This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 64e7834f Add the support of KEEPTLE and GET options to SET
command(#1935)
64e7834f is described below
commit 64e7834fea79a81d5de125a8fe3ae8e00dc82ee6
Author: 纪华裕 <[email protected]>
AuthorDate: Mon Dec 18 18:13:48 2023 +0800
Add the support of KEEPTLE and GET options to SET command(#1935)
---
src/commands/cmd_string.cc | 50 ++++++-----
src/types/redis_string.cc | 116 ++++++++++++++++---------
src/types/redis_string.h | 15 +++-
tests/cppunit/types/string_test.cc | 6 +-
tests/gocase/unit/type/strings/strings_test.go | 88 +++++++++++++++++++
5 files changed, 212 insertions(+), 63 deletions(-)
diff --git a/src/commands/cmd_string.cc b/src/commands/cmd_string.cc
index 99783172..a0e1a690 100644
--- a/src/commands/cmd_string.cc
+++ b/src/commands/cmd_string.cc
@@ -20,10 +20,12 @@
#include <cstdint>
#include <optional>
+#include <string>
#include "commander.h"
#include "commands/command_parser.h"
#include "error_constants.h"
+#include "server/redis_reply.h"
#include "server/server.h"
#include "storage/redis_db.h"
#include "time_util.h"
@@ -131,16 +133,16 @@ class CommandGetSet : public Commander {
public:
Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::String string_db(srv->storage, conn->GetNamespace());
- std::string old_value;
- auto s = string_db.GetSet(args_[1], args_[2], &old_value);
- if (!s.ok() && !s.IsNotFound()) {
+ std::optional<std::string> old_value;
+ auto s = string_db.GetSet(args_[1], args_[2], old_value);
+ if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
- if (s.IsNotFound()) {
- *output = redis::NilString();
+ if (old_value.has_value()) {
+ *output = redis::BulkString(old_value.value());
} else {
- *output = redis::BulkString(old_value);
+ *output = redis::NilString();
}
return Status::OK();
}
@@ -281,10 +283,14 @@ class CommandSet : public Commander {
while (parser.Good()) {
if (auto v = GET_OR_RET(ParseTTL(parser, ttl_flag))) {
ttl_ = *v;
+ } else if (parser.EatEqICaseFlag("KEEPTTL", ttl_flag)) {
+ keep_ttl_ = true;
} else if (parser.EatEqICaseFlag("NX", set_flag)) {
- set_flag_ = NX;
+ set_flag_ = StringSetType::NX;
} else if (parser.EatEqICaseFlag("XX", set_flag)) {
- set_flag_ = XX;
+ set_flag_ = StringSetType::XX;
+ } else if (parser.EatEqICase("GET")) {
+ get_ = true;
} else {
return parser.InvalidSyntax();
}
@@ -294,7 +300,7 @@ class CommandSet : public Commander {
}
Status Execute(Server *srv, Connection *conn, std::string *output) override {
- bool ret = false;
+ std::optional<std::string> ret;
redis::String string_db(srv->storage, conn->GetNamespace());
if (ttl_ < 0) {
@@ -307,29 +313,33 @@ class CommandSet : public Commander {
}
rocksdb::Status s;
- if (set_flag_ == NX) {
- s = string_db.SetNX(args_[1], args_[2], ttl_, &ret);
- } else if (set_flag_ == XX) {
- s = string_db.SetXX(args_[1], args_[2], ttl_, &ret);
- } else {
- s = string_db.SetEX(args_[1], args_[2], ttl_);
- }
+ s = string_db.Set(args_[1], args_[2], {ttl_, set_flag_, get_, keep_ttl_},
ret);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
- if (set_flag_ != NONE && !ret) {
- *output = redis::NilString();
+ if (get_) {
+ if (ret.has_value()) {
+ *output = redis::BulkString(ret.value());
+ } else {
+ *output = redis::NilString();
+ }
} else {
- *output = redis::SimpleString("OK");
+ if (ret.has_value()) {
+ *output = redis::SimpleString("OK");
+ } else {
+ *output = redis::NilString();
+ }
}
return Status::OK();
}
private:
uint64_t ttl_ = 0;
- enum { NONE, NX, XX } set_flag_ = NONE;
+ bool get_ = false;
+ bool keep_ttl_ = false;
+ StringSetType set_flag_ = StringSetType::NONE;
};
class CommandSetEX : public Commander {
diff --git a/src/types/redis_string.cc b/src/types/redis_string.cc
index 6153fe4e..420d3994 100644
--- a/src/types/redis_string.cc
+++ b/src/types/redis_string.cc
@@ -23,7 +23,6 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
-#include <limits>
#include <optional>
#include <string>
@@ -187,20 +186,10 @@ rocksdb::Status String::GetEx(const std::string
&user_key, std::string *value, u
return rocksdb::Status::OK();
}
-rocksdb::Status String::GetSet(const std::string &user_key, const std::string
&new_value, std::string *old_value) {
- std::string ns_key = AppendNamespacePrefix(user_key);
-
- LockGuard guard(storage_->GetLockManager(), ns_key);
- rocksdb::Status s = getValue(ns_key, old_value);
- if (!s.ok() && !s.IsNotFound()) return s;
-
- std::string raw_value;
- Metadata metadata(kRedisString, false);
- metadata.Encode(&raw_value);
- raw_value.append(new_value);
- auto write_status = updateRawValue(ns_key, raw_value);
- // prev status was used to tell whether old value was empty or not
- return !write_status.ok() ? write_status : s;
+rocksdb::Status String::GetSet(const std::string &user_key, const std::string
&new_value,
+ std::optional<std::string> &old_value) {
+ auto s = Set(user_key, new_value, {/*ttl=*/0, StringSetType::NONE,
/*get=*/true, /*keep_ttl=*/false}, old_value);
+ return s;
}
rocksdb::Status String::GetDel(const std::string &user_key, std::string
*value) {
std::string ns_key = AppendNamespacePrefix(user_key);
@@ -217,38 +206,87 @@ rocksdb::Status String::Set(const std::string &user_key,
const std::string &valu
return MSet(pairs, /*ttl=*/0, /*lock=*/true);
}
-rocksdb::Status String::SetEX(const std::string &user_key, const std::string
&value, uint64_t ttl) {
- std::vector<StringPair> pairs{StringPair{user_key, value}};
- return MSet(pairs, /*ttl=*/ttl, /*lock=*/true);
-}
+rocksdb::Status String::Set(const std::string &user_key, const std::string
&value, StringSetArgs args,
+ std::optional<std::string> &ret) {
+ std::string ns_key = AppendNamespacePrefix(user_key);
-rocksdb::Status String::SetNX(const std::string &user_key, const std::string
&value, uint64_t ttl, bool *flag) {
- std::vector<StringPair> pairs{StringPair{user_key, value}};
- return MSetNX(pairs, ttl, flag);
-}
+ LockGuard guard(storage_->GetLockManager(), ns_key);
-rocksdb::Status String::SetXX(const std::string &user_key, const std::string
&value, uint64_t ttl, bool *flag) {
- *flag = false;
- int exists = 0;
+ // Get old value for NX/XX/GET/KEEPTTL option
+ std::string old_raw_value;
+ auto s = getRawValue(ns_key, &old_raw_value);
+ if (!s.ok() && !s.IsNotFound() && !s.IsInvalidArgument()) return s;
+ auto old_key_found = !s.IsNotFound();
+ // The reply following Redis doc: https://redis.io/commands/set/
+ // Handle GET option
+ if (args.get) {
+ if (s.IsInvalidArgument()) {
+ return s;
+ }
+ if (old_key_found) {
+ // if GET option given: return The previous value of the key.
+ auto offset = Metadata::GetOffsetAfterExpire(old_raw_value[0]);
+ ret = std::make_optional(old_raw_value.substr(offset));
+ } else {
+ // if GET option given, the key didn't exist before: return nil
+ ret = std::nullopt;
+ }
+ }
+
+ // Handle NX/XX option
+ if (old_key_found && args.type == StringSetType::NX) {
+ // if GET option not given, operation aborted: return nil
+ if (!args.get) ret = std::nullopt;
+ return rocksdb::Status::OK();
+ } else if (!old_key_found && args.type == StringSetType::XX) {
+ // if GET option not given, operation aborted: return nil
+ if (!args.get) ret = std::nullopt;
+ return rocksdb::Status::OK();
+ } else {
+ // if GET option not given, make ret not nil
+ if (!args.get) ret = "";
+ }
+
+ // Handle expire time
uint64_t expire = 0;
- if (ttl > 0) {
+ if (args.ttl > 0) {
uint64_t now = util::GetTimeStampMS();
- expire = now + ttl;
+ expire = now + args.ttl;
+ } else if (args.keep_ttl && old_key_found) {
+ Metadata metadata(kRedisString, false);
+ auto s = metadata.Decode(old_raw_value);
+ if (!s.ok()) {
+ return s;
+ }
+ expire = metadata.expire;
}
- std::string ns_key = AppendNamespacePrefix(user_key);
- LockGuard guard(storage_->GetLockManager(), ns_key);
- auto s = Exists({user_key}, &exists);
- if (!s.ok()) return s;
- if (exists != 1) return rocksdb::Status::OK();
-
- *flag = true;
- std::string raw_value;
+ // Create new value
+ std::string new_raw_value;
Metadata metadata(kRedisString, false);
metadata.expire = expire;
- metadata.Encode(&raw_value);
- raw_value.append(value);
- return updateRawValue(ns_key, raw_value);
+ metadata.Encode(&new_raw_value);
+ new_raw_value.append(value);
+ return updateRawValue(ns_key, new_raw_value);
+}
+
+rocksdb::Status String::SetEX(const std::string &user_key, const std::string
&value, uint64_t ttl) {
+ std::optional<std::string> ret;
+ return Set(user_key, value, {ttl, StringSetType::NONE, /*get=*/false,
/*keep_ttl=*/false}, ret);
+}
+
+rocksdb::Status String::SetNX(const std::string &user_key, const std::string
&value, uint64_t ttl, bool *flag) {
+ std::optional<std::string> ret;
+ auto s = Set(user_key, value, {ttl, StringSetType::NX, /*get=*/false,
/*keep_ttl=*/false}, ret);
+ *flag = ret.has_value();
+ return s;
+}
+
+rocksdb::Status String::SetXX(const std::string &user_key, const std::string
&value, uint64_t ttl, bool *flag) {
+ std::optional<std::string> ret;
+ auto s = Set(user_key, value, {ttl, StringSetType::XX, /*get=*/false,
/*keep_ttl=*/false}, ret);
+ *flag = ret.has_value();
+ return s;
}
rocksdb::Status String::SetRange(const std::string &user_key, size_t offset,
const std::string &value,
diff --git a/src/types/redis_string.h b/src/types/redis_string.h
index 41be5bdd..bfb4ef99 100644
--- a/src/types/redis_string.h
+++ b/src/types/redis_string.h
@@ -21,6 +21,7 @@
#pragma once
#include <cstdint>
+#include <optional>
#include <string>
#include <vector>
@@ -32,6 +33,15 @@ struct StringPair {
Slice value;
};
+enum class StringSetType { NONE, NX, XX };
+
+struct StringSetArgs {
+ uint64_t ttl;
+ StringSetType type;
+ bool get;
+ bool keep_ttl;
+};
+
namespace redis {
class String : public Database {
@@ -40,9 +50,12 @@ class String : public Database {
rocksdb::Status Append(const std::string &user_key, const std::string
&value, uint64_t *new_size);
rocksdb::Status Get(const std::string &user_key, std::string *value);
rocksdb::Status GetEx(const std::string &user_key, std::string *value,
uint64_t ttl, bool persist);
- rocksdb::Status GetSet(const std::string &user_key, const std::string
&new_value, std::string *old_value);
+ rocksdb::Status GetSet(const std::string &user_key, const std::string
&new_value,
+ std::optional<std::string> &old_value);
rocksdb::Status GetDel(const std::string &user_key, std::string *value);
rocksdb::Status Set(const std::string &user_key, const std::string &value);
+ rocksdb::Status Set(const std::string &user_key, const std::string &value,
StringSetArgs args,
+ std::optional<std::string> &ret);
rocksdb::Status SetEX(const std::string &user_key, const std::string &value,
uint64_t ttl);
rocksdb::Status SetNX(const std::string &user_key, const std::string &value,
uint64_t ttl, bool *flag);
rocksdb::Status SetXX(const std::string &user_key, const std::string &value,
uint64_t ttl, bool *flag);
diff --git a/tests/cppunit/types/string_test.cc
b/tests/cppunit/types/string_test.cc
index fe916adc..631b4ec8 100644
--- a/tests/cppunit/types/string_test.cc
+++ b/tests/cppunit/types/string_test.cc
@@ -142,15 +142,15 @@ TEST_F(RedisStringTest, GetSet) {
rocksdb::Env::Default()->GetCurrentTime(&now);
std::vector<std::string> values = {"a", "b", "c", "d"};
for (size_t i = 0; i < values.size(); i++) {
- std::string old_value;
+ std::optional<std::string> old_value;
auto s = string_->Expire(key_, now * 1000 + 100000);
- string_->GetSet(key_, values[i], &old_value);
+ string_->GetSet(key_, values[i], old_value);
if (i != 0) {
EXPECT_EQ(values[i - 1], old_value);
auto s = string_->TTL(key_, &ttl);
EXPECT_TRUE(ttl == -1);
} else {
- EXPECT_TRUE(old_value.empty());
+ EXPECT_TRUE(!old_value.has_value());
}
}
auto s = string_->Del(key_);
diff --git a/tests/gocase/unit/type/strings/strings_test.go
b/tests/gocase/unit/type/strings/strings_test.go
index f255228d..fc799fc5 100644
--- a/tests/gocase/unit/type/strings/strings_test.go
+++ b/tests/gocase/unit/type/strings/strings_test.go
@@ -678,6 +678,94 @@ func TestString(t *testing.T) {
util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
})
+ t.Run("Extended SET KEEPTTL and EX/PX/EXAT/PXAT option", func(t
*testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl",
"ex", "100").Err())
+ require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl",
"px", "100").Err())
+ require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl",
"exat", "100").Err())
+ require.Error(t, rdb.Do(ctx, "SET", "foo", "xx", "keepttl",
"pxat", "100").Err())
+ })
+
+ t.Run("Extended SET KEEPTTL WITH option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Equal(t, "OK", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{KeepTTL: true}).Val())
+ ttl := rdb.TTL(ctx, "foo").Val()
+ require.Equal(t, time.Duration(-1), ttl)
+ require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar",
10*time.Second).Val())
+ require.Equal(t, "OK", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{KeepTTL: true}).Val())
+ ttl = rdb.TTL(ctx, "foo").Val()
+ util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+ })
+
+ t.Run("Extended SET GET option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Equal(t, "", rdb.SetArgs(ctx, "foo", "bar",
redis.SetArgs{Get: true}).Val())
+ require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{Get: true}).Val())
+ require.Equal(t, "xx", rdb.Get(ctx, "foo").Val())
+ })
+
+ t.Run("Extended SET GET and NX option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{Get: true, Mode: "NX"}).Val())
+ require.Equal(t, "xx", rdb.Get(ctx, "foo").Val())
+ require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 0).Val())
+ require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{Get: true, Mode: "NX"}).Val())
+ require.Equal(t, "bar", rdb.Get(ctx, "foo").Val())
+ })
+
+ t.Run("Extended SET GET and XX option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{Get: true, Mode: "XX"}).Val())
+ require.Equal(t, "", rdb.Get(ctx, "foo").Val())
+ require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar", 0).Val())
+ require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{Get: true, Mode: "XX"}).Val())
+ require.Equal(t, "xx", rdb.Get(ctx, "foo").Val())
+ })
+
+ t.Run("Extended SET GET and KEEPTTL option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Equal(t, "", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{Get: true, KeepTTL: true}).Val())
+ ttl := rdb.TTL(ctx, "foo").Val()
+ require.Equal(t, time.Duration(-1), ttl)
+ require.Equal(t, "OK", rdb.Set(ctx, "foo", "bar",
10*time.Second).Val())
+ require.Equal(t, "bar", rdb.SetArgs(ctx, "foo", "xx",
redis.SetArgs{Get: true, KeepTTL: true}).Val())
+ ttl = rdb.TTL(ctx, "foo").Val()
+ util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+ })
+
+ t.Run("Extended SET GET and EX option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "ex",
"10", "get").Val())
+ ttl := rdb.TTL(ctx, "foo").Val()
+ util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+ })
+
+ t.Run("Extended SET GET and PX option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+ require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "px",
"10000", "get").Val())
+ ttl := rdb.TTL(ctx, "foo").Val()
+ util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+ })
+
+ t.Run("Extended SET GET and EXAT option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+
+ expireAt :=
strconv.FormatInt(time.Now().Add(10*time.Second).Unix(), 10)
+ require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "exat",
expireAt, "get").Val())
+ ttl := rdb.TTL(ctx, "foo").Val()
+ util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+ })
+
+ t.Run("Extended SET GET and PXAT option", func(t *testing.T) {
+ require.NoError(t, rdb.Del(ctx, "foo").Err())
+
+ expireAt :=
strconv.FormatInt(time.Now().Add(10*time.Second).UnixMilli(), 10)
+ require.Equal(t, nil, rdb.Do(ctx, "SET", "foo", "bar", "pxat",
expireAt, "get").Val())
+
+ ttl := rdb.TTL(ctx, "foo").Val()
+ util.BetweenValues(t, ttl, 5*time.Second, 10*time.Second)
+ })
+
t.Run("GETRANGE with huge ranges, Github issue redis/redis#1844",
func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
require.Equal(t, "bar", rdb.GetRange(ctx, "foo", 0,
2094967291).Val())