This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 319de1ec Add the support of JSON.ARRINDEX command (#1865)
319de1ec is described below

commit 319de1ec3808e14473f031e4f73ed1a36ed93365
Author: skyitachi <[email protected]>
AuthorDate: Fri Nov 10 11:39:32 2023 +0800

    Add the support of JSON.ARRINDEX command (#1865)
---
 src/commands/cmd_json.cc                 | 52 ++++++++++++++++++++++++++++-
 src/types/json.h                         | 45 +++++++++++++++++++++++++
 src/types/redis_json.cc                  | 20 +++++++++++
 src/types/redis_json.h                   |  2 ++
 tests/cppunit/types/json_test.cc         | 57 ++++++++++++++++++++++++++++++++
 tests/gocase/unit/type/json/json_test.go | 38 +++++++++++++++++++++
 6 files changed, 213 insertions(+), 1 deletion(-)

diff --git a/src/commands/cmd_json.cc b/src/commands/cmd_json.cc
index 08bcf7ab..75e31323 100644
--- a/src/commands/cmd_json.cc
+++ b/src/commands/cmd_json.cc
@@ -320,6 +320,54 @@ class CommandJsonArrPop : public Commander {
   int64_t index_ = -1;
 };
 
+class CommanderJsonArrIndex : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    if (args.size() > 6) {
+      return {Status::RedisExecErr, errWrongNumOfArguments};
+    }
+    start_ = 0;
+    end_ = std::numeric_limits<ssize_t>::max();
+
+    if (args.size() > 4) {
+      start_ = GET_OR_RET(ParseInt<ssize_t>(args[4], 10));
+    }
+    if (args.size() > 5) {
+      end_ = GET_OR_RET(ParseInt<ssize_t>(args[5], 10));
+    }
+    return Status::OK();
+  }
+
+  Status Execute(Server *svr, Connection *conn, std::string *output) override {
+    redis::Json json(svr->storage, conn->GetNamespace());
+
+    std::vector<ssize_t> result;
+
+    auto s = json.ArrIndex(args_[1], args_[2], args_[3], start_, end_, 
&result);
+
+    if (s.IsNotFound()) {
+      *output = redis::NilString();
+      return Status::OK();
+    }
+
+    if (!s.ok()) return {Status::RedisExecErr, s.ToString()};
+
+    *output = redis::MultiLen(result.size());
+    for (const auto &found_index : result) {
+      if (found_index == NOT_ARRAY) {
+        *output += redis::NilString();
+        continue;
+      }
+      *output += redis::Integer(found_index);
+    }
+    return Status::OK();
+  }
+
+ private:
+  ssize_t start_;
+  ssize_t end_;
+};
+
 REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandJsonSet>("json.set", 4, "write", 1, 
1, 1),
                         MakeCmdAttr<CommandJsonGet>("json.get", -2, 
"read-only", 1, 1, 1),
                         MakeCmdAttr<CommandJsonInfo>("json.info", 2, 
"read-only", 1, 1, 1),
@@ -329,5 +377,7 @@ 
REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandJsonSet>("json.set", 4, "write", 1, 1
                         MakeCmdAttr<CommandJsonToggle>("json.toggle", -2, 
"write", 1, 1, 1),
                         MakeCmdAttr<CommandJsonArrLen>("json.arrlen", -2, 
"read-only", 1, 1, 1),
                         MakeCmdAttr<CommandJsonObjkeys>("json.objkeys", -2, 
"read-only", 1, 1, 1),
-                        MakeCmdAttr<CommandJsonArrPop>("json.arrpop", -2, 
"write", 1, 1, 1), );
+                        MakeCmdAttr<CommandJsonArrPop>("json.arrpop", -2, 
"write", 1, 1, 1),
+                        MakeCmdAttr<CommanderJsonArrIndex>("json.arrindex", 
-4, "read-only", 1, 1, 1), );
+
 }  // namespace redis
diff --git a/src/types/json.h b/src/types/json.h
index 1b78ec06..4aa4786b 100644
--- a/src/types/json.h
+++ b/src/types/json.h
@@ -33,6 +33,9 @@
 
 #include "status.h"
 
