This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 7d48490b feat(search): add vector type to kqir::Value (#2371)
7d48490b is described below

commit 7d48490ba1890d9bb0510c11350d7ae1a1f14dd0
Author: Twice <[email protected]>
AuthorDate: Thu Jun 20 00:49:49 2024 +0900

    feat(search): add vector type to kqir::Value (#2371)
---
 src/search/indexer.cc | 105 +++++++++++++++++++++++++++++++++-----------------
 src/search/indexer.h  |   3 ++
 src/search/value.h    |   9 ++++-
 3 files changed, 79 insertions(+), 38 deletions(-)

diff --git a/src/search/indexer.cc b/src/search/indexer.cc
index 80fea7a9..7ce0b3d0 100644
--- a/src/search/indexer.cc
+++ b/src/search/indexer.cc
@@ -57,6 +57,73 @@ StatusOr<FieldValueRetriever> 
FieldValueRetriever::Create(IndexOnDataType type,
   }
 }
 
+// placeholders, remove them after vector indexing is implemented
+static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; }
+static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; }
+
+StatusOr<kqir::Value> FieldValueRetriever::ParseFromJson(const jsoncons::json 
&val,
+                                                         const 
redis::IndexFieldMetadata *type) {
+  if (auto numeric [[maybe_unused]] = dynamic_cast<const 
redis::NumericFieldMetadata *>(type)) {
+    if (!val.is_number() || val.is_string()) return {Status::NotOK, "json 
value cannot be string for numeric fields"};
+    return kqir::MakeValue<kqir::Numeric>(val.as_double());
+  } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
+    if (val.is_string()) {
+      const char delim[] = {tag->separator, '\0'};
+      auto vec = util::Split(val.as_string(), delim);
+      return kqir::MakeValue<kqir::StringArray>(vec);
+    } else if (val.is_array()) {
+      std::vector<std::string> strs;
+      for (size_t i = 0; i < val.size(); ++i) {
+        if (!val[i].is_string())
+          return {Status::NotOK, "json value should be string or array of 
strings for tag fields"};
+        strs.push_back(val[i].as_string());
+      }
+      return kqir::MakeValue<kqir::StringArray>(strs);
+    } else {
+      return {Status::NotOK, "json value should be string or array of strings 
for tag fields"};
+    }
+  } else if (IsVectorType(type)) {
+    size_t dim = GetVectorDim(type);
+    if (!val.is_array()) return {Status::NotOK, "json value should be array of 
numbers for vector fields"};
+    if (dim != val.size()) return {Status::NotOK, "the size of the json array 
is not equal to the dim of the vector"};
+    std::vector<double> nums;
+    for (size_t i = 0; i < dim; ++i) {
+      if (!val[i].is_number() || val[i].is_string())
+        return {Status::NotOK, "json value should be array of numbers for 
vector fields"};
+      nums.push_back(val[i].as_double());
+    }
+    return kqir::MakeValue<kqir::NumericArray>(nums);
+  } else {
+    return {Status::NotOK, "unknown field type to retrieve"};
+  }
+}
+
+StatusOr<kqir::Value> FieldValueRetriever::ParseFromHash(const std::string 
&value,
+                                                         const 
redis::IndexFieldMetadata *type) {
+  if (auto numeric [[maybe_unused]] = dynamic_cast<const 
redis::NumericFieldMetadata *>(type)) {
+    auto num = GET_OR_RET(ParseFloat(value));
+    return kqir::MakeValue<kqir::Numeric>(num);
+  } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
+    const char delim[] = {tag->separator, '\0'};
+    auto vec = util::Split(value, delim);
+    return kqir::MakeValue<kqir::StringArray>(vec);
+  } else if (IsVectorType(type)) {
+    const size_t dim = GetVectorDim(type);
+    if (value.size() != dim * sizeof(double)) {
+      return {Status::NotOK, "field value is too short or too long to be 
parsed as a vector"};
+    }
+    std::vector<double> vec;
+    for (size_t i = 0; i < dim; ++i) {
+      // TODO: care about endian later
+      // TODO: currently only support 64bit floating point
+      vec.push_back(*(reinterpret_cast<const double *>(value.data()) + i));
+    }
+    return kqir::MakeValue<kqir::NumericArray>(vec);
+  } else {
+    return {Status::NotOK, "unknown field type to retrieve"};
+  }
+}
+
 StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, 
const redis::IndexFieldMetadata *type) {
   if (std::holds_alternative<HashData>(db)) {
     auto &[hash, metadata, key] = std::get<HashData>(db);
@@ -71,17 +138,7 @@ StatusOr<kqir::Value> 
FieldValueRetriever::Retrieve(std::string_view field, cons
     if (s.IsNotFound()) return {Status::NotFound, s.ToString()};
     if (!s.ok()) return {Status::NotOK, s.ToString()};
 
-    if (auto numeric [[maybe_unused]] = dynamic_cast<const 
redis::NumericFieldMetadata *>(type)) {
-      auto num = GET_OR_RET(ParseFloat(value));
-      return kqir::MakeValue<kqir::Numeric>(num);
-    } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) 
{
-      const char delim[] = {tag->separator, '\0'};
-      auto vec = util::Split(value, delim);
-      return kqir::MakeValue<kqir::StringArray>(vec);
-    } else {
-      return {Status::NotOK, "unknown field type to retrieve"};
-    }
-
+    return ParseFromHash(value, type);
   } else if (std::holds_alternative<JsonData>(db)) {
     auto &value = std::get<JsonData>(db);
 
@@ -91,31 +148,7 @@ StatusOr<kqir::Value> 
FieldValueRetriever::Retrieve(std::string_view field, cons
       return {Status::NotFound, "json value specified by the field (json path) 
should exist and be unique"};
     auto val = s->value[0];
 
-    if (auto numeric [[maybe_unused]] = dynamic_cast<const 
redis::NumericFieldMetadata *>(type)) {
-      if (val.is_string()) return {Status::NotOK, "json value cannot be string 
for numeric fields"};
-      return kqir::MakeValue<kqir::Numeric>(val.as_double());
-    } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) 
{
-      if (val.is_string()) {
-        const char delim[] = {tag->separator, '\0'};
-        auto vec = util::Split(val.as_string(), delim);
-        return kqir::MakeValue<kqir::StringArray>(vec);
-      } else if (val.is_array()) {
-        std::vector<std::string> strs;
-        for (size_t i = 0; i < val.size(); ++i) {
-          if (!val[i].is_string())
-            return {Status::NotOK, "json value should be string or array of 
strings for tag fields"};
-          strs.push_back(val[i].as_string());
-        }
-        return kqir::MakeValue<kqir::StringArray>(strs);
-      } else {
-        return {Status::NotOK, "json value should be string or array of 
strings for tag fields"};
-      }
-    } else {
-      return {Status::NotOK, "unknown field type to retrieve"};
-    }
-
-    return Status::OK();
-
+    return ParseFromJson(val, type);
   } else {
     return {Status::NotOK, "unknown redis data type to retrieve"};
   }
diff --git a/src/search/indexer.h b/src/search/indexer.h
index 029944e2..8ffd503b 100644
--- a/src/search/indexer.h
+++ b/src/search/indexer.h
@@ -65,6 +65,9 @@ struct FieldValueRetriever {
   explicit FieldValueRetriever(JsonValue json) : 
db(std::in_place_type<JsonData>, std::move(json)) {}
 
   StatusOr<kqir::Value> Retrieve(std::string_view field, const 
redis::IndexFieldMetadata *type);
+
+  static StatusOr<kqir::Value> ParseFromJson(const jsoncons::json &value, 
const redis::IndexFieldMetadata *type);
+  static StatusOr<kqir::Value> ParseFromHash(const std::string &value, const 
redis::IndexFieldMetadata *type);
 };
 
 struct IndexUpdater {
diff --git a/src/search/value.h b/src/search/value.h
index f1a1e8b7..f3395717 100644
--- a/src/search/value.h
+++ b/src/search/value.h
@@ -40,8 +40,8 @@ using String = std::string;  // e.g. a single tag
 using NumericArray = std::vector<Numeric>;  // used for vector fields
 using StringArray = std::vector<String>;    // used for tag fields, e.g. a 
list for tags
 
-struct Value : std::variant<Null, Numeric, StringArray> {
-  using Base = std::variant<Null, Numeric, StringArray>;
+struct Value : std::variant<Null, Numeric, StringArray, NumericArray> {
+  using Base = std::variant<Null, Numeric, StringArray, NumericArray>;
 
   using Base::Base;
 
@@ -72,6 +72,9 @@ struct Value : std::variant<Null, Numeric, StringArray> {
     } else if (Is<StringArray>()) {
       return util::StringJoin(
           Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; 
}, sep);
+    } else if (Is<NumericArray>()) {
+      return util::StringJoin(
+          Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return 
std::to_string(v); }, sep);
     }
 
     __builtin_unreachable();
@@ -87,6 +90,8 @@ struct Value : std::variant<Null, Numeric, StringArray> {
       char sep = tag ? tag->separator : ',';
       return util::StringJoin(
           Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; 
}, std::string(1, sep));
+    } else if (Is<NumericArray>()) {
+      return util::StringJoin(Get<NumericArray>(), [](const auto &v) -> 
decltype(auto) { return std::to_string(v); });
     }
 
     __builtin_unreachable();

Reply via email to