This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 4aa36ec5 feat(scan): Support arbitrary glob patterns (#2608)
4aa36ec5 is described below
commit 4aa36ec5efd4d4cb6e8250a674715c119d1c3b64
Author: Nathan <[email protected]>
AuthorDate: Mon Oct 28 22:22:36 2024 -0400
feat(scan): Support arbitrary glob patterns (#2608)
---
.github/config/typos.toml | 5 +
src/commands/cmd_server.cc | 16 +-
src/commands/scan_base.h | 14 +-
src/common/status.h | 1 -
src/common/string_util.cc | 228 +++++++++++++++++-----------
src/common/string_util.h | 15 +-
src/config/config.cc | 2 +-
src/server/server.cc | 11 +-
src/storage/redis_db.cc | 28 ++--
src/storage/redis_db.h | 9 +-
tests/cppunit/string_util_test.cc | 156 +++++++++++++++++++
tests/gocase/unit/keyspace/keyspace_test.go | 51 ++++++-
tests/gocase/unit/scan/scan_test.go | 51 ++++++-
13 files changed, 448 insertions(+), 139 deletions(-)
diff --git a/.github/config/typos.toml b/.github/config/typos.toml
index daae57c8..03518540 100644
--- a/.github/config/typos.toml
+++ b/.github/config/typos.toml
@@ -20,6 +20,11 @@ extend-exclude = [
".git/",
"src/vendor/",
"tests/gocase/util/slot.go",
+
+ # Uses short strings for testing glob matching
+ "tests/cppunit/string_util_test.cc",
+ "tests/gocase/unit/keyspace/keyspace_test.go",
+ "tests/gocase/unit/scan/scan_test.go",
]
ignore-hidden = false
diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc
index 3d98a5f5..600f23c0 100644
--- a/src/commands/cmd_server.cc
+++ b/src/commands/cmd_server.cc
@@ -23,6 +23,8 @@
#include "commands/scan_base.h"
#include "common/io_util.h"
#include "common/rdb_stream.h"
+#include "common/string_util.h"
+#include "common/time_util.h"
#include "config/config.h"
#include "error_constants.h"
#include "server/redis_connection.h"
@@ -30,8 +32,6 @@
#include "server/server.h"
#include "stats/disk_stats.h"
#include "storage/rdb/rdb.h"
-#include "string_util.h"
-#include "time_util.h"
namespace redis {
@@ -114,15 +114,15 @@ class CommandNamespace : public Commander {
class CommandKeys : public Commander {
public:
Status Execute(engine::Context &ctx, Server *srv, Connection *conn,
std::string *output) override {
- const std::string &prefix = args_[1];
+ const std::string &glob_pattern = args_[1];
std::vector<std::string> keys;
redis::Database redis(srv->storage, conn->GetNamespace());
- if (prefix.empty() || prefix.find('*') != prefix.size() - 1) {
- return {Status::RedisExecErr, "only keys prefix match was supported"};
+ if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
+ return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
}
-
- const rocksdb::Status s = redis.Keys(ctx, prefix.substr(0, prefix.size() -
1), &keys);
+ const auto [prefix, suffix_glob] = util::SplitGlob(glob_pattern);
+ const rocksdb::Status s = redis.Keys(ctx, prefix, suffix_glob, &keys);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
@@ -846,7 +846,7 @@ class CommandScan : public CommandScanBase {
std::vector<std::string> keys;
std::string end_key;
- auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, &keys, &end_key,
type_);
+ const auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, suffix_glob_,
&keys, &end_key, type_);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
diff --git a/src/commands/scan_base.h b/src/commands/scan_base.h
index b3773b94..5a0d4ca1 100644
--- a/src/commands/scan_base.h
+++ b/src/commands/scan_base.h
@@ -23,8 +23,10 @@
#include "commander.h"
#include "commands/command_parser.h"
#include "error_constants.h"
+#include "glob.h"
#include "parse_util.h"
#include "server/server.h"
+#include "string_util.h"
namespace redis {
@@ -44,14 +46,11 @@ class CommandScanBase : public Commander {
Status ParseAdditionalFlags(Parser &parser) {
while (parser.Good()) {
if (parser.EatEqICase("match")) {
- prefix_ = GET_OR_RET(parser.TakeStr());
- // The match pattern should contain exactly one '*' at the end; remove
the * to
- // get the prefix to match.
- if (!prefix_.empty() && prefix_.find('*') == prefix_.size() - 1) {
- prefix_.pop_back();
- } else {
- return {Status::RedisParseErr, "currently only key prefix matching
is supported"};
+ const std::string glob_pattern = GET_OR_RET(parser.TakeStr());
+ if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
+ return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
}
+ std::tie(prefix_, suffix_glob_) = util::SplitGlob(glob_pattern);
} else if (parser.EatEqICase("count")) {
limit_ = GET_OR_RET(parser.TakeInt());
if (limit_ <= 0) {
@@ -100,6 +99,7 @@ class CommandScanBase : public Commander {
protected:
std::string cursor_;
std::string prefix_;
+ std::string suffix_glob_ = "*";
int limit_ = 20;
RedisType type_ = kRedisNone;
};
diff --git a/src/common/status.h b/src/common/status.h
index 2bd610ad..aef74033 100644
--- a/src/common/status.h
+++ b/src/common/status.h
@@ -26,7 +26,6 @@
#include <algorithm>
#include <memory>
#include <string>
-#include <tuple>
#include <type_traits>
#include <utility>
diff --git a/src/common/string_util.cc b/src/common/string_util.cc
index 476e3173..cce64402 100644
--- a/src/common/string_util.cc
+++ b/src/common/string_util.cc
@@ -101,118 +101,174 @@ bool HasPrefix(const std::string &str, const
std::string &prefix) {
return !strncasecmp(str.data(), prefix.data(), prefix.size());
}
-int StringMatch(const std::string &pattern, const std::string &in, int nocase)
{
- return StringMatchLen(pattern.c_str(), pattern.length(), in.c_str(),
in.length(), nocase);
+Status ValidateGlob(std::string_view glob) {
+ for (size_t idx = 0; idx < glob.size(); ++idx) {
+ switch (glob[idx]) {
+ case '*':
+ case '?':
+ break;
+ case ']':
+ return {Status::NotOK, "Unmatched unescaped ]"};
+ case '\\':
+ if (idx == glob.size() - 1) {
+ return {Status::NotOK, "Trailing unescaped backslash"};
+ }
+ // Skip the next character: this is a literal so nothing can go wrong
+ idx++;
+ break;
+ case '[':
+ idx++; // Skip the opening bracket
+ while (idx < glob.size() && glob[idx] != ']') {
+ if (glob[idx] == '\\') {
+ idx += 2;
+ continue;
+ } else if (idx + 1 < glob.size() && glob[idx + 1] == '-') {
+ if (idx + 2 >= glob.size()) {
+ return {Status::NotOK, "Unterminated character range"};
+ }
+ // Skip the - and the end of the range
+ idx += 2;
+ }
+ idx++;
+ }
+ if (idx == glob.size()) {
+ return {Status::NotOK, "Unterminated [ group"};
+ }
+ break;
+ default:
+ // This is a literal: nothing can go wrong
+ break;
+ }
+ }
+ return Status::OK();
}
-// Glob-style pattern matching.
-int StringMatchLen(const char *pattern, size_t pattern_len, const char
*string, size_t string_len, int nocase) {
- while (pattern_len && string_len) {
+constexpr bool StringMatchImpl(std::string_view pattern, std::string_view
string, bool ignore_case,
+ bool *skip_longer_matches, size_t
recursion_depth = 0) {
+ // If we want to ignore case, this is equivalent to converting both the
pattern and the string to lowercase
+ const auto canonicalize = [ignore_case](unsigned char c) -> unsigned char {
+ return ignore_case ? static_cast<unsigned char>(std::tolower(c)) : c;
+ };
+
+ if (recursion_depth > 1000) return false;
+
+ while (!pattern.empty() && !string.empty()) {
switch (pattern[0]) {
case '*':
- while (pattern[1] == '*') {
- pattern++;
- pattern_len--;
+ // Optimization: collapse multiple * into one
+ while (pattern.size() >= 2 && pattern[1] == '*') {
+ pattern.remove_prefix(1);
}
-
- if (pattern_len == 1) return 1; /* match */
-
- while (string_len) {
- if (StringMatchLen(pattern + 1, pattern_len - 1, string, string_len,
nocase)) return 1; /* match */
- string++;
- string_len--;
+ // Optimization: If the '*' is the last character in the pattern, it
can match anything
+ if (pattern.length() == 1) return true;
+ while (!string.empty()) {
+ if (StringMatchImpl(pattern.substr(1), string, ignore_case,
skip_longer_matches, recursion_depth + 1))
+ return true;
+ if (*skip_longer_matches) return false;
+ string.remove_prefix(1);
}
- return 0; /* no match */
+ // There was no match for the rest of the pattern starting
+ // from anywhere in the rest of the string. If there were
+ // any '*' earlier in the pattern, we can terminate the
+ // search early without trying to match them to longer
+ // substrings. This is because a longer match for the
+ // earlier part of the pattern would require the rest of the
+ // pattern to match starting later in the string, and we
+ // have just determined that there is no match for the rest
+ // of the pattern starting from anywhere in the current
+ // string.
+ *skip_longer_matches = true;
+ return false;
case '?':
- string++;
- string_len--;
+ if (string.empty()) return false;
+ string.remove_prefix(1);
break;
case '[': {
- pattern++;
- pattern_len--;
- int not_symbol = pattern[0] == '^';
- if (not_symbol) {
- pattern++;
- pattern_len--;
- }
+ pattern.remove_prefix(1);
+ const bool invert = pattern[0] == '^';
+ if (invert) pattern.remove_prefix(1);
- int match = 0;
+ bool match = false;
while (true) {
- if (pattern[0] == '\\' && pattern_len >= 2) {
- pattern++;
- pattern_len--;
- if (pattern[0] == string[0]) match = 1;
+ if (pattern.empty()) {
+ // unterminated [ group: reject invalid pattern
+ return false;
} else if (pattern[0] == ']') {
break;
- } else if (pattern_len == 0) {
- pattern--;
- pattern_len++;
- break;
- } else if (pattern[1] == '-' && pattern_len >= 3) {
- int start = pattern[0];
- int end = pattern[2];
- int c = string[0];
- if (start > end) {
- int t = start;
- start = end;
- end = t;
- }
- if (nocase) {
- start = tolower(start);
- end = tolower(end);
- c = tolower(c);
- }
- pattern += 2;
- pattern_len -= 2;
- if (c >= start && c <= end) match = 1;
- } else {
- if (!nocase) {
- if (pattern[0] == string[0]) match = 1;
- } else {
- if (tolower(static_cast<int>(pattern[0])) ==
tolower(static_cast<int>(string[0]))) match = 1;
- }
+ } else if (pattern.length() >= 2 && pattern[0] == '\\') {
+ pattern.remove_prefix(1);
+ if (pattern[0] == string[0]) match = true;
+ } else if (pattern.length() >= 3 && pattern[1] == '-') {
+ unsigned char start = canonicalize(pattern[0]);
+ unsigned char end = canonicalize(pattern[2]);
+ if (start > end) std::swap(start, end);
+ const int c = canonicalize(string[0]);
+ pattern.remove_prefix(2);
+
+ if (c >= start && c <= end) match = true;
+ } else if (canonicalize(pattern[0]) == canonicalize(string[0])) {
+ match = true;
}
- pattern++;
- pattern_len--;
+ pattern.remove_prefix(1);
}
-
- if (not_symbol) match = !match;
-
- if (!match) return 0; /* no match */
-
- string++;
- string_len--;
+ if (invert) match = !match;
+ if (!match) return false;
+ string.remove_prefix(1);
break;
}
case '\\':
- if (pattern_len >= 2) {
- pattern++;
- pattern_len--;
+ if (pattern.length() >= 2) {
+ pattern.remove_prefix(1);
}
- /* fall through */
+ [[fallthrough]];
default:
- if (!nocase) {
- if (pattern[0] != string[0]) return 0; /* no match */
+ // Just a normal character
+ if (!ignore_case) {
+ if (pattern[0] != string[0]) return false;
} else {
- if (tolower(static_cast<int>(pattern[0])) !=
tolower(static_cast<int>(string[0]))) return 0; /* no match */
+ if (std::tolower((int)pattern[0]) != std::tolower((int)string[0]))
return false;
}
- string++;
- string_len--;
+ string.remove_prefix(1);
break;
}
- pattern++;
- pattern_len--;
- if (string_len == 0) {
- while (*pattern == '*') {
- pattern++;
- pattern_len--;
- }
- break;
- }
+ pattern.remove_prefix(1);
}
- if (pattern_len == 0 && string_len == 0) return 1;
- return 0;
+ // Now that either the pattern is empty or the string is empty, this is a
match iff
+ // the pattern consists only of '*', and the string is empty.
+ return string.empty() && std::all_of(pattern.begin(), pattern.end(), [](char
c) { return c == '*'; });
+}
+
+// Given a glob [pattern] and a string [string], return true iff the string
matches the glob.
+// If [ignore_case] is true, the match is case-insensitive.
+bool StringMatch(std::string_view glob, std::string_view str, bool
ignore_case) {
+ bool skip_longer_matches = false;
+ return StringMatchImpl(glob, str, ignore_case, &skip_longer_matches);
+}
+
+// Split a glob pattern into a literal prefix and a suffix containing
wildcards.
+// For example, if the user calls [KEYS bla*bla], this function will return
{"bla", "*bla"}.
+// This allows the caller of this function to optimize this call by performing
a
+// prefix-scan on "bla" and then filtering the results using the GlobMatches
function.
+std::pair<std::string, std::string> SplitGlob(std::string_view glob) {
+ // Stores the prefix of the glob pattern, with backslashes removed
+ std::string prefix;
+ // Find the first un-escaped '*', '?' or '[' character in [glob]
+ for (size_t idx = 0; idx < glob.size(); ++idx) {
+ if (glob[idx] == '*' || glob[idx] == '?' || glob[idx] == '[') {
+ // Return a pair of views: the part of the glob before the wildcard, and
the part after
+ return {prefix, std::string(glob.substr(idx))};
+ } else if (glob[idx] == '\\') {
+ // Skip checking whether the next character is a special character
+ ++idx;
+ // Append the escaped special character to the prefix
+ if (idx < glob.size()) prefix.push_back(glob[idx]);
+ } else {
+ prefix.push_back(glob[idx]);
+ }
+ }
+ // No wildcard found, return the entire string (without the backslashes) as
the prefix
+ return {prefix, ""};
}
std::vector<std::string> RegexMatch(const std::string &str, const std::string
®ex) {
diff --git a/src/common/string_util.h b/src/common/string_util.h
index 2dcb1080..f86590ad 100644
--- a/src/common/string_util.h
+++ b/src/common/string_util.h
@@ -20,7 +20,13 @@
#pragma once
-#include "status.h"
+#include <cstdint>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "common/status.h"
namespace util {
@@ -32,8 +38,11 @@ std::string Trim(std::string in, std::string_view chars);
std::vector<std::string> Split(std::string_view in, std::string_view delim);
std::vector<std::string> Split2KV(const std::string &in, const std::string
&delim);
bool HasPrefix(const std::string &str, const std::string &prefix);
-int StringMatch(const std::string &pattern, const std::string &in, int nocase);
-int StringMatchLen(const char *p, size_t plen, const char *s, size_t slen, int
nocase);
+
+Status ValidateGlob(std::string_view glob);
+bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case
= false);
+std::pair<std::string, std::string> SplitGlob(std::string_view glob);
+
std::vector<std::string> RegexMatch(const std::string &str, const std::string
®ex);
std::string StringToHex(std::string_view input);
std::vector<std::string> TokenizeRedisProtocol(const std::string &value);
diff --git a/src/config/config.cc b/src/config/config.cc
index 57b2c7b0..f14dc78c 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -904,7 +904,7 @@ Status Config::Load(const CLIOptions &opts) {
void Config::Get(const std::string &key, std::vector<std::string> *values)
const {
values->clear();
for (const auto &iter : fields_) {
- if (util::StringMatch(key, iter.first, 1)) {
+ if (util::StringMatch(key, iter.first, true)) {
if (iter.second->IsMultiConfig()) {
for (const auto &p : util::Split(iter.second->ToString(), "\n")) {
values->emplace_back(iter.first);
diff --git a/src/server/server.cc b/src/server/server.cc
index 1de5534d..e569d12e 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -38,7 +38,7 @@
#include <utility>
#include "commands/commander.h"
-#include "config.h"
+#include "common/string_util.h"
#include "config/config.h"
#include "fmt/format.h"
#include "redis_connection.h"
@@ -46,7 +46,6 @@
#include "storage/redis_db.h"
#include "storage/scripting.h"
#include "storage/storage.h"
-#include "string_util.h"
#include "thread_util.h"
#include "time_util.h"
#include "version.h"
@@ -160,7 +159,7 @@ Status Server::Start() {
if (!config_->cluster_enabled) {
engine::Context no_txn_ctx =
engine::Context::NoTransactionContext(storage);
GET_OR_RET(index_mgr.Load(no_txn_ctx, kDefaultNamespace));
- for (auto [_, ns] : namespace_.List()) {
+ for (const auto &[_, ns] : namespace_.List()) {
GET_OR_RET(index_mgr.Load(no_txn_ctx, ns));
}
}
@@ -391,7 +390,7 @@ int Server::PublishMessage(const std::string &channel,
const std::string &msg) {
std::vector<std::string> patterns;
std::vector<ConnContext> to_publish_patterns_conn_ctxs;
for (const auto &iter : pubsub_patterns_) {
- if (util::StringMatch(iter.first, channel, 0)) {
+ if (util::StringMatch(iter.first, channel, false)) {
for (const auto &conn_ctx : iter.second) {
to_publish_patterns_conn_ctxs.emplace_back(conn_ctx);
patterns.emplace_back(iter.first);
@@ -463,7 +462,7 @@ void Server::GetChannelsByPattern(const std::string
&pattern, std::vector<std::s
std::lock_guard<std::mutex> guard(pubsub_channels_mu_);
for (const auto &iter : pubsub_channels_) {
- if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
+ if (pattern.empty() || util::StringMatch(pattern, iter.first, false)) {
channels->emplace_back(iter.first);
}
}
@@ -549,7 +548,7 @@ void Server::GetSChannelsByPattern(const std::string
&pattern, std::vector<std::
for (const auto &shard_channels : pubsub_shard_channels_) {
for (const auto &iter : shard_channels) {
- if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
+ if (pattern.empty() || util::StringMatch(pattern, iter.first, false)) {
channels->emplace_back(iter.first);
}
}
diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc
index 7fe83477..5eabd8d8 100644
--- a/src/storage/redis_db.cc
+++ b/src/storage/redis_db.cc
@@ -21,16 +21,15 @@
#include "redis_db.h"
#include <ctime>
-#include <map>
#include <utility>
#include "cluster/redis_slot.h"
#include "common/scope_exit.h"
+#include "common/string_util.h"
#include "db_util.h"
#include "parse_util.h"
#include "rocksdb/iterator.h"
#include "rocksdb/status.h"
-#include "server/server.h"
#include "storage/iterator.h"
#include "storage/redis_metadata.h"
#include "storage/storage.h"
@@ -249,11 +248,11 @@ rocksdb::Status Database::GetExpireTime(engine::Context
&ctx, const Slice &user_
}
rocksdb::Status Database::GetKeyNumStats(engine::Context &ctx, const
std::string &prefix, KeyNumStats *stats) {
- return Keys(ctx, prefix, nullptr, stats);
+ return Keys(ctx, prefix, "*", nullptr, stats);
}
-rocksdb::Status Database::Keys(engine::Context &ctx, const std::string
&prefix, std::vector<std::string> *keys,
- KeyNumStats *stats) {
+rocksdb::Status Database::Keys(engine::Context &ctx, const std::string
&prefix, const std::string &suffix_glob,
+ std::vector<std::string> *keys, KeyNumStats
*stats) {
uint16_t slot_id = 0;
std::string ns_prefix;
if (namespace_ != kDefaultNamespace || keys != nullptr) {
@@ -277,6 +276,10 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const
std::string &prefix,
if (!ns_prefix.empty() && !iter->key().starts_with(ns_prefix)) {
break;
}
+ auto [_, user_key] = ExtractNamespaceKey(iter->key(),
storage_->IsSlotIdEncoded());
+ if (!util::StringMatch(suffix_glob,
user_key.ToString().substr(prefix.size()))) {
+ continue;
+ }
Metadata metadata(kRedisNone, false);
auto s = metadata.Decode(iter->value());
if (!s.ok()) continue;
@@ -293,7 +296,6 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const
std::string &prefix,
}
}
if (keys) {
- auto [_, user_key] = ExtractNamespaceKey(iter->key(),
storage_->IsSlotIdEncoded());
keys->emplace_back(user_key.ToString());
}
}
@@ -319,8 +321,8 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const
std::string &prefix,
}
rocksdb::Status Database::Scan(engine::Context &ctx, const std::string
&cursor, uint64_t limit,
- const std::string &prefix,
std::vector<std::string> *keys, std::string *end_cursor,
- RedisType type) {
+ const std::string &prefix, const std::string
&suffix_glob,
+ std::vector<std::string> *keys, std::string
*end_cursor, RedisType type) {
end_cursor->clear();
uint64_t cnt = 0;
uint16_t slot_start = 0;
@@ -366,6 +368,10 @@ rocksdb::Status Database::Scan(engine::Context &ctx, const
std::string &cursor,
if (metadata.Expired()) continue;
std::tie(std::ignore, user_key) =
ExtractNamespaceKey<std::string>(iter->key(), storage_->IsSlotIdEncoded());
+
+ if (!util::StringMatch(suffix_glob, user_key.substr(prefix.size()))) {
+ continue;
+ }
keys->emplace_back(user_key);
cnt++;
}
@@ -395,7 +401,7 @@ rocksdb::Status Database::Scan(engine::Context &ctx, const
std::string &cursor,
if (iter->Valid()) {
std::tie(std::ignore, user_key) =
ExtractNamespaceKey<std::string>(iter->key(), storage_->IsSlotIdEncoded());
auto res = std::mismatch(prefix.begin(), prefix.end(),
user_key.begin());
- if (res.first == prefix.end()) {
+ if (res.first == prefix.end() && util::StringMatch(suffix_glob,
user_key.substr(prefix.size()))) {
keys->emplace_back(user_key);
}
@@ -420,13 +426,13 @@ rocksdb::Status Database::RandomKey(engine::Context &ctx,
const std::string &cur
std::string end_cursor;
std::vector<std::string> keys;
- auto s = Scan(ctx, cursor, RANDOM_KEY_SCAN_LIMIT, "", &keys, &end_cursor);
+ auto s = Scan(ctx, cursor, RANDOM_KEY_SCAN_LIMIT, "", "*", &keys,
&end_cursor);
if (!s.ok()) {
return s;
}
if (keys.empty() && !cursor.empty()) {
// if reach the end, restart from beginning
- s = Scan(ctx, "", RANDOM_KEY_SCAN_LIMIT, "", &keys, &end_cursor);
+ s = Scan(ctx, "", RANDOM_KEY_SCAN_LIMIT, "", "*", &keys, &end_cursor);
if (!s.ok()) {
return s;
}
diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h
index 7111fed1..41ed3dae 100644
--- a/src/storage/redis_db.h
+++ b/src/storage/redis_db.h
@@ -20,7 +20,6 @@
#pragma once
-#include <map>
#include <optional>
#include <string>
#include <utility>
@@ -29,7 +28,6 @@
#include "cluster/cluster_defs.h"
#include "redis_metadata.h"
-#include "server/redis_reply.h"
#include "storage.h"
namespace redis {
@@ -119,11 +117,12 @@ class Database {
[[nodiscard]] rocksdb::Status FlushDB(engine::Context &ctx);
[[nodiscard]] rocksdb::Status FlushAll(engine::Context &ctx);
[[nodiscard]] rocksdb::Status GetKeyNumStats(engine::Context &ctx, const
std::string &prefix, KeyNumStats *stats);
- [[nodiscard]] rocksdb::Status Keys(engine::Context &ctx, const std::string
&prefix,
+ [[nodiscard]] rocksdb::Status Keys(engine::Context &ctx, const std::string
&prefix, const std::string &suffix_glob,
std::vector<std::string> *keys = nullptr,
KeyNumStats *stats = nullptr);
[[nodiscard]] rocksdb::Status Scan(engine::Context &ctx, const std::string
&cursor, uint64_t limit,
- const std::string &prefix,
std::vector<std::string> *keys,
- std::string *end_cursor = nullptr,
RedisType type = kRedisNone);
+ const std::string &prefix, const
std::string &suffix_glob,
+ std::vector<std::string> *keys,
std::string *end_cursor = nullptr,
+ RedisType type = kRedisNone);
[[nodiscard]] rocksdb::Status RandomKey(engine::Context &ctx, const
std::string &cursor, std::string *key);
std::string AppendNamespacePrefix(const Slice &user_key);
[[nodiscard]] rocksdb::Status ClearKeysOfSlotRange(engine::Context &ctx,
const rocksdb::Slice &ns,
diff --git a/tests/cppunit/string_util_test.cc
b/tests/cppunit/string_util_test.cc
index f95ccbff..1d24cf59 100644
--- a/tests/cppunit/string_util_test.cc
+++ b/tests/cppunit/string_util_test.cc
@@ -22,6 +22,7 @@
#include <gtest/gtest.h>
+#include <initializer_list>
#include <map>
#include <string>
#include <unordered_map>
@@ -84,6 +85,161 @@ TEST(StringUtil, HasPrefix) {
ASSERT_FALSE(util::HasPrefix("has", "has_prefix"));
}
+TEST(StringUtil, ValidateGlob) {
+ const auto expect_ok = [](std::string_view glob) {
+ const auto result = util::ValidateGlob(glob);
+ EXPECT_TRUE(result.IsOK()) << glob << ": " << result.Msg();
+ };
+
+ const auto expect_error = [](std::string_view glob, std::string_view
expected_error) {
+ const auto result = util::ValidateGlob(glob);
+ EXPECT_FALSE(result.IsOK());
+ EXPECT_EQ(result.Msg(), expected_error) << glob;
+ };
+
+ expect_ok("a");
+ expect_ok("\\*");
+ expect_ok("\\?");
+ expect_ok("\\[");
+ expect_ok("\\]");
+ expect_ok("a*");
+ expect_ok("a?");
+ expect_ok("[ab]");
+ expect_ok("[^ab]");
+ expect_ok("[a-c]");
+ // Surprisingly valid: this accepts the characters {a, b, c, e, f, g, -}
+ expect_ok("[a-c-e-g]");
+ expect_ok("[^a-c]");
+ expect_ok("[-]");
+ expect_ok("[\\]]");
+ expect_ok("[\\\\]");
+ expect_ok("[\\?]");
+ expect_ok("[\\*]");
+ expect_ok("[\\[]");
+
+ expect_error("[", "Unterminated [ group");
+ expect_error("]", "Unmatched unescaped ]");
+ expect_error("[a", "Unterminated [ group");
+ expect_error("\\", "Trailing unescaped backslash");
+
+ // Weird case: we open a character class, with the range 'a' to ']', but
then never close it
+ expect_error("[a-]", "Unterminated [ group");
+ expect_ok("[a-]]");
+}
+
+TEST(StringUtil, StringMatch) {
+ /* Some basic tests */
+ EXPECT_TRUE(util::StringMatch("a", "a"));
+ EXPECT_FALSE(util::StringMatch("a", "b"));
+ EXPECT_FALSE(util::StringMatch("a", "aa"));
+ EXPECT_FALSE(util::StringMatch("a", ""));
+ EXPECT_TRUE(util::StringMatch("", ""));
+ EXPECT_FALSE(util::StringMatch("", "a"));
+ EXPECT_TRUE(util::StringMatch("*", ""));
+ EXPECT_TRUE(util::StringMatch("*", "a"));
+
+ /* Simple character class tests */
+ EXPECT_TRUE(util::StringMatch("[a]", "a"));
+ EXPECT_FALSE(util::StringMatch("[a]", "b"));
+ EXPECT_FALSE(util::StringMatch("[^a]", "a"));
+ EXPECT_TRUE(util::StringMatch("[^a]", "b"));
+ EXPECT_TRUE(util::StringMatch("[ab]", "a"));
+ EXPECT_TRUE(util::StringMatch("[ab]", "b"));
+ EXPECT_FALSE(util::StringMatch("[ab]", "c"));
+ EXPECT_TRUE(util::StringMatch("[^ab]", "c"));
+ EXPECT_TRUE(util::StringMatch("[a-c]", "b"));
+ EXPECT_FALSE(util::StringMatch("[a-c]", "d"));
+
+ /* Corner cases in character class parsing */
+ EXPECT_TRUE(util::StringMatch("[a-c-e-g]", "-"));
+ EXPECT_FALSE(util::StringMatch("[a-c-e-g]", "d"));
+ EXPECT_TRUE(util::StringMatch("[a-c-e-g]", "f"));
+
+ /* Escaping */
+ EXPECT_TRUE(util::StringMatch("\\?", "?"));
+ EXPECT_FALSE(util::StringMatch("\\?", "a"));
+ EXPECT_TRUE(util::StringMatch("\\*", "*"));
+ EXPECT_FALSE(util::StringMatch("\\*", "a"));
+ EXPECT_TRUE(util::StringMatch("\\[", "["));
+ EXPECT_TRUE(util::StringMatch("\\]", "]"));
+ EXPECT_TRUE(util::StringMatch("\\\\", "\\"));
+ EXPECT_TRUE(util::StringMatch("[\\.]", "."));
+ EXPECT_TRUE(util::StringMatch("[\\-]", "-"));
+ EXPECT_TRUE(util::StringMatch("[\\[]", "["));
+ EXPECT_TRUE(util::StringMatch("[\\]]", "]"));
+ EXPECT_TRUE(util::StringMatch("[\\\\]", "\\"));
+ EXPECT_TRUE(util::StringMatch("[\\?]", "?"));
+ EXPECT_TRUE(util::StringMatch("[\\*]", "*"));
+
+ /* Simple wild cards */
+ EXPECT_TRUE(util::StringMatch("?", "a"));
+ EXPECT_FALSE(util::StringMatch("?", "aa"));
+ EXPECT_FALSE(util::StringMatch("??", "a"));
+ EXPECT_TRUE(util::StringMatch("?x?", "axb"));
+ EXPECT_FALSE(util::StringMatch("?x?", "abx"));
+ EXPECT_FALSE(util::StringMatch("?x?", "xab"));
+
+ /* Asterisk wild cards (backtracking) */
+ EXPECT_FALSE(util::StringMatch("*??", "a"));
+ EXPECT_TRUE(util::StringMatch("*??", "ab"));
+ EXPECT_TRUE(util::StringMatch("*??", "abc"));
+ EXPECT_TRUE(util::StringMatch("*??", "abcd"));
+ EXPECT_FALSE(util::StringMatch("??*", "a"));
+ EXPECT_TRUE(util::StringMatch("??*", "ab"));
+ EXPECT_TRUE(util::StringMatch("??*", "abc"));
+ EXPECT_TRUE(util::StringMatch("??*", "abcd"));
+ EXPECT_FALSE(util::StringMatch("?*?", "a"));
+ EXPECT_TRUE(util::StringMatch("?*?", "ab"));
+ EXPECT_TRUE(util::StringMatch("?*?", "abc"));
+ EXPECT_TRUE(util::StringMatch("?*?", "abcd"));
+ EXPECT_TRUE(util::StringMatch("*b", "b"));
+ EXPECT_TRUE(util::StringMatch("*b", "ab"));
+ EXPECT_FALSE(util::StringMatch("*b", "ba"));
+ EXPECT_TRUE(util::StringMatch("*b", "bb"));
+ EXPECT_TRUE(util::StringMatch("*b", "abb"));
+ EXPECT_TRUE(util::StringMatch("*b", "bab"));
+ EXPECT_TRUE(util::StringMatch("*bc", "abbc"));
+ EXPECT_TRUE(util::StringMatch("*bc", "bc"));
+ EXPECT_TRUE(util::StringMatch("*bc", "bbc"));
+ EXPECT_TRUE(util::StringMatch("*bc", "bcbc"));
+
+ /* Multiple asterisks (complex backtracking) */
+ EXPECT_TRUE(util::StringMatch("*ac*", "abacadaeafag"));
+ EXPECT_TRUE(util::StringMatch("*ac*ae*ag*", "abacadaeafag"));
+ EXPECT_TRUE(util::StringMatch("*a*b*[bc]*[ef]*g*", "abacadaeafag"));
+ EXPECT_FALSE(util::StringMatch("*a*b*[ef]*[cd]*g*", "abacadaeafag"));
+ EXPECT_TRUE(util::StringMatch("*abcd*", "abcabcabcabcdefg"));
+ EXPECT_TRUE(util::StringMatch("*ab*cd*", "abcabcabcabcdefg"));
+ EXPECT_TRUE(util::StringMatch("*abcd*abcdef*", "abcabcdabcdeabcdefg"));
+ EXPECT_FALSE(util::StringMatch("*abcd*", "abcabcabcabcefg"));
+ EXPECT_FALSE(util::StringMatch("*ab*cd*", "abcabcabcabcefg"));
+
+ /* Robustness to exponential blow-ups with lots of non-collapsible asterisks
*/
+ EXPECT_TRUE(
+ util::StringMatch("?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*a",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"));
+ EXPECT_FALSE(
+ util::StringMatch("?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*b",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"));
+}
+
+TEST(StringUtil, SplitGlob) {
+ using namespace std::string_literals;
+
+ // Basic functionality: no escaped characters
+ EXPECT_EQ(util::SplitGlob(""), std::make_pair(""s, ""s));
+ EXPECT_EQ(util::SplitGlob("string"), std::make_pair("string"s, ""s));
+ EXPECT_EQ(util::SplitGlob("string*"), std::make_pair("string"s, "*"s));
+ EXPECT_EQ(util::SplitGlob("*string"), std::make_pair(""s, "*string"s));
+ EXPECT_EQ(util::SplitGlob("str*ing"), std::make_pair("str"s, "*ing"s));
+ EXPECT_EQ(util::SplitGlob("string?"), std::make_pair("string"s, "?"s));
+ EXPECT_EQ(util::SplitGlob("?string"), std::make_pair(""s, "?string"s));
+ EXPECT_EQ(util::SplitGlob("ab[cd]ef"), std::make_pair("ab"s, "[cd]ef"s));
+
+ // Escaped characters; also tests that prefix is trimmed of backslashes
+ EXPECT_EQ(util::SplitGlob("str\\*ing*"), std::make_pair("str*ing"s, "*"s));
+ EXPECT_EQ(util::SplitGlob("str\\?ing?"), std::make_pair("str?ing"s, "?"s));
+ EXPECT_EQ(util::SplitGlob("str\\[ing[a]"), std::make_pair("str[ing"s,
"[a]"s));
+}
+
TEST(StringUtil, EscapeString) {
std::unordered_map<std::string, std::string> origin_to_escaped = {
{"abc", "abc"},
diff --git a/tests/gocase/unit/keyspace/keyspace_test.go
b/tests/gocase/unit/keyspace/keyspace_test.go
index 6fbb84a7..37b86afd 100644
--- a/tests/gocase/unit/keyspace/keyspace_test.go
+++ b/tests/gocase/unit/keyspace/keyspace_test.go
@@ -27,6 +27,7 @@ import (
"github.com/apache/kvrocks/tests/gocase/util"
"github.com/stretchr/testify/require"
+ "golang.org/x/exp/slices"
)
func TestKeyspace(t *testing.T) {
@@ -65,10 +66,6 @@ func TestKeyspace(t *testing.T) {
require.Equal(t, []string{"foo_a", "foo_b", "foo_c"}, keys)
})
- t.Run("KEYS with invalid pattern", func(t *testing.T) {
- require.Error(t, rdb.Keys(ctx, "*ab*").Err())
- })
-
t.Run("KEYS to get all keys", func(t *testing.T) {
keys := rdb.Keys(ctx, "*").Val()
sort.Slice(keys, func(i, j int) bool {
@@ -77,12 +74,58 @@ func TestKeyspace(t *testing.T) {
require.Equal(t, []string{"foo_a", "foo_b", "foo_c", "key_x",
"key_y", "key_z"}, keys)
})
+ t.Run("KEYS with invalid patterns", func(t *testing.T) {
+ require.Error(t, rdb.Keys(ctx, "[").Err())
+ require.Error(t, rdb.Keys(ctx, "\\").Err())
+ require.Error(t, rdb.Keys(ctx, "[a-]").Err())
+ require.Error(t, rdb.Keys(ctx, "[a").Err())
+ })
+
t.Run("DBSize", func(t *testing.T) {
require.NoError(t, rdb.Do(ctx, "dbsize", "scan").Err())
time.Sleep(100 * time.Millisecond)
require.EqualValues(t, 6, rdb.Do(ctx, "dbsize").Val())
})
+ t.Run("KEYS with non-trivial patterns", func(t *testing.T) {
+ require.NoError(t, rdb.FlushDB(ctx).Err())
+ for _, key := range []string{"aa", "aab", "aabb", "ab", "abb"} {
+ require.NoError(t, rdb.Set(ctx, key, "hello", 0).Err())
+ }
+
+ keys := rdb.Keys(ctx, "a*").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"},
keys)
+
+ keys = rdb.Keys(ctx, "aa").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"aa"}, keys)
+
+ keys = rdb.Keys(ctx, "aa*").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"aa", "aab", "aabb"}, keys)
+
+ keys = rdb.Keys(ctx, "a?").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"aa", "ab"}, keys)
+
+ keys = rdb.Keys(ctx, "a*?").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"},
keys)
+
+ keys = rdb.Keys(ctx, "ab*").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"ab", "abb"}, keys)
+
+ keys = rdb.Keys(ctx, "*ab").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"aab", "ab"}, keys)
+
+ keys = rdb.Keys(ctx, "*ab*").Val()
+ slices.Sort(keys)
+ require.Equal(t, []string{"aab", "aabb", "ab", "abb"}, keys)
+ })
+
t.Run("DEL all keys", func(t *testing.T) {
vals := rdb.Keys(ctx, "*").Val()
require.EqualValues(t, len(vals), rdb.Del(ctx, vals...).Val())
diff --git a/tests/gocase/unit/scan/scan_test.go
b/tests/gocase/unit/scan/scan_test.go
index cade5dcc..5d2f9fb9 100644
--- a/tests/gocase/unit/scan/scan_test.go
+++ b/tests/gocase/unit/scan/scan_test.go
@@ -77,7 +77,6 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx
context.Context) {
require.NoError(t, rdb.FlushDB(ctx).Err())
util.Populate(t, rdb, "", 1000, 10)
keys := scanAll(t, rdb)
- keys = slices.Compact(keys)
require.Len(t, keys, 1000)
})
@@ -85,7 +84,6 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx
context.Context) {
require.NoError(t, rdb.FlushDB(ctx).Err())
util.Populate(t, rdb, "", 1000, 10)
keys := scanAll(t, rdb, "count", 5)
- keys = slices.Compact(keys)
require.Len(t, keys, 1000)
})
@@ -93,15 +91,46 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx
context.Context) {
require.NoError(t, rdb.FlushDB(ctx).Err())
util.Populate(t, rdb, "key:", 1000, 10)
keys := scanAll(t, rdb, "match", "key:*")
- keys = slices.Compact(keys)
require.Len(t, keys, 1000)
})
- t.Run("SCAN MATCH invalid pattern", func(t *testing.T) {
+ t.Run("SCAN MATCH non-trivial pattern", func(t *testing.T) {
require.NoError(t, rdb.FlushDB(ctx).Err())
- util.Populate(t, rdb, "*ab", 1000, 10)
- // SCAN MATCH with invalid pattern should return an error
- require.Error(t, rdb.Do(context.Background(), "SCAN", "match",
"*ab*").Err())
+
+ for _, key := range []string{"aa", "aab", "aabb", "ab", "abb",
"ba"} {
+ require.NoError(t, rdb.Set(ctx, key, "hello", 0).Err())
+ }
+
+ keys := scanAll(t, rdb, "match", "a*")
+ require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"},
keys)
+
+ keys = scanAll(t, rdb, "match", "aa")
+ require.Equal(t, []string{"aa"}, keys)
+
+ keys = scanAll(t, rdb, "match", "aa*")
+ require.Equal(t, []string{"aa", "aab", "aabb"}, keys)
+
+ keys = scanAll(t, rdb, "match", "a?")
+ require.Equal(t, []string{"aa", "ab"}, keys)
+
+ keys = scanAll(t, rdb, "match", "a*?")
+ require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"},
keys)
+
+ keys = scanAll(t, rdb, "match", "ab*")
+ require.Equal(t, []string{"ab", "abb"}, keys)
+
+ keys = scanAll(t, rdb, "match", "*ab")
+ require.Equal(t, []string{"aab", "ab"}, keys)
+
+ keys = scanAll(t, rdb, "match", "*ab*")
+ require.Equal(t, []string{"aab", "aabb", "ab", "abb"}, keys)
+
+ // Special case: using [b]* instead of b* forces the a full
scan of the keyspace,
+ // matching every result with the pattern. We ask for exactly
one key, but the
+ // first 5 keys don't match the pattern. This tests that SCAN
returns a valid
+ // cursor even when the first [limit] keys don't satisfy the
pattern.
+ keys = scanAll(t, rdb, "match", "[b]*", "count", "1")
+ require.Equal(t, []string{"ba"}, keys)
})
t.Run("SCAN guarantees check under write load", func(t *testing.T) {
@@ -226,6 +255,7 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx
context.Context) {
require.NoError(t, rdb.SAdd(ctx, "set",
elements...).Err())
keys, _, err := rdb.SScan(ctx, "set", 0, "",
10000).Result()
require.NoError(t, err)
+ slices.Sort(keys)
keys = slices.Compact(keys)
require.Len(t, keys, 100)
})
@@ -307,6 +337,11 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx
context.Context) {
require.NoError(t, rdb.Do(ctx, "SCAN", "0", "match", "a*",
"count", "1").Err())
util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1",
"match", "a*", "hello").Err(), ".*syntax error.*")
util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1",
"match", "a*", "hello", "hi").Err(), ".*syntax error.*")
+
+ util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match",
"[").Err(), ".*Invalid glob pattern.*")
+ util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match",
"\\").Err(), ".*Invalid glob pattern.*")
+ util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match",
"[a").Err(), ".*Invalid glob pattern.*")
+ util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match",
"[a-]").Err(), ".*Invalid glob pattern.*")
})
t.Run("SCAN with type args ", func(t *testing.T) {
@@ -406,6 +441,8 @@ func scanAll(t testing.TB, rdb *redis.Client, args
...interface{}) (keys []strin
keys = append(keys, keyList...)
if c == "0" {
+ slices.Sort(keys)
+ keys = slices.Compact(keys)
return
}
}