+constexpr ssize_t NOT_FOUND_INDEX = -1;
+constexpr ssize_t NOT_ARRAY = -2;
+
 struct JsonValue {
   JsonValue() = default;
   explicit JsonValue(jsoncons::basic_json<char> value) : 
value(std::move(value)) {}
@@ -173,6 +176,48 @@ struct JsonValue {
     return result_count;
   }
 
+  static std::pair<ssize_t, ssize_t> NormalizeArrIndices(ssize_t start, 
ssize_t end, ssize_t len) {
+    if (start < 0) {
+      start = std::max<ssize_t>(0, len + start);
+    } else {
+      start = std::min<ssize_t>(start, len - 1);
+    }
+    if (end == 0) {
+      end = len;
+    } else if (end < 0) {
+      end = std::max<ssize_t>(0, len + end);
+    }
+    end = std::min<ssize_t>(end, len);
+    return {start, end};
+  }
+
+  StatusOr<std::vector<ssize_t>> ArrIndex(std::string_view path, const 
jsoncons::json &needle, ssize_t start,
+                                          ssize_t end) const {
+    std::vector<ssize_t> result;
+    try {
+      jsoncons::jsonpath::json_query(value, path, [&](const std::string & 
/*path*/, const jsoncons::json &val) {
+        if (!val.is_array()) {
+          result.emplace_back(NOT_ARRAY);
+          return;
+        }
+        auto [pstart, pend] = NormalizeArrIndices(start, end, 
static_cast<ssize_t>(val.size()));
+        auto arr_begin = val.array_range().begin();
+        auto begin_it = arr_begin + pstart;
+
+        auto end_it = arr_begin + pend;
+        auto it = std::find(begin_it, end_it, needle);
+        if (it != end_it) {
+          result.emplace_back(it - arr_begin);
+          return;
+        }
+        result.emplace_back(NOT_FOUND_INDEX);
+      });
+    } catch (const jsoncons::jsonpath::jsonpath_error &e) {
+      return {Status::NotOK, e.what()};
+    }
+    return result;
+  }
+
   StatusOr<std::vector<std::string>> Type(std::string_view path) const {
     std::vector<std::string> types;
     try {
diff --git a/src/types/redis_json.cc b/src/types/redis_json.cc
index 94536605..3de38a88 100644
--- a/src/types/redis_json.cc
+++ b/src/types/redis_json.cc
@@ -180,6 +180,26 @@ rocksdb::Status Json::ArrAppend(const std::string 
&user_key, const std::string &
   return write(ns_key, &metadata, value);
 }
 
+rocksdb::Status Json::ArrIndex(const std::string &user_key, const std::string 
&path, const std::string &needle,
+                               ssize_t start, ssize_t end, 
std::vector<ssize_t> *result) {
+  auto ns_key = AppendNamespacePrefix(user_key);
+
+  auto needle_res = JsonValue::FromString(needle, 
storage_->GetConfig()->json_max_nesting_depth);
+  if (!needle_res) return rocksdb::Status::InvalidArgument(needle_res.Msg());
+  auto needle_value = *std::move(needle_res);
+
+  JsonMetadata metadata;
+  JsonValue value;
+  auto s = read(ns_key, &metadata, &value);
+  if (!s.ok()) return s;
+
+  auto index_res = value.ArrIndex(path, needle_value.value, start, end);
+  if (!index_res) return rocksdb::Status::InvalidArgument(index_res.Msg());
+  *result = *index_res;
+
+  return rocksdb::Status::OK();
+}
+
 rocksdb::Status Json::Type(const std::string &user_key, const std::string 
&path, std::vector<std::string> *results) {
   auto ns_key = AppendNamespacePrefix(user_key);
 
diff --git a/src/types/redis_json.h b/src/types/redis_json.h
index e1f5cde9..94a054ae 100644
--- a/src/types/redis_json.h
+++ b/src/types/redis_json.h
@@ -48,6 +48,8 @@ class Json : public Database {
                           std::vector<std::optional<std::vector<std::string>>> 
&keys);
   rocksdb::Status ArrPop(const std::string &user_key, const std::string &path, 
int64_t index,
                          std::vector<std::optional<JsonValue>> *results);
+  rocksdb::Status ArrIndex(const std::string &user_key, const std::string 
&path, const std::string &needle,
+                           ssize_t start, ssize_t end, std::vector<ssize_t> 
*result);
 
  private:
   rocksdb::Status write(Slice ns_key, JsonMetadata *metadata, const JsonValue 
&json_val);
diff --git a/tests/cppunit/types/json_test.cc b/tests/cppunit/types/json_test.cc
index c1c10e8d..c371c6d6 100644
--- a/tests/cppunit/types/json_test.cc
+++ b/tests/cppunit/types/json_test.cc
@@ -398,3 +398,60 @@ TEST_F(RedisJsonTest, ArrPop) {
   ASSERT_EQ(res[3]->Dump().GetValue(), "1");
   res.clear();
 }
+
+TEST_F(RedisJsonTest, ArrIndex) {
+  std::vector<ssize_t> res;
+  int max_end = std::numeric_limits<int>::max();
+
+  ASSERT_TRUE(json_->Set(key_, "$", R"({"arr":[0, 1, 2, 3, 2, 1, 0]})").ok());
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 0, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 0);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res.size(), 1);
+  ASSERT_EQ(res[0], 3);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "4", 0, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], -1);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 1, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 6);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", -1, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 6);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 6, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 6);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 5, -1, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], -1);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 5, 0, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 6);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", -2, 6, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], -1);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "\"foo\"", 0, max_end, &res).ok() 
&& res.size() == 1);
+  ASSERT_EQ(res[0], -1);
+
+  ASSERT_TRUE(json_->Set(key_, "$", R"({"arr":[0, 1, 2, 3, 4, 2, 1, 
0]})").ok());
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 3);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", 3, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 5);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "1", 0, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 1);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", 1, 4, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], 2);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "6", 0, max_end, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], -1);
+
+  ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, 2, &res).ok() && 
res.size() == 1);
+  ASSERT_EQ(res[0], -1);
+}
diff --git a/tests/gocase/unit/type/json/json_test.go 
b/tests/gocase/unit/type/json/json_test.go
index 0d5f3fd9..2e97b3b8 100644
--- a/tests/gocase/unit/type/json/json_test.go
+++ b/tests/gocase/unit/type/json/json_test.go
@@ -275,4 +275,42 @@ func TestJson(t *testing.T) {
                require.Equal(t, rdb.Do(ctx, "JSON.GET", "a").Val(), 
`[99,false,99]`)
        })
 
