This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 4aa36ec5 feat(scan): Support arbitrary glob patterns (#2608)
4aa36ec5 is described below

commit 4aa36ec5efd4d4cb6e8250a674715c119d1c3b64
Author: Nathan <[email protected]>
AuthorDate: Mon Oct 28 22:22:36 2024 -0400

    feat(scan): Support arbitrary glob patterns (#2608)
---
 .github/config/typos.toml                   |   5 +
 src/commands/cmd_server.cc                  |  16 +-
 src/commands/scan_base.h                    |  14 +-
 src/common/status.h                         |   1 -
 src/common/string_util.cc                   | 228 +++++++++++++++++-----------
 src/common/string_util.h                    |  15 +-
 src/config/config.cc                        |   2 +-
 src/server/server.cc                        |  11 +-
 src/storage/redis_db.cc                     |  28 ++--
 src/storage/redis_db.h                      |   9 +-
 tests/cppunit/string_util_test.cc           | 156 +++++++++++++++++++
 tests/gocase/unit/keyspace/keyspace_test.go |  51 ++++++-
 tests/gocase/unit/scan/scan_test.go         |  51 ++++++-
 13 files changed, 448 insertions(+), 139 deletions(-)

diff --git a/.github/config/typos.toml b/.github/config/typos.toml
index daae57c8..03518540 100644
--- a/.github/config/typos.toml
+++ b/.github/config/typos.toml
@@ -20,6 +20,11 @@ extend-exclude = [
     ".git/",
     "src/vendor/",
     "tests/gocase/util/slot.go",
+    
+    # Uses short strings for testing glob matching
+    "tests/cppunit/string_util_test.cc",
+    "tests/gocase/unit/keyspace/keyspace_test.go",
+    "tests/gocase/unit/scan/scan_test.go",
 ]
 ignore-hidden = false
 
diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc
index 3d98a5f5..600f23c0 100644
--- a/src/commands/cmd_server.cc
+++ b/src/commands/cmd_server.cc
@@ -23,6 +23,8 @@
 #include "commands/scan_base.h"
 #include "common/io_util.h"
 #include "common/rdb_stream.h"
+#include "common/string_util.h"
+#include "common/time_util.h"
 #include "config/config.h"
 #include "error_constants.h"
 #include "server/redis_connection.h"
@@ -30,8 +32,6 @@
 #include "server/server.h"
 #include "stats/disk_stats.h"
 #include "storage/rdb/rdb.h"
-#include "string_util.h"
-#include "time_util.h"
 
 namespace redis {
 
@@ -114,15 +114,15 @@ class CommandNamespace : public Commander {
 class CommandKeys : public Commander {
  public:
   Status Execute(engine::Context &ctx, Server *srv, Connection *conn, 
std::string *output) override {
-    const std::string &prefix = args_[1];
+    const std::string &glob_pattern = args_[1];
     std::vector<std::string> keys;
     redis::Database redis(srv->storage, conn->GetNamespace());
 
-    if (prefix.empty() || prefix.find('*') != prefix.size() - 1) {
-      return {Status::RedisExecErr, "only keys prefix match was supported"};
+    if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
+      return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
     }
-
-    const rocksdb::Status s = redis.Keys(ctx, prefix.substr(0, prefix.size() - 
1), &keys);
+    const auto [prefix, suffix_glob] = util::SplitGlob(glob_pattern);
+    const rocksdb::Status s = redis.Keys(ctx, prefix, suffix_glob, &keys);
     if (!s.ok()) {
       return {Status::RedisExecErr, s.ToString()};
     }
@@ -846,7 +846,7 @@ class CommandScan : public CommandScanBase {
     std::vector<std::string> keys;
     std::string end_key;
 
-    auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, &keys, &end_key, 
type_);
+    const auto s = redis_db.Scan(ctx, key_name, limit_, prefix_, suffix_glob_, 
&keys, &end_key, type_);
     if (!s.ok()) {
       return {Status::RedisExecErr, s.ToString()};
     }
diff --git a/src/commands/scan_base.h b/src/commands/scan_base.h
index b3773b94..5a0d4ca1 100644
--- a/src/commands/scan_base.h
+++ b/src/commands/scan_base.h
@@ -23,8 +23,10 @@
 #include "commander.h"
 #include "commands/command_parser.h"
 #include "error_constants.h"
+#include "glob.h"
 #include "parse_util.h"
 #include "server/server.h"
+#include "string_util.h"
 
 namespace redis {
 
@@ -44,14 +46,11 @@ class CommandScanBase : public Commander {
   Status ParseAdditionalFlags(Parser &parser) {
     while (parser.Good()) {
       if (parser.EatEqICase("match")) {
-        prefix_ = GET_OR_RET(parser.TakeStr());
-        // The match pattern should contain exactly one '*' at the end; remove 
the * to
-        // get the prefix to match.
-        if (!prefix_.empty() && prefix_.find('*') == prefix_.size() - 1) {
-          prefix_.pop_back();
-        } else {
-          return {Status::RedisParseErr, "currently only key prefix matching 
is supported"};
+        const std::string glob_pattern = GET_OR_RET(parser.TakeStr());
+        if (const Status s = util::ValidateGlob(glob_pattern); !s.IsOK()) {
+          return {Status::RedisParseErr, "Invalid glob pattern: " + s.Msg()};
         }
+        std::tie(prefix_, suffix_glob_) = util::SplitGlob(glob_pattern);
       } else if (parser.EatEqICase("count")) {
         limit_ = GET_OR_RET(parser.TakeInt());
         if (limit_ <= 0) {
@@ -100,6 +99,7 @@ class CommandScanBase : public Commander {
  protected:
   std::string cursor_;
   std::string prefix_;
+  std::string suffix_glob_ = "*";
   int limit_ = 20;
   RedisType type_ = kRedisNone;
 };
diff --git a/src/common/status.h b/src/common/status.h
index 2bd610ad..aef74033 100644
--- a/src/common/status.h
+++ b/src/common/status.h
@@ -26,7 +26,6 @@
 #include <algorithm>
 #include <memory>
 #include <string>
-#include <tuple>
 #include <type_traits>
 #include <utility>
 
diff --git a/src/common/string_util.cc b/src/common/string_util.cc
index 476e3173..cce64402 100644
--- a/src/common/string_util.cc
+++ b/src/common/string_util.cc
@@ -101,118 +101,174 @@ bool HasPrefix(const std::string &str, const 
std::string &prefix) {
   return !strncasecmp(str.data(), prefix.data(), prefix.size());
 }
 
-int StringMatch(const std::string &pattern, const std::string &in, int nocase) 
{
-  return StringMatchLen(pattern.c_str(), pattern.length(), in.c_str(), 
in.length(), nocase);
+Status ValidateGlob(std::string_view glob) {
+  for (size_t idx = 0; idx < glob.size(); ++idx) {
+    switch (glob[idx]) {
+      case '*':
+      case '?':
+        break;
+      case ']':
+        return {Status::NotOK, "Unmatched unescaped ]"};
+      case '\\':
+        if (idx == glob.size() - 1) {
+          return {Status::NotOK, "Trailing unescaped backslash"};
+        }
+        // Skip the next character: this is a literal so nothing can go wrong
+        idx++;
+        break;
+      case '[':
+        idx++;  // Skip the opening bracket
+        while (idx < glob.size() && glob[idx] != ']') {
+          if (glob[idx] == '\\') {
+            idx += 2;
+            continue;
+          } else if (idx + 1 < glob.size() && glob[idx + 1] == '-') {
+            if (idx + 2 >= glob.size()) {
+              return {Status::NotOK, "Unterminated character range"};
+            }
+            // Skip the - and the end of the range
+            idx += 2;
+          }
+          idx++;
+        }
+        if (idx == glob.size()) {
+          return {Status::NotOK, "Unterminated [ group"};
+        }
+        break;
+      default:
+        // This is a literal: nothing can go wrong
+        break;
+    }
+  }
+  return Status::OK();
 }
 
-// Glob-style pattern matching.
-int StringMatchLen(const char *pattern, size_t pattern_len, const char 
*string, size_t string_len, int nocase) {
-  while (pattern_len && string_len) {
+constexpr bool StringMatchImpl(std::string_view pattern, std::string_view 
string, bool ignore_case,
+                               bool *skip_longer_matches, size_t 
recursion_depth = 0) {
+  // If we want to ignore case, this is equivalent to converting both the 
pattern and the string to lowercase
+  const auto canonicalize = [ignore_case](unsigned char c) -> unsigned char {
+    return ignore_case ? static_cast<unsigned char>(std::tolower(c)) : c;
+  };
+
+  if (recursion_depth > 1000) return false;
+
+  while (!pattern.empty() && !string.empty()) {
     switch (pattern[0]) {
       case '*':
-        while (pattern[1] == '*') {
-          pattern++;
-          pattern_len--;
+        // Optimization: collapse multiple * into one
+        while (pattern.size() >= 2 && pattern[1] == '*') {
+          pattern.remove_prefix(1);
         }
-
-        if (pattern_len == 1) return 1; /* match */
-
-        while (string_len) {
-          if (StringMatchLen(pattern + 1, pattern_len - 1, string, string_len, 
nocase)) return 1; /* match */
-          string++;
-          string_len--;
+        // Optimization: If the '*' is the last character in the pattern, it 
can match anything
+        if (pattern.length() == 1) return true;
+        while (!string.empty()) {
+          if (StringMatchImpl(pattern.substr(1), string, ignore_case, 
skip_longer_matches, recursion_depth + 1))
+            return true;
+          if (*skip_longer_matches) return false;
+          string.remove_prefix(1);
         }
-        return 0; /* no match */
+        // There was no match for the rest of the pattern starting
+        // from anywhere in the rest of the string. If there were
+        // any '*' earlier in the pattern, we can terminate the
+        // search early without trying to match them to longer
+        // substrings. This is because a longer match for the
+        // earlier part of the pattern would require the rest of the
+        // pattern to match starting later in the string, and we
+        // have just determined that there is no match for the rest
+        // of the pattern starting from anywhere in the current
+        // string.
+        *skip_longer_matches = true;
+        return false;
       case '?':
-        string++;
-        string_len--;
+        if (string.empty()) return false;
+        string.remove_prefix(1);
         break;
       case '[': {
-        pattern++;
-        pattern_len--;
-        int not_symbol = pattern[0] == '^';
-        if (not_symbol) {
-          pattern++;
-          pattern_len--;
-        }
+        pattern.remove_prefix(1);
+        const bool invert = pattern[0] == '^';
+        if (invert) pattern.remove_prefix(1);
 
-        int match = 0;
+        bool match = false;
         while (true) {
-          if (pattern[0] == '\\' && pattern_len >= 2) {
-            pattern++;
-            pattern_len--;
-            if (pattern[0] == string[0]) match = 1;
+          if (pattern.empty()) {
+            // unterminated [ group: reject invalid pattern
+            return false;
           } else if (pattern[0] == ']') {
             break;
-          } else if (pattern_len == 0) {
-            pattern--;
-            pattern_len++;
-            break;
-          } else if (pattern[1] == '-' && pattern_len >= 3) {
-            int start = pattern[0];
-            int end = pattern[2];
-            int c = string[0];
-            if (start > end) {
-              int t = start;
-              start = end;
-              end = t;
-            }
-            if (nocase) {
-              start = tolower(start);
-              end = tolower(end);
-              c = tolower(c);
-            }
-            pattern += 2;
-            pattern_len -= 2;
-            if (c >= start && c <= end) match = 1;
-          } else {
-            if (!nocase) {
-              if (pattern[0] == string[0]) match = 1;
-            } else {
-              if (tolower(static_cast<int>(pattern[0])) == 
tolower(static_cast<int>(string[0]))) match = 1;
-            }
+          } else if (pattern.length() >= 2 && pattern[0] == '\\') {
+            pattern.remove_prefix(1);
+            if (pattern[0] == string[0]) match = true;
+          } else if (pattern.length() >= 3 && pattern[1] == '-') {
+            unsigned char start = canonicalize(pattern[0]);
+            unsigned char end = canonicalize(pattern[2]);
+            if (start > end) std::swap(start, end);
+            const int c = canonicalize(string[0]);
+            pattern.remove_prefix(2);
+
+            if (c >= start && c <= end) match = true;
+          } else if (canonicalize(pattern[0]) == canonicalize(string[0])) {
+            match = true;
           }
-          pattern++;
-          pattern_len--;
+          pattern.remove_prefix(1);
         }
-
-        if (not_symbol) match = !match;
-
-        if (!match) return 0; /* no match */
-
-        string++;
-        string_len--;
+        if (invert) match = !match;
+        if (!match) return false;
+        string.remove_prefix(1);
         break;
       }
       case '\\':
-        if (pattern_len >= 2) {
-          pattern++;
-          pattern_len--;
+        if (pattern.length() >= 2) {
+          pattern.remove_prefix(1);
         }
-        /* fall through */
+        [[fallthrough]];
       default:
-        if (!nocase) {
-          if (pattern[0] != string[0]) return 0; /* no match */
+        // Just a normal character
+        if (!ignore_case) {
+          if (pattern[0] != string[0]) return false;
         } else {
-          if (tolower(static_cast<int>(pattern[0])) != 
tolower(static_cast<int>(string[0]))) return 0; /* no match */
+          if (std::tolower((int)pattern[0]) != std::tolower((int)string[0])) 
return false;
         }
-        string++;
-        string_len--;
+        string.remove_prefix(1);
         break;
     }
-    pattern++;
-    pattern_len--;
-    if (string_len == 0) {
-      while (*pattern == '*') {
-        pattern++;
-        pattern_len--;
-      }
-      break;
-    }
+    pattern.remove_prefix(1);
   }
 
-  if (pattern_len == 0 && string_len == 0) return 1;
-  return 0;
+  // Now that either the pattern is empty or the string is empty, this is a 
match iff
+  // the pattern consists only of '*', and the string is empty.
+  return string.empty() && std::all_of(pattern.begin(), pattern.end(), [](char 
c) { return c == '*'; });
+}
+
+// Given a glob [pattern] and a string [string], return true iff the string 
matches the glob.
+// If [ignore_case] is true, the match is case-insensitive.
+bool StringMatch(std::string_view glob, std::string_view str, bool 
ignore_case) {
+  bool skip_longer_matches = false;
+  return StringMatchImpl(glob, str, ignore_case, &skip_longer_matches);
+}
+
+// Split a glob pattern into a literal prefix and a suffix containing 
wildcards.
+// For example, if the user calls [KEYS bla*bla], this function will return 
{"bla", "*bla"}.
+// This allows the caller of this function to optimize this call by performing 
a
+// prefix-scan on "bla" and then filtering the results using the GlobMatches 
function.
+std::pair<std::string, std::string> SplitGlob(std::string_view glob) {
+  // Stores the prefix of the glob pattern, with backslashes removed
+  std::string prefix;
+  // Find the first un-escaped '*', '?' or '[' character in [glob]
+  for (size_t idx = 0; idx < glob.size(); ++idx) {
+    if (glob[idx] == '*' || glob[idx] == '?' || glob[idx] == '[') {
+      // Return a pair of views: the part of the glob before the wildcard, and 
the part after
+      return {prefix, std::string(glob.substr(idx))};
+    } else if (glob[idx] == '\\') {
+      // Skip checking whether the next character is a special character
+      ++idx;
+      // Append the escaped special character to the prefix
+      if (idx < glob.size()) prefix.push_back(glob[idx]);
+    } else {
+      prefix.push_back(glob[idx]);
+    }
+  }
+  // No wildcard found, return the entire string (without the backslashes) as 
the prefix
+  return {prefix, ""};
 }
 
 std::vector<std::string> RegexMatch(const std::string &str, const std::string 
&regex) {
diff --git a/src/common/string_util.h b/src/common/string_util.h
index 2dcb1080..f86590ad 100644
--- a/src/common/string_util.h
+++ b/src/common/string_util.h
@@ -20,7 +20,13 @@
 
 #pragma once
 
-#include "status.h"
+#include <cstdint>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "common/status.h"
 
 namespace util {
 
@@ -32,8 +38,11 @@ std::string Trim(std::string in, std::string_view chars);
 std::vector<std::string> Split(std::string_view in, std::string_view delim);
 std::vector<std::string> Split2KV(const std::string &in, const std::string 
&delim);
 bool HasPrefix(const std::string &str, const std::string &prefix);
-int StringMatch(const std::string &pattern, const std::string &in, int nocase);
-int StringMatchLen(const char *p, size_t plen, const char *s, size_t slen, int 
nocase);
+
+Status ValidateGlob(std::string_view glob);
+bool StringMatch(std::string_view glob, std::string_view str, bool ignore_case 
= false);
+std::pair<std::string, std::string> SplitGlob(std::string_view glob);
+
 std::vector<std::string> RegexMatch(const std::string &str, const std::string 
&regex);
 std::string StringToHex(std::string_view input);
 std::vector<std::string> TokenizeRedisProtocol(const std::string &value);
diff --git a/src/config/config.cc b/src/config/config.cc
index 57b2c7b0..f14dc78c 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -904,7 +904,7 @@ Status Config::Load(const CLIOptions &opts) {
 void Config::Get(const std::string &key, std::vector<std::string> *values) 
const {
   values->clear();
   for (const auto &iter : fields_) {
-    if (util::StringMatch(key, iter.first, 1)) {
+    if (util::StringMatch(key, iter.first, true)) {
       if (iter.second->IsMultiConfig()) {
         for (const auto &p : util::Split(iter.second->ToString(), "\n")) {
           values->emplace_back(iter.first);
diff --git a/src/server/server.cc b/src/server/server.cc
index 1de5534d..e569d12e 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -38,7 +38,7 @@
 #include <utility>
 
 #include "commands/commander.h"
-#include "config.h"
+#include "common/string_util.h"
 #include "config/config.h"
 #include "fmt/format.h"
 #include "redis_connection.h"
@@ -46,7 +46,6 @@
 #include "storage/redis_db.h"
 #include "storage/scripting.h"
 #include "storage/storage.h"
-#include "string_util.h"
 #include "thread_util.h"
 #include "time_util.h"
 #include "version.h"
@@ -160,7 +159,7 @@ Status Server::Start() {
   if (!config_->cluster_enabled) {
     engine::Context no_txn_ctx = 
engine::Context::NoTransactionContext(storage);
     GET_OR_RET(index_mgr.Load(no_txn_ctx, kDefaultNamespace));
-    for (auto [_, ns] : namespace_.List()) {
+    for (const auto &[_, ns] : namespace_.List()) {
       GET_OR_RET(index_mgr.Load(no_txn_ctx, ns));
     }
   }
@@ -391,7 +390,7 @@ int Server::PublishMessage(const std::string &channel, 
const std::string &msg) {
   std::vector<std::string> patterns;
   std::vector<ConnContext> to_publish_patterns_conn_ctxs;
   for (const auto &iter : pubsub_patterns_) {
-    if (util::StringMatch(iter.first, channel, 0)) {
+    if (util::StringMatch(iter.first, channel, false)) {
       for (const auto &conn_ctx : iter.second) {
         to_publish_patterns_conn_ctxs.emplace_back(conn_ctx);
         patterns.emplace_back(iter.first);
@@ -463,7 +462,7 @@ void Server::GetChannelsByPattern(const std::string 
&pattern, std::vector<std::s
   std::lock_guard<std::mutex> guard(pubsub_channels_mu_);
 
   for (const auto &iter : pubsub_channels_) {
-    if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
+    if (pattern.empty() || util::StringMatch(pattern, iter.first, false)) {
       channels->emplace_back(iter.first);
     }
   }
@@ -549,7 +548,7 @@ void Server::GetSChannelsByPattern(const std::string 
&pattern, std::vector<std::
 
   for (const auto &shard_channels : pubsub_shard_channels_) {
     for (const auto &iter : shard_channels) {
-      if (pattern.empty() || util::StringMatch(pattern, iter.first, 0)) {
+      if (pattern.empty() || util::StringMatch(pattern, iter.first, false)) {
         channels->emplace_back(iter.first);
       }
     }
diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc
index 7fe83477..5eabd8d8 100644
--- a/src/storage/redis_db.cc
+++ b/src/storage/redis_db.cc
@@ -21,16 +21,15 @@
 #include "redis_db.h"
 
 #include <ctime>
-#include <map>
 #include <utility>
 
 #include "cluster/redis_slot.h"
 #include "common/scope_exit.h"
+#include "common/string_util.h"
 #include "db_util.h"
 #include "parse_util.h"
 #include "rocksdb/iterator.h"
 #include "rocksdb/status.h"
-#include "server/server.h"
 #include "storage/iterator.h"
 #include "storage/redis_metadata.h"
 #include "storage/storage.h"
@@ -249,11 +248,11 @@ rocksdb::Status Database::GetExpireTime(engine::Context 
&ctx, const Slice &user_
 }
 
 rocksdb::Status Database::GetKeyNumStats(engine::Context &ctx, const 
std::string &prefix, KeyNumStats *stats) {
-  return Keys(ctx, prefix, nullptr, stats);
+  return Keys(ctx, prefix, "*", nullptr, stats);
 }
 
-rocksdb::Status Database::Keys(engine::Context &ctx, const std::string 
&prefix, std::vector<std::string> *keys,
-                               KeyNumStats *stats) {
+rocksdb::Status Database::Keys(engine::Context &ctx, const std::string 
&prefix, const std::string &suffix_glob,
+                               std::vector<std::string> *keys, KeyNumStats 
*stats) {
   uint16_t slot_id = 0;
   std::string ns_prefix;
   if (namespace_ != kDefaultNamespace || keys != nullptr) {
@@ -277,6 +276,10 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const 
std::string &prefix,
       if (!ns_prefix.empty() && !iter->key().starts_with(ns_prefix)) {
         break;
       }
+      auto [_, user_key] = ExtractNamespaceKey(iter->key(), 
storage_->IsSlotIdEncoded());
+      if (!util::StringMatch(suffix_glob, 
user_key.ToString().substr(prefix.size()))) {
+        continue;
+      }
       Metadata metadata(kRedisNone, false);
       auto s = metadata.Decode(iter->value());
       if (!s.ok()) continue;
@@ -293,7 +296,6 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const 
std::string &prefix,
         }
       }
       if (keys) {
-        auto [_, user_key] = ExtractNamespaceKey(iter->key(), 
storage_->IsSlotIdEncoded());
         keys->emplace_back(user_key.ToString());
       }
     }
@@ -319,8 +321,8 @@ rocksdb::Status Database::Keys(engine::Context &ctx, const 
std::string &prefix,
 }
 
 rocksdb::Status Database::Scan(engine::Context &ctx, const std::string 
&cursor, uint64_t limit,
-                               const std::string &prefix, 
std::vector<std::string> *keys, std::string *end_cursor,
-                               RedisType type) {
+                               const std::string &prefix, const std::string 
&suffix_glob,
+                               std::vector<std::string> *keys, std::string 
*end_cursor, RedisType type) {
   end_cursor->clear();
   uint64_t cnt = 0;
   uint16_t slot_start = 0;
@@ -366,6 +368,10 @@ rocksdb::Status Database::Scan(engine::Context &ctx, const 
std::string &cursor,
 
       if (metadata.Expired()) continue;
       std::tie(std::ignore, user_key) = 
ExtractNamespaceKey<std::string>(iter->key(), storage_->IsSlotIdEncoded());
+
+      if (!util::StringMatch(suffix_glob, user_key.substr(prefix.size()))) {
+        continue;
+      }
       keys->emplace_back(user_key);
       cnt++;
     }
@@ -395,7 +401,7 @@ rocksdb::Status Database::Scan(engine::Context &ctx, const 
std::string &cursor,
         if (iter->Valid()) {
           std::tie(std::ignore, user_key) = 
ExtractNamespaceKey<std::string>(iter->key(), storage_->IsSlotIdEncoded());
           auto res = std::mismatch(prefix.begin(), prefix.end(), 
user_key.begin());
-          if (res.first == prefix.end()) {
+          if (res.first == prefix.end() && util::StringMatch(suffix_glob, 
user_key.substr(prefix.size()))) {
             keys->emplace_back(user_key);
           }
 
@@ -420,13 +426,13 @@ rocksdb::Status Database::RandomKey(engine::Context &ctx, 
const std::string &cur
 
   std::string end_cursor;
   std::vector<std::string> keys;
-  auto s = Scan(ctx, cursor, RANDOM_KEY_SCAN_LIMIT, "", &keys, &end_cursor);
+  auto s = Scan(ctx, cursor, RANDOM_KEY_SCAN_LIMIT, "", "*", &keys, 
&end_cursor);
   if (!s.ok()) {
     return s;
   }
   if (keys.empty() && !cursor.empty()) {
     // if reach the end, restart from beginning
-    s = Scan(ctx, "", RANDOM_KEY_SCAN_LIMIT, "", &keys, &end_cursor);
+    s = Scan(ctx, "", RANDOM_KEY_SCAN_LIMIT, "", "*", &keys, &end_cursor);
     if (!s.ok()) {
       return s;
     }
diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h
index 7111fed1..41ed3dae 100644
--- a/src/storage/redis_db.h
+++ b/src/storage/redis_db.h
@@ -20,7 +20,6 @@
 
 #pragma once
 
-#include <map>
 #include <optional>
 #include <string>
 #include <utility>
@@ -29,7 +28,6 @@
 
 #include "cluster/cluster_defs.h"
 #include "redis_metadata.h"
-#include "server/redis_reply.h"
 #include "storage.h"
 
 namespace redis {
@@ -119,11 +117,12 @@ class Database {
   [[nodiscard]] rocksdb::Status FlushDB(engine::Context &ctx);
   [[nodiscard]] rocksdb::Status FlushAll(engine::Context &ctx);
   [[nodiscard]] rocksdb::Status GetKeyNumStats(engine::Context &ctx, const 
std::string &prefix, KeyNumStats *stats);
-  [[nodiscard]] rocksdb::Status Keys(engine::Context &ctx, const std::string 
&prefix,
+  [[nodiscard]] rocksdb::Status Keys(engine::Context &ctx, const std::string 
&prefix, const std::string &suffix_glob,
                                      std::vector<std::string> *keys = nullptr, 
KeyNumStats *stats = nullptr);
   [[nodiscard]] rocksdb::Status Scan(engine::Context &ctx, const std::string 
&cursor, uint64_t limit,
-                                     const std::string &prefix, 
std::vector<std::string> *keys,
-                                     std::string *end_cursor = nullptr, 
RedisType type = kRedisNone);
+                                     const std::string &prefix, const 
std::string &suffix_glob,
+                                     std::vector<std::string> *keys, 
std::string *end_cursor = nullptr,
+                                     RedisType type = kRedisNone);
   [[nodiscard]] rocksdb::Status RandomKey(engine::Context &ctx, const 
std::string &cursor, std::string *key);
   std::string AppendNamespacePrefix(const Slice &user_key);
   [[nodiscard]] rocksdb::Status ClearKeysOfSlotRange(engine::Context &ctx, 
const rocksdb::Slice &ns,
diff --git a/tests/cppunit/string_util_test.cc 
b/tests/cppunit/string_util_test.cc
index f95ccbff..1d24cf59 100644
--- a/tests/cppunit/string_util_test.cc
+++ b/tests/cppunit/string_util_test.cc
@@ -22,6 +22,7 @@
 
 #include <gtest/gtest.h>
 
+#include <initializer_list>
 #include <map>
 #include <string>
 #include <unordered_map>
@@ -84,6 +85,161 @@ TEST(StringUtil, HasPrefix) {
   ASSERT_FALSE(util::HasPrefix("has", "has_prefix"));
 }
 
+TEST(StringUtil, ValidateGlob) {
+  const auto expect_ok = [](std::string_view glob) {
+    const auto result = util::ValidateGlob(glob);
+    EXPECT_TRUE(result.IsOK()) << glob << ": " << result.Msg();
+  };
+
+  const auto expect_error = [](std::string_view glob, std::string_view 
expected_error) {
+    const auto result = util::ValidateGlob(glob);
+    EXPECT_FALSE(result.IsOK());
+    EXPECT_EQ(result.Msg(), expected_error) << glob;
+  };
+
+  expect_ok("a");
+  expect_ok("\\*");
+  expect_ok("\\?");
+  expect_ok("\\[");
+  expect_ok("\\]");
+  expect_ok("a*");
+  expect_ok("a?");
+  expect_ok("[ab]");
+  expect_ok("[^ab]");
+  expect_ok("[a-c]");
+  // Surprisingly valid: this accepts the characters {a, b, c, e, f, g, -}
+  expect_ok("[a-c-e-g]");
+  expect_ok("[^a-c]");
+  expect_ok("[-]");
+  expect_ok("[\\]]");
+  expect_ok("[\\\\]");
+  expect_ok("[\\?]");
+  expect_ok("[\\*]");
+  expect_ok("[\\[]");
+
+  expect_error("[", "Unterminated [ group");
+  expect_error("]", "Unmatched unescaped ]");
+  expect_error("[a", "Unterminated [ group");
+  expect_error("\\", "Trailing unescaped backslash");
+
+  // Weird case: we open a character class, with the range 'a' to ']', but 
then never close it
+  expect_error("[a-]", "Unterminated [ group");
+  expect_ok("[a-]]");
+}
+
+TEST(StringUtil, StringMatch) {
+  /* Some basic tests */
+  EXPECT_TRUE(util::StringMatch("a", "a"));
+  EXPECT_FALSE(util::StringMatch("a", "b"));
+  EXPECT_FALSE(util::StringMatch("a", "aa"));
+  EXPECT_FALSE(util::StringMatch("a", ""));
+  EXPECT_TRUE(util::StringMatch("", ""));
+  EXPECT_FALSE(util::StringMatch("", "a"));
+  EXPECT_TRUE(util::StringMatch("*", ""));
+  EXPECT_TRUE(util::StringMatch("*", "a"));
+
+  /* Simple character class tests */
+  EXPECT_TRUE(util::StringMatch("[a]", "a"));
+  EXPECT_FALSE(util::StringMatch("[a]", "b"));
+  EXPECT_FALSE(util::StringMatch("[^a]", "a"));
+  EXPECT_TRUE(util::StringMatch("[^a]", "b"));
+  EXPECT_TRUE(util::StringMatch("[ab]", "a"));
+  EXPECT_TRUE(util::StringMatch("[ab]", "b"));
+  EXPECT_FALSE(util::StringMatch("[ab]", "c"));
+  EXPECT_TRUE(util::StringMatch("[^ab]", "c"));
+  EXPECT_TRUE(util::StringMatch("[a-c]", "b"));
+  EXPECT_FALSE(util::StringMatch("[a-c]", "d"));
+
+  /* Corner cases in character class parsing */
+  EXPECT_TRUE(util::StringMatch("[a-c-e-g]", "-"));
+  EXPECT_FALSE(util::StringMatch("[a-c-e-g]", "d"));
+  EXPECT_TRUE(util::StringMatch("[a-c-e-g]", "f"));
+
+  /* Escaping */
+  EXPECT_TRUE(util::StringMatch("\\?", "?"));
+  EXPECT_FALSE(util::StringMatch("\\?", "a"));
+  EXPECT_TRUE(util::StringMatch("\\*", "*"));
+  EXPECT_FALSE(util::StringMatch("\\*", "a"));
+  EXPECT_TRUE(util::StringMatch("\\[", "["));
+  EXPECT_TRUE(util::StringMatch("\\]", "]"));
+  EXPECT_TRUE(util::StringMatch("\\\\", "\\"));
+  EXPECT_TRUE(util::StringMatch("[\\.]", "."));
+  EXPECT_TRUE(util::StringMatch("[\\-]", "-"));
+  EXPECT_TRUE(util::StringMatch("[\\[]", "["));
+  EXPECT_TRUE(util::StringMatch("[\\]]", "]"));
+  EXPECT_TRUE(util::StringMatch("[\\\\]", "\\"));
+  EXPECT_TRUE(util::StringMatch("[\\?]", "?"));
+  EXPECT_TRUE(util::StringMatch("[\\*]", "*"));
+
+  /* Simple wild cards */
+  EXPECT_TRUE(util::StringMatch("?", "a"));
+  EXPECT_FALSE(util::StringMatch("?", "aa"));
+  EXPECT_FALSE(util::StringMatch("??", "a"));
+  EXPECT_TRUE(util::StringMatch("?x?", "axb"));
+  EXPECT_FALSE(util::StringMatch("?x?", "abx"));
+  EXPECT_FALSE(util::StringMatch("?x?", "xab"));
+
+  /* Asterisk wild cards (backtracking) */
+  EXPECT_FALSE(util::StringMatch("*??", "a"));
+  EXPECT_TRUE(util::StringMatch("*??", "ab"));
+  EXPECT_TRUE(util::StringMatch("*??", "abc"));
+  EXPECT_TRUE(util::StringMatch("*??", "abcd"));
+  EXPECT_FALSE(util::StringMatch("??*", "a"));
+  EXPECT_TRUE(util::StringMatch("??*", "ab"));
+  EXPECT_TRUE(util::StringMatch("??*", "abc"));
+  EXPECT_TRUE(util::StringMatch("??*", "abcd"));
+  EXPECT_FALSE(util::StringMatch("?*?", "a"));
+  EXPECT_TRUE(util::StringMatch("?*?", "ab"));
+  EXPECT_TRUE(util::StringMatch("?*?", "abc"));
+  EXPECT_TRUE(util::StringMatch("?*?", "abcd"));
+  EXPECT_TRUE(util::StringMatch("*b", "b"));
+  EXPECT_TRUE(util::StringMatch("*b", "ab"));
+  EXPECT_FALSE(util::StringMatch("*b", "ba"));
+  EXPECT_TRUE(util::StringMatch("*b", "bb"));
+  EXPECT_TRUE(util::StringMatch("*b", "abb"));
+  EXPECT_TRUE(util::StringMatch("*b", "bab"));
+  EXPECT_TRUE(util::StringMatch("*bc", "abbc"));
+  EXPECT_TRUE(util::StringMatch("*bc", "bc"));
+  EXPECT_TRUE(util::StringMatch("*bc", "bbc"));
+  EXPECT_TRUE(util::StringMatch("*bc", "bcbc"));
+
+  /* Multiple asterisks (complex backtracking) */
+  EXPECT_TRUE(util::StringMatch("*ac*", "abacadaeafag"));
+  EXPECT_TRUE(util::StringMatch("*ac*ae*ag*", "abacadaeafag"));
+  EXPECT_TRUE(util::StringMatch("*a*b*[bc]*[ef]*g*", "abacadaeafag"));
+  EXPECT_FALSE(util::StringMatch("*a*b*[ef]*[cd]*g*", "abacadaeafag"));
+  EXPECT_TRUE(util::StringMatch("*abcd*", "abcabcabcabcdefg"));
+  EXPECT_TRUE(util::StringMatch("*ab*cd*", "abcabcabcabcdefg"));
+  EXPECT_TRUE(util::StringMatch("*abcd*abcdef*", "abcabcdabcdeabcdefg"));
+  EXPECT_FALSE(util::StringMatch("*abcd*", "abcabcabcabcefg"));
+  EXPECT_FALSE(util::StringMatch("*ab*cd*", "abcabcabcabcefg"));
+
+  /* Robustness to exponential blow-ups with lots of non-collapsible asterisks 
*/
+  EXPECT_TRUE(
+      util::StringMatch("?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*a", 
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"));
+  EXPECT_FALSE(
+      util::StringMatch("?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*b", 
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"));
+}
+
+TEST(StringUtil, SplitGlob) {
+  using namespace std::string_literals;
+
+  // Basic functionality: no escaped characters
+  EXPECT_EQ(util::SplitGlob(""), std::make_pair(""s, ""s));
+  EXPECT_EQ(util::SplitGlob("string"), std::make_pair("string"s, ""s));
+  EXPECT_EQ(util::SplitGlob("string*"), std::make_pair("string"s, "*"s));
+  EXPECT_EQ(util::SplitGlob("*string"), std::make_pair(""s, "*string"s));
+  EXPECT_EQ(util::SplitGlob("str*ing"), std::make_pair("str"s, "*ing"s));
+  EXPECT_EQ(util::SplitGlob("string?"), std::make_pair("string"s, "?"s));
+  EXPECT_EQ(util::SplitGlob("?string"), std::make_pair(""s, "?string"s));
+  EXPECT_EQ(util::SplitGlob("ab[cd]ef"), std::make_pair("ab"s, "[cd]ef"s));
+
+  // Escaped characters; also tests that prefix is trimmed of backslashes
+  EXPECT_EQ(util::SplitGlob("str\\*ing*"), std::make_pair("str*ing"s, "*"s));
+  EXPECT_EQ(util::SplitGlob("str\\?ing?"), std::make_pair("str?ing"s, "?"s));
+  EXPECT_EQ(util::SplitGlob("str\\[ing[a]"), std::make_pair("str[ing"s, 
"[a]"s));
+}
+
 TEST(StringUtil, EscapeString) {
   std::unordered_map<std::string, std::string> origin_to_escaped = {
       {"abc", "abc"},
diff --git a/tests/gocase/unit/keyspace/keyspace_test.go 
b/tests/gocase/unit/keyspace/keyspace_test.go
index 6fbb84a7..37b86afd 100644
--- a/tests/gocase/unit/keyspace/keyspace_test.go
+++ b/tests/gocase/unit/keyspace/keyspace_test.go
@@ -27,6 +27,7 @@ import (
 
        "github.com/apache/kvrocks/tests/gocase/util"
        "github.com/stretchr/testify/require"
+       "golang.org/x/exp/slices"
 )
 
 func TestKeyspace(t *testing.T) {
@@ -65,10 +66,6 @@ func TestKeyspace(t *testing.T) {
                require.Equal(t, []string{"foo_a", "foo_b", "foo_c"}, keys)
        })
 
-       t.Run("KEYS with invalid pattern", func(t *testing.T) {
-               require.Error(t, rdb.Keys(ctx, "*ab*").Err())
-       })
-
        t.Run("KEYS to get all keys", func(t *testing.T) {
                keys := rdb.Keys(ctx, "*").Val()
                sort.Slice(keys, func(i, j int) bool {
@@ -77,12 +74,58 @@ func TestKeyspace(t *testing.T) {
                require.Equal(t, []string{"foo_a", "foo_b", "foo_c", "key_x", 
"key_y", "key_z"}, keys)
        })
 
+       t.Run("KEYS with invalid patterns", func(t *testing.T) {
+               require.Error(t, rdb.Keys(ctx, "[").Err())
+               require.Error(t, rdb.Keys(ctx, "\\").Err())
+               require.Error(t, rdb.Keys(ctx, "[a-]").Err())
+               require.Error(t, rdb.Keys(ctx, "[a").Err())
+       })
+
        t.Run("DBSize", func(t *testing.T) {
                require.NoError(t, rdb.Do(ctx, "dbsize", "scan").Err())
                time.Sleep(100 * time.Millisecond)
                require.EqualValues(t, 6, rdb.Do(ctx, "dbsize").Val())
        })
 
+       t.Run("KEYS with non-trivial patterns", func(t *testing.T) {
+               require.NoError(t, rdb.FlushDB(ctx).Err())
+               for _, key := range []string{"aa", "aab", "aabb", "ab", "abb"} {
+                       require.NoError(t, rdb.Set(ctx, key, "hello", 0).Err())
+               }
+
+               keys := rdb.Keys(ctx, "a*").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, 
keys)
+
+               keys = rdb.Keys(ctx, "aa").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"aa"}, keys)
+
+               keys = rdb.Keys(ctx, "aa*").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"aa", "aab", "aabb"}, keys)
+
+               keys = rdb.Keys(ctx, "a?").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"aa", "ab"}, keys)
+
+               keys = rdb.Keys(ctx, "a*?").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, 
keys)
+
+               keys = rdb.Keys(ctx, "ab*").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"ab", "abb"}, keys)
+
+               keys = rdb.Keys(ctx, "*ab").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"aab", "ab"}, keys)
+
+               keys = rdb.Keys(ctx, "*ab*").Val()
+               slices.Sort(keys)
+               require.Equal(t, []string{"aab", "aabb", "ab", "abb"}, keys)
+       })
+
        t.Run("DEL all keys", func(t *testing.T) {
                vals := rdb.Keys(ctx, "*").Val()
                require.EqualValues(t, len(vals), rdb.Del(ctx, vals...).Val())
diff --git a/tests/gocase/unit/scan/scan_test.go 
b/tests/gocase/unit/scan/scan_test.go
index cade5dcc..5d2f9fb9 100644
--- a/tests/gocase/unit/scan/scan_test.go
+++ b/tests/gocase/unit/scan/scan_test.go
@@ -77,7 +77,6 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx 
context.Context) {
                require.NoError(t, rdb.FlushDB(ctx).Err())
                util.Populate(t, rdb, "", 1000, 10)
                keys := scanAll(t, rdb)
-               keys = slices.Compact(keys)
                require.Len(t, keys, 1000)
        })
 
@@ -85,7 +84,6 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx 
context.Context) {
                require.NoError(t, rdb.FlushDB(ctx).Err())
                util.Populate(t, rdb, "", 1000, 10)
                keys := scanAll(t, rdb, "count", 5)
-               keys = slices.Compact(keys)
                require.Len(t, keys, 1000)
        })
 
@@ -93,15 +91,46 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx 
context.Context) {
                require.NoError(t, rdb.FlushDB(ctx).Err())
                util.Populate(t, rdb, "key:", 1000, 10)
                keys := scanAll(t, rdb, "match", "key:*")
-               keys = slices.Compact(keys)
                require.Len(t, keys, 1000)
        })
 
-       t.Run("SCAN MATCH invalid pattern", func(t *testing.T) {
+       t.Run("SCAN MATCH non-trivial pattern", func(t *testing.T) {
                require.NoError(t, rdb.FlushDB(ctx).Err())
-               util.Populate(t, rdb, "*ab", 1000, 10)
-               // SCAN MATCH with invalid pattern should return an error
-               require.Error(t, rdb.Do(context.Background(), "SCAN", "match", 
"*ab*").Err())
+
+               for _, key := range []string{"aa", "aab", "aabb", "ab", "abb", 
"ba"} {
+                       require.NoError(t, rdb.Set(ctx, key, "hello", 0).Err())
+               }
+
+               keys := scanAll(t, rdb, "match", "a*")
+               require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, 
keys)
+
+               keys = scanAll(t, rdb, "match", "aa")
+               require.Equal(t, []string{"aa"}, keys)
+
+               keys = scanAll(t, rdb, "match", "aa*")
+               require.Equal(t, []string{"aa", "aab", "aabb"}, keys)
+
+               keys = scanAll(t, rdb, "match", "a?")
+               require.Equal(t, []string{"aa", "ab"}, keys)
+
+               keys = scanAll(t, rdb, "match", "a*?")
+               require.Equal(t, []string{"aa", "aab", "aabb", "ab", "abb"}, 
keys)
+
+               keys = scanAll(t, rdb, "match", "ab*")
+               require.Equal(t, []string{"ab", "abb"}, keys)
+
+               keys = scanAll(t, rdb, "match", "*ab")
+               require.Equal(t, []string{"aab", "ab"}, keys)
+
+               keys = scanAll(t, rdb, "match", "*ab*")
+               require.Equal(t, []string{"aab", "aabb", "ab", "abb"}, keys)
+
+               // Special case: using [b]* instead of b* forces the a full 
scan of the keyspace,
+               // matching every result with the pattern. We ask for exactly 
one key, but the
+               // first 5 keys don't match the pattern. This tests that SCAN 
returns a valid
+               // cursor even when the first [limit] keys don't satisfy the 
pattern.
+               keys = scanAll(t, rdb, "match", "[b]*", "count", "1")
+               require.Equal(t, []string{"ba"}, keys)
        })
 
        t.Run("SCAN guarantees check under write load", func(t *testing.T) {
@@ -226,6 +255,7 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx 
context.Context) {
                        require.NoError(t, rdb.SAdd(ctx, "set", 
elements...).Err())
                        keys, _, err := rdb.SScan(ctx, "set", 0, "", 
10000).Result()
                        require.NoError(t, err)
+                       slices.Sort(keys)
                        keys = slices.Compact(keys)
                        require.Len(t, keys, 100)
                })
@@ -307,6 +337,11 @@ func ScanTest(t *testing.T, rdb *redis.Client, ctx 
context.Context) {
                require.NoError(t, rdb.Do(ctx, "SCAN", "0", "match", "a*", 
"count", "1").Err())
                util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", 
"match", "a*", "hello").Err(), ".*syntax error.*")
                util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "count", "1", 
"match", "a*", "hello", "hi").Err(), ".*syntax error.*")
+
+               util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", 
"[").Err(), ".*Invalid glob pattern.*")
+               util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", 
"\\").Err(), ".*Invalid glob pattern.*")
+               util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", 
"[a").Err(), ".*Invalid glob pattern.*")
+               util.ErrorRegexp(t, rdb.Do(ctx, "SCAN", "0", "match", 
"[a-]").Err(), ".*Invalid glob pattern.*")
        })
 
        t.Run("SCAN with type args ", func(t *testing.T) {
@@ -406,6 +441,8 @@ func scanAll(t testing.TB, rdb *redis.Client, args 
...interface{}) (keys []strin
                keys = append(keys, keyList...)
 
                if c == "0" {
+                       slices.Sort(keys)
+                       keys = slices.Compact(keys)
                        return
                }
        }


Reply via email to