This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 64e7834f Add the support of  KEEPTLE and GET options to SET 
command(#1935)
64e7834f is described below

commit 64e7834fea79a81d5de125a8fe3ae8e00dc82ee6
Author: 纪华裕 <[email protected]>
AuthorDate: Mon Dec 18 18:13:48 2023 +0800

    Add the support of  KEEPTLE and GET options to SET command(#1935)
---
 src/commands/cmd_string.cc                     |  50 ++++++-----
 src/types/redis_string.cc                      | 116 ++++++++++++++++---------
 src/types/redis_string.h                       |  15 +++-
 tests/cppunit/types/string_test.cc             |   6 +-
 tests/gocase/unit/type/strings/strings_test.go |  88 +++++++++++++++++++
 5 files changed, 212 insertions(+), 63 deletions(-)

diff --git a/src/commands/cmd_string.cc b/src/commands/cmd_string.cc
index 99783172..a0e1a690 100644
--- a/src/commands/cmd_string.cc
+++ b/src/commands/cmd_string.cc
@@ -20,10 +20,12 @@
 
 #include <cstdint>
 #include <optional>
+#include <string>
 
 #include "commander.h"
 #include "commands/command_parser.h"
 #include "error_constants.h"
+#include "server/redis_reply.h"
 #include "server/server.h"
 #include "storage/redis_db.h"
 #include "time_util.h"
@@ -131,16 +133,16 @@ class CommandGetSet : public Commander {
  public:
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
     redis::String string_db(srv->storage, conn->GetNamespace());
-    std::string old_value;
-    auto s = string_db.GetSet(args_[1], args_[2], &old_value);
-    if (!s.ok() && !s.IsNotFound()) {
+    std::optional<std::string> old_value;
+    auto s = string_db.GetSet(args_[1], args_[2], old_value);
+    if (!s.ok()) {
       return {Status::RedisExecErr, s.ToString()};
     }
 
-    if (s.IsNotFound()) {
-      *output = redis::NilString();
+    if (old_value.has_value()) {
+      *output = redis::BulkString(old_value.value());
     } else {
-      *output = redis::BulkString(old_value);
+      *output = redis::NilString();
     }
     return Status::OK();
   }
@@ -281,10 +283,14 @@ class CommandSet : public Commander {
     while (parser.Good()) {
       if (auto v = GET_OR_RET(ParseTTL(parser, ttl_flag))) {
         ttl_ = *v;
+      } else if (parser.EatEqICaseFlag("KEEPTTL", ttl_flag)) {
+        keep_ttl_ = true;
       } else if (parser.EatEqICaseFlag("NX", set_flag)) {
-        set_flag_ = NX;
+        set_flag_ = StringSetType::NX;
       } else if (parser.EatEqICaseFlag("XX", set_flag)) {
-        set_flag_ = XX;
+        set_flag_ = StringSetType::XX;
+      } else if (parser.EatEqICase("GET")) {
+        get_ = true;
       } else {
         return parser.InvalidSyntax();
       }
@@ -294,7 +300,7 @@ class CommandSet : public Commander {
   }
 
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
-    bool ret = false;
+    std::optional<std::string> ret;
     redis::String string_db(srv->storage, conn->GetNamespace());
 
     if (ttl_ < 0) {
@@ -307,29 +313,33 @@ class CommandSet : public Commander {
     }
 
     rocksdb::Status s;
-    if (set_flag_ == NX) {
-      s = string_db.SetNX(args_[1], args_[2], ttl_, &ret);
-    } else if (set_flag_ == XX) {
-      s = string_db.SetXX(args_[1], args_[2], ttl_, &ret);
-    } else {
-      s = string_db.SetEX(args_[1], args_[2], ttl_);
-    }
+    s = string_db.Set(args_[1], args_[2], {ttl_, set_flag_, get_, keep_ttl_}, 
ret);
 
     if (!s.ok()) {
       return {Status::RedisExecErr, s.ToString()};
     }
 
-    if (set_flag_ != NONE && !ret) {
-      *output = redis::NilString();
+    if (get_) {
+      if (ret.has_value()) {
+        *output = redis::BulkString(ret.value());
+      } else {
+        *output = redis::NilString();
+      }
     } else {
-      *output = redis::SimpleString("OK");
+      if (ret.has_value()) {
+        *output = redis::SimpleString("OK");
+      } else {
+        *output = redis::NilString();
+      }
     }
     return Status::OK();
   }
 
  private:
   uint64_t ttl_ = 0;
-  enum { NONE, NX, XX } set_flag_ = NONE;
+  bool get_ = false;
+  bool keep_ttl_ = false;
+  StringSetType set_flag_ = StringSetType::NONE;
 };
 
 class CommandSetEX : public Commander {
diff --git a/src/types/redis_string.cc b/src/types/redis_string.cc
index 6153fe4e..420d3994 100644
--- a/src/types/redis_string.cc
+++ b/src/types/redis_string.cc
@@ -23,7 +23,6 @@
 #include <cmath>
 #include <cstddef>
 #include <cstdint>
-#include <limits>
 #include <optional>
 #include <string>
 
@@ -187,20 +186,10 @@ rocksdb::Status String::GetEx(const std::string 
&user_key, std::string *value, u
   return rocksdb::Status::OK();
 }
 
-rocksdb::Status String::GetSet(const std::string &user_key, const std::string 
&new_value, std::string *old_value) {
-  std::string ns_key = AppendNamespacePrefix(user_key);
-
-  LockGuard guard(storage_->GetLockManager(), ns_key);
-  rocksdb::Status s = getValue(ns_key, old_value);
-  if (!s.ok() && !s.IsNotFound()) return s;
-
-  std::string raw_value;
-  Metadata metadata(kRedisString, false);
-  metadata.Encode(&raw_value);
-  raw_value.append(new_value);
-  auto write_status = updateRawValue(ns_key, raw_value);
-  // prev status was used to tell whether old value was empty or not
-  return !write_status.ok() ? write_status : s;
+rocksdb::Status String::GetSet(const std::string &user_key, const std::string 
&new_value,
+                               std::optional<std::string> &old_value) {
+  auto s = Set(user_key, new_value, {/*ttl=*/0, StringSetType::NONE, 
/*get=*/true, /*keep_ttl=*/false}, old_value);
+  return s;
 }
 rocksdb::Status String::GetDel(const std::string &user_key, std::string 
*value) {
   std::string ns_key = AppendNamespacePrefix(user_key);
@@ -217,38 +206,87 @@ rocksdb::Status String::Set(const std::string &user_key, 
const std::string &valu
   return MSet(pairs, /*ttl=*/0, /*lock=*/true);
 }
 
-rocksdb::Status String::SetEX(const std::string &user_key, const std::string 
&value, uint64_t ttl) {
-  std::vector<StringPair> pairs{StringPair{user_key, value}};
-  return MSet(pairs, /*ttl=*/ttl, /*lock=*/true);
-}
+rocksdb::Status String::Set(const std::string &user_key, const std::string 
&value, StringSetArgs args,
+                            std::optional<std::string> &ret) {
+  std::string ns_key = AppendNamespacePrefix(user_key);
 
-rocksdb::Status String::SetNX(const std::string &user_key, const std::string 
&value, uint64_t ttl, bool *flag) {
-  std::vector<StringPair> pairs{StringPair{user_key, value}};
-  return MSetNX(pairs, ttl, flag);
-}
+  LockGuard guard(storage_->GetLockManager(), ns_key);
 
-rocksdb::Status String::SetXX(const std::string &user_key, const std::string 
&value, uint64_t ttl, bool *flag) {
-  *flag = false;
-  int exists = 0;
+  // Get old value for NX/XX/GET/KEEPTTL option
+  std::string old_raw_value;
+  auto s = getRawValue(ns_key, &old_raw_value);
+  if (!s.ok() && !s.IsNotFound() && !s.IsInvalidArgument()) return s;
+  auto old_key_found = !s.IsNotFound();
+  // The reply following Redis doc: https://redis.io/commands/set/
+  // Handle GET option
+  if (args.get) {
+    if (s.IsInvalidArgument()) {
+      return s;
+    }
+    if (old_key_found) {
+      // if GET option given: return The previous value of the key.
+      auto offset = Metadata::GetOffsetAfterExpire(old_raw_value[0]);
+      ret = std::make_optional(old_raw_value.substr(offset));
+    } else {
+      // if GET option given, the key didn't exist before: return nil
+      ret = std::nullopt;
+    }
+  }
+
+  // Handle NX/XX option
+  if (old_key_found && args.type == StringSetType::NX) {
+    // if GET option not given, operation aborted: return nil
+    if (!args.get) ret = std::nullopt;
+    return rocksdb::Status::OK();
+  } else if (!old_key_found && args.type == StringSetType::XX) {
+    // if GET option not given, operation aborted: return nil
+    if (!args.get) ret = std::nullopt;
+    return rocksdb::Status::OK();
+  } else {
+    // if GET option not given, make ret not nil
+    if (!args.get) ret = "";
+  }
+
+  // Handle expire time
   uint64_t expire = 0;
-  if (ttl > 0) {
+  if (args.ttl > 0) {
     uint64_t now = util::GetTimeStampMS();
-    expire = now + ttl;
+    expire = now + args.ttl;
+  } else if (args.keep_ttl && old_key_found) {
+    Metadata metadata(kRedisString, false);
+    auto s = metadata.Decode(old_raw_value);
+    if (!s.ok()) {
+      return s;
+    }
+    expire = metadata.expire;
   }
 
-  std::string ns_key = AppendNamespacePrefix(user_key);
-  LockGuard guard(storage_->GetLockManager(), ns_key);
-  auto s = Exists({user_key}, &exists);
-  if (!s.ok()) return s;
-  if (exists != 1) return rocksdb::Status::OK();
-
-  *flag = true;
-  std::string raw_value;
+  // Create new value
+  std::string new_raw_value;
   Metadata metadata(kRedisString, false);
   metadata.expire = expire;
-  metadata.Encode(&raw_value);
-  raw_value.append(value);
-  return updateRawValue(ns_key, raw_value);
+  metadata.Encode(&new_raw_value);
+  new_raw_value.append(value);
+  return updateRawValue(ns_key, new_raw_value);
+}
+
+rocksdb::Status String::SetEX(const std::string &user_key, const std::string 
&value, uint64_t ttl) {
+  std::optional<std::string> ret;
+  return Set(user_key, value, {ttl, StringSetType::NONE, /*get=*/false, 
/*keep_ttl=*/false}, ret);
+}
+
+rocksdb::Status String::SetNX(const std::string &user_key, const std::string 
&value, uint64_t ttl, bool *flag) {
+  std::optional<std::string> ret;
+  auto s = Set(user_key, value, {ttl, StringSetType::NX, /*get=*/false, 
/*keep_ttl=*/false}, ret);
+  *flag = ret.has_value();
+  return s;
+}
+
+rocksdb::Status String::SetXX(const std::string &user_key, const std::string 
&value, uint64_t ttl, bool *flag) {
+  std::optional<std::string> ret;
+  auto s = Set(user_key, value, {ttl, StringSetType::XX, /*get=*/false, 
/*keep_ttl=*/false}, ret);
+  *flag = ret.has_value();
+  return s;
 }
 
 rocksdb::Status String::SetRange(const std::string &user_key, size_t offset, 
const std::string &value,
diff --git a/src/types/redis_string.h b/src/types/redis_string.h
index 41be5bdd..bfb4ef99 100644
--- a/src/types/redis_string.h
+++ b/src/types/redis_string.h
@@ -21,6 +21,7 @@
 #pragma once
 
 #include <cstdint>
+#include <optional>
 #include <string>
 #include <vector>
 
@@ -32,6 +33,15 @@ struct StringPair {
   Slice value;
 };
 
+enum class StringSetType { NONE, NX, XX };
+
+struct StringSetArgs {
+  uint64_t ttl;
+  StringSetType type;
+  bool get;
+  bool keep_ttl;
+};
+
 namespace redis {
 
 class String : public Database {
@@ -40,9 +50,12 @@ class String : public Database {
   rocksdb::Status Append(const std::string &user_key, const std::string 
&value, uint64_t *new_size);
   rocksdb::Status Get(const std::string &user_key, std::string *value);
   rocksdb::Status GetEx(const std::string &user_key, std::string *value, 
uint64_t ttl, bool persist);
-  rocksdb::Status GetSet(const std::string &user_key, const std::string 
&new_value, std::string *old_value);
+  rocksdb::Status GetSet(const std::string &user_key, const std::string 
&new_value,
+                         std::optional<std::string> &old_value);
   rocksdb::Status GetDel(const std::string &user_key, std::string *value);
   rocksdb::Status Set(const std::string &user_key, const std::string &value);
+  rocksdb::Status Set(const std::string &user_key, const std::string &value, 
StringSetArgs args,
+                      std::optional<std::string> &ret);
   rocksdb::Status SetEX(const std::string &user_key, const std::string &value, 
uint64_t ttl);
   rocksdb::Status SetNX(const std::string &user_key, const std::string &value, 
uint64_t ttl, bool *flag);
   rocksdb::Status SetXX(const std::string &user_key, const std::string &value, 
uint64_t ttl, bool *flag);
diff --git a/tests/cppunit/types/string_test.cc 
b/tests/cppunit/types/string_test.cc
index fe916adc..631b4ec8 100644
--- a/tests/cppunit/types/string_test.cc
+++ b/tests/cppunit/types/string_test.cc
@@ -142,15 +142,15 @@ TEST_F(RedisStringTest, GetSet) {
   rocksdb::Env::Default()->GetCurrentTime(&now);
   std::vector<std::string> values = {"a", "b", "c", "d"};
   for (size_t i = 0; i < values.size(); i++) {
-    std::string old_value;
+    std::optional<std::string> old_value;
     auto s = string_->Expire(key_, now * 1000 + 100000);
-    string_->GetSet(key_, values[i], &old_value);
+    string_->GetSet(key_, values[i], old_value);
     if (i != 0) {
       EXPECT_EQ(values[i - 1], old_value);
       auto s = string_->TTL(key_, &ttl);
       EXPECT_TRUE(ttl == -1);
     } else {
-      EXPECT_TRUE(old_value.empty());
+      EXPECT_TRUE(!old_value.has_value());
     }
   }
   auto s = string_->Del(key_);
diff --git a/tests/gocase/unit/type/strings/strings_test.go 
b/tests/gocase/unit/type/strings/strings_test.go
index f255228d..fc799fc5 100644
--- a/tests/gocase/unit/type/strings/strings_test.go
+++ b/tests/gocase/unit/type/strings/strings_test.go
@@ -678,6 +678,94 @@ func TestString(t *testing.T) {
                util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
        })
 
+       t.Run("Extended SET KEEPTTL and EX/PX/EXAT/PXAT option", func(t 
*testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", 
"ex", "100").Err())
+               require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", 
"px", "100").Err())
+               require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", 
"exat", "100").Err())
+               require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl", 
"pxat", "100").Err())
+       })
+
+       t.Run("Extended SET KEEPTTL WITH option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Equal(t, "OK", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{KeepTTL: true}).Val())
+               ttl := rdb.TTL(ctx, "foo").Val()
+               require.Equal(t, time.Duration(-1), ttl)
+               require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 
10*time.Second).Val())
+               require.Equal(t, "OK", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{KeepTTL: true}).Val())
+               ttl = rdb.TTL(ctx, "foo").Val()
+               util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+       })
+
+       t.Run("Extended SET GET option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Equal(t, "", rdb.SetArgs(ctx, "foo", "bar", 
redis.SetArgs{Get: true}).Val())
+               require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{Get: true}).Val())
+               require.Equal(t, "xx", rdb.Get(ctx, "foo").Val())
+       })
+
+       t.Run("Extended SET GET and NX option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{Get: true, Mode: "NX"}).Val())
+               require.Equal(t, "xx", rdb.Get(ctx, "foo").Val())
+               require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 0).Val())
+               require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{Get: true, Mode: "NX"}).Val())
+               require.Equal(t, "bar", rdb.Get(ctx, "foo").Val())
+       })
+
+       t.Run("Extended SET GET and XX option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{Get: true, Mode: "XX"}).Val())
+               require.Equal(t, "", rdb.Get(ctx, "foo").Val())
+               require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 0).Val())
+               require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{Get: true, Mode: "XX"}).Val())
+               require.Equal(t, "xx", rdb.Get(ctx, "foo").Val())
+       })
+
+       t.Run("Extended SET GET and KEEPTTL option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{Get: true, KeepTTL: true}).Val())
+               ttl := rdb.TTL(ctx, "foo").Val()
+               require.Equal(t, time.Duration(-1), ttl)
+               require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 
10*time.Second).Val())
+               require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx", 
redis.SetArgs{Get: true, KeepTTL: true}).Val())
+               ttl = rdb.TTL(ctx, "foo").Val()
+               util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+       })
+
+       t.Run("Extended SET GET and EX option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "ex", 
"10", "get").Val())
+               ttl := rdb.TTL(ctx, "foo").Val()
+               util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+       })
+
+       t.Run("Extended SET GET and PX option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+               require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "px", 
"10000", "get").Val())
+               ttl := rdb.TTL(ctx, "foo").Val()
+               util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+       })
+
+       t.Run("Extended SET GET and EXAT option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+
+               expireAt := 
strconv.FormatInt(time.Now().Add(10*time.Second).Unix(), 10)
+               require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "exat", 
expireAt, "get").Val())
+               ttl := rdb.TTL(ctx, "foo").Val()
+               util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+       })
+
+       t.Run("Extended SET GET and PXAT option", func(t *testing.T) {
+               require.NoError(t, rdb.Del(ctx, "foo").Err())
+
+               expireAt := 
strconv.FormatInt(time.Now().Add(10*time.Second).UnixMilli(), 10)
+               require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "pxat", 
expireAt, "get").Val())
+
+               ttl := rdb.TTL(ctx, "foo").Val()
+               util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+       })
+
        t.Run("GETRANGE with huge ranges, Github issue redis/redis#1844", 
func(t *testing.T) {
                require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
                require.Equal(t, "bar", rdb.GetRange(ctx, "foo", 0, 
2094967291).Val())

Reply via email to