This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 7d48490b feat(search): add vector type to kqir::Value (#2371)
7d48490b is described below
commit 7d48490ba1890d9bb0510c11350d7ae1a1f14dd0
Author: Twice <[email protected]>
AuthorDate: Thu Jun 20 00:49:49 2024 +0900
feat(search): add vector type to kqir::Value (#2371)
---
src/search/indexer.cc | 105 +++++++++++++++++++++++++++++++++-----------------
src/search/indexer.h | 3 ++
src/search/value.h | 9 ++++-
3 files changed, 79 insertions(+), 38 deletions(-)
diff --git a/src/search/indexer.cc b/src/search/indexer.cc
index 80fea7a9..7ce0b3d0 100644
--- a/src/search/indexer.cc
+++ b/src/search/indexer.cc
@@ -57,6 +57,73 @@ StatusOr<FieldValueRetriever>
FieldValueRetriever::Create(IndexOnDataType type,
}
}
+// placeholders, remove them after vector indexing is implemented
+static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; }
+static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; }
+
+StatusOr<kqir::Value> FieldValueRetriever::ParseFromJson(const jsoncons::json
&val,
+ const
redis::IndexFieldMetadata *type) {
+ if (auto numeric [[maybe_unused]] = dynamic_cast<const
redis::NumericFieldMetadata *>(type)) {
+ if (!val.is_number() || val.is_string()) return {Status::NotOK, "json
value cannot be string for numeric fields"};
+ return kqir::MakeValue<kqir::Numeric>(val.as_double());
+ } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
+ if (val.is_string()) {
+ const char delim[] = {tag->separator, '\0'};
+ auto vec = util::Split(val.as_string(), delim);
+ return kqir::MakeValue<kqir::StringArray>(vec);
+ } else if (val.is_array()) {
+ std::vector<std::string> strs;
+ for (size_t i = 0; i < val.size(); ++i) {
+ if (!val[i].is_string())
+ return {Status::NotOK, "json value should be string or array of
strings for tag fields"};
+ strs.push_back(val[i].as_string());
+ }
+ return kqir::MakeValue<kqir::StringArray>(strs);
+ } else {
+ return {Status::NotOK, "json value should be string or array of strings
for tag fields"};
+ }
+ } else if (IsVectorType(type)) {
+ size_t dim = GetVectorDim(type);
+ if (!val.is_array()) return {Status::NotOK, "json value should be array of
numbers for vector fields"};
+ if (dim != val.size()) return {Status::NotOK, "the size of the json array
is not equal to the dim of the vector"};
+ std::vector<double> nums;
+ for (size_t i = 0; i < dim; ++i) {
+ if (!val[i].is_number() || val[i].is_string())
+ return {Status::NotOK, "json value should be array of numbers for
vector fields"};
+ nums.push_back(val[i].as_double());
+ }
+ return kqir::MakeValue<kqir::NumericArray>(nums);
+ } else {
+ return {Status::NotOK, "unknown field type to retrieve"};
+ }
+}
+
+StatusOr<kqir::Value> FieldValueRetriever::ParseFromHash(const std::string
&value,
+ const
redis::IndexFieldMetadata *type) {
+ if (auto numeric [[maybe_unused]] = dynamic_cast<const
redis::NumericFieldMetadata *>(type)) {
+ auto num = GET_OR_RET(ParseFloat(value));
+ return kqir::MakeValue<kqir::Numeric>(num);
+ } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
+ const char delim[] = {tag->separator, '\0'};
+ auto vec = util::Split(value, delim);
+ return kqir::MakeValue<kqir::StringArray>(vec);
+ } else if (IsVectorType(type)) {
+ const size_t dim = GetVectorDim(type);
+ if (value.size() != dim * sizeof(double)) {
+ return {Status::NotOK, "field value is too short or too long to be
parsed as a vector"};
+ }
+ std::vector<double> vec;
+ for (size_t i = 0; i < dim; ++i) {
+ // TODO: care about endian later
+ // TODO: currently only support 64bit floating point
+ vec.push_back(*(reinterpret_cast<const double *>(value.data()) + i));
+ }
+ return kqir::MakeValue<kqir::NumericArray>(vec);
+ } else {
+ return {Status::NotOK, "unknown field type to retrieve"};
+ }
+}
+
StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field,
const redis::IndexFieldMetadata *type) {
if (std::holds_alternative<HashData>(db)) {
auto &[hash, metadata, key] = std::get<HashData>(db);
@@ -71,17 +138,7 @@ StatusOr<kqir::Value>
FieldValueRetriever::Retrieve(std::string_view field, cons
if (s.IsNotFound()) return {Status::NotFound, s.ToString()};
if (!s.ok()) return {Status::NotOK, s.ToString()};
- if (auto numeric [[maybe_unused]] = dynamic_cast<const
redis::NumericFieldMetadata *>(type)) {
- auto num = GET_OR_RET(ParseFloat(value));
- return kqir::MakeValue<kqir::Numeric>(num);
- } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type))
{
- const char delim[] = {tag->separator, '\0'};
- auto vec = util::Split(value, delim);
- return kqir::MakeValue<kqir::StringArray>(vec);
- } else {
- return {Status::NotOK, "unknown field type to retrieve"};
- }
-
+ return ParseFromHash(value, type);
} else if (std::holds_alternative<JsonData>(db)) {
auto &value = std::get<JsonData>(db);
@@ -91,31 +148,7 @@ StatusOr<kqir::Value>
FieldValueRetriever::Retrieve(std::string_view field, cons
return {Status::NotFound, "json value specified by the field (json path)
should exist and be unique"};
auto val = s->value[0];
- if (auto numeric [[maybe_unused]] = dynamic_cast<const
redis::NumericFieldMetadata *>(type)) {
- if (val.is_string()) return {Status::NotOK, "json value cannot be string
for numeric fields"};
- return kqir::MakeValue<kqir::Numeric>(val.as_double());
- } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type))
{
- if (val.is_string()) {
- const char delim[] = {tag->separator, '\0'};
- auto vec = util::Split(val.as_string(), delim);
- return kqir::MakeValue<kqir::StringArray>(vec);
- } else if (val.is_array()) {
- std::vector<std::string> strs;
- for (size_t i = 0; i < val.size(); ++i) {
- if (!val[i].is_string())
- return {Status::NotOK, "json value should be string or array of
strings for tag fields"};
- strs.push_back(val[i].as_string());
- }
- return kqir::MakeValue<kqir::StringArray>(strs);
- } else {
- return {Status::NotOK, "json value should be string or array of
strings for tag fields"};
- }
- } else {
- return {Status::NotOK, "unknown field type to retrieve"};
- }
-
- return Status::OK();
-
+ return ParseFromJson(val, type);
} else {
return {Status::NotOK, "unknown redis data type to retrieve"};
}
diff --git a/src/search/indexer.h b/src/search/indexer.h
index 029944e2..8ffd503b 100644
--- a/src/search/indexer.h
+++ b/src/search/indexer.h
@@ -65,6 +65,9 @@ struct FieldValueRetriever {
explicit FieldValueRetriever(JsonValue json) :
db(std::in_place_type<JsonData>, std::move(json)) {}
StatusOr<kqir::Value> Retrieve(std::string_view field, const
redis::IndexFieldMetadata *type);
+
+ static StatusOr<kqir::Value> ParseFromJson(const jsoncons::json &value,
const redis::IndexFieldMetadata *type);
+ static StatusOr<kqir::Value> ParseFromHash(const std::string &value, const
redis::IndexFieldMetadata *type);
};
struct IndexUpdater {
diff --git a/src/search/value.h b/src/search/value.h
index f1a1e8b7..f3395717 100644
--- a/src/search/value.h
+++ b/src/search/value.h
@@ -40,8 +40,8 @@ using String = std::string; // e.g. a single tag
using NumericArray = std::vector<Numeric>; // used for vector fields
using StringArray = std::vector<String>; // used for tag fields, e.g. a
list for tags
-struct Value : std::variant<Null, Numeric, StringArray> {
- using Base = std::variant<Null, Numeric, StringArray>;
+struct Value : std::variant<Null, Numeric, StringArray, NumericArray> {
+ using Base = std::variant<Null, Numeric, StringArray, NumericArray>;
using Base::Base;
@@ -72,6 +72,9 @@ struct Value : std::variant<Null, Numeric, StringArray> {
} else if (Is<StringArray>()) {
return util::StringJoin(
Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v;
}, sep);
+ } else if (Is<NumericArray>()) {
+ return util::StringJoin(
+ Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return
std::to_string(v); }, sep);
}
__builtin_unreachable();
@@ -87,6 +90,8 @@ struct Value : std::variant<Null, Numeric, StringArray> {
char sep = tag ? tag->separator : ',';
return util::StringJoin(
Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v;
}, std::string(1, sep));
+ } else if (Is<NumericArray>()) {
+ return util::StringJoin(Get<NumericArray>(), [](const auto &v) ->
decltype(auto) { return std::to_string(v); });
}
__builtin_unreachable();