+       t.Run("JSON.ARRINDEX basics", func(t *testing.T) {
+               arrIndexCmd := "JSON.ARRINDEX"
+               require.NoError(t, rdb.Do(ctx, "SET", "a", `1`).Err())
+               require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$", `1`).Err())
+               require.NoError(t, rdb.Do(ctx, "DEL", "a").Err())
+
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", ` {"x":1, 
"y": {"x":1} } `).Err())
+               require.Equal(t, []interface{}{}, rdb.Do(ctx, arrIndexCmd, "a", 
"$..k", `1`).Val())
+               require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$").Err())
+               require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$", ` 1, 2, 
3`).Err())
+
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", 
`{"arr":[0,1,2,3,2,1,0]}`).Err())
+               require.Equal(t, []interface{}{int64(0)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `0`).Val())
+               require.Equal(t, []interface{}{int64(3)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `3`).Val())
+               require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `4`).Val())
+               require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `0`, 1).Val())
+               require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `0`, -1).Val())
+               require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `0`, 6).Val())
+               require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `0`, 4, -0).Val())
+               require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `0`, 5, -1).Val())
+               require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `0`, 5, 0).Val())
+               require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `2`, -2, 6).Val())
+               require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `"foo"`).Val())
+
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", 
`{"arr":[0,1,2,3,4,2,1,0]}`).Err())
+
+               require.Equal(t, []interface{}{int64(3)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `3`).Val())
+               require.Equal(t, []interface{}{int64(5)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `2`, 3).Val())
+               require.Equal(t, []interface{}{int64(1)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `1`).Val())
+               require.Equal(t, []interface{}{int64(2)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `2`, 1, 4).Val())
+               require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `6`).Val())
+               require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr", `3`, 0, 2).Val())
+
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", 
`{"arr":[0,1,2]}`).Err())
+               require.Equal(t, []interface{}{nil, nil, nil}, rdb.Do(ctx, 
arrIndexCmd, "a", "$.arr.*", `1`).Val())
+               require.NoError(t, rdb.Do(ctx, "JSON.SET", "a1", "$", 
`{"arr":[[1],[2],[3]]}`).Err())
+               require.Equal(t, []interface{}{int64(0), int64(-1), int64(-1)}, 
rdb.Do(ctx, arrIndexCmd, "a1", "$.arr.*", `1`).Val())
+       })
 }

Reply via email to