This is an automated email from the ASF dual-hosted git repository. twice pushed a commit to branch value-add-vector in repository https://gitbox.apache.org/repos/asf/kvrocks.git
commit 474f19822321148253b10937b8d600f75442a9b4 Author: PragmaTwice <[email protected]> AuthorDate: Wed Jun 19 19:41:00 2024 +0900 feat(search): add vector type to kqir::Value --- src/search/indexer.cc | 105 +++++++++++++++++++++++++++++++++----------------- src/search/indexer.h | 3 ++ src/search/value.h | 9 ++++- 3 files changed, 79 insertions(+), 38 deletions(-) diff --git a/src/search/indexer.cc b/src/search/indexer.cc index 80fea7a9..7ce0b3d0 100644 --- a/src/search/indexer.cc +++ b/src/search/indexer.cc @@ -57,6 +57,73 @@ StatusOr<FieldValueRetriever> FieldValueRetriever::Create(IndexOnDataType type, } } +// placeholders, remove them after vector indexing is implemented +static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; } +static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; } + +StatusOr<kqir::Value> FieldValueRetriever::ParseFromJson(const jsoncons::json &val, + const redis::IndexFieldMetadata *type) { + if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) { + if (!val.is_number() || val.is_string()) return {Status::NotOK, "json value cannot be string for numeric fields"}; + return kqir::MakeValue<kqir::Numeric>(val.as_double()); + } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) { + if (val.is_string()) { + const char delim[] = {tag->separator, '\0'}; + auto vec = util::Split(val.as_string(), delim); + return kqir::MakeValue<kqir::StringArray>(vec); + } else if (val.is_array()) { + std::vector<std::string> strs; + for (size_t i = 0; i < val.size(); ++i) { + if (!val[i].is_string()) + return {Status::NotOK, "json value should be string or array of strings for tag fields"}; + strs.push_back(val[i].as_string()); + } + return kqir::MakeValue<kqir::StringArray>(strs); + } else { + return {Status::NotOK, "json value should be string or array of strings for tag fields"}; + } + } else if (IsVectorType(type)) { + size_t dim = GetVectorDim(type); + if (!val.is_array()) return {Status::NotOK, "json value should be array of numbers for vector fields"}; + if (dim != val.size()) return {Status::NotOK, "the size of the json array is not equal to the dim of the vector"}; + std::vector<double> nums; + for (size_t i = 0; i < dim; ++i) { + if (!val[i].is_number() || val[i].is_string()) + return {Status::NotOK, "json value should be array of numbers for vector fields"}; + nums.push_back(val[i].as_double()); + } + return kqir::MakeValue<kqir::NumericArray>(nums); + } else { + return {Status::NotOK, "unknown field type to retrieve"}; + } +} + +StatusOr<kqir::Value> FieldValueRetriever::ParseFromHash(const std::string &value, + const redis::IndexFieldMetadata *type) { + if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) { + auto num = GET_OR_RET(ParseFloat(value)); + return kqir::MakeValue<kqir::Numeric>(num); + } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) { + const char delim[] = {tag->separator, '\0'}; + auto vec = util::Split(value, delim); + return kqir::MakeValue<kqir::StringArray>(vec); + } else if (IsVectorType(type)) { + const size_t dim = GetVectorDim(type); + if (value.size() != dim * sizeof(double)) { + return {Status::NotOK, "field value is too short or too long to be parsed as a vector"}; + } + std::vector<double> vec; + for (size_t i = 0; i < dim; ++i) { + // TODO: care about endian later + // TODO: currently only support 64bit floating point + vec.push_back(*(reinterpret_cast<const double *>(value.data()) + i)); + } + return kqir::MakeValue<kqir::NumericArray>(vec); + } else { + return {Status::NotOK, "unknown field type to retrieve"}; + } +} + StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, const redis::IndexFieldMetadata *type) { if (std::holds_alternative<HashData>(db)) { auto &[hash, metadata, key] = std::get<HashData>(db); @@ -71,17 +138,7 @@ StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, cons if (s.IsNotFound()) return {Status::NotFound, s.ToString()}; if (!s.ok()) return {Status::NotOK, s.ToString()}; - if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) { - auto num = GET_OR_RET(ParseFloat(value)); - return kqir::MakeValue<kqir::Numeric>(num); - } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) { - const char delim[] = {tag->separator, '\0'}; - auto vec = util::Split(value, delim); - return kqir::MakeValue<kqir::StringArray>(vec); - } else { - return {Status::NotOK, "unknown field type to retrieve"}; - } - + return ParseFromHash(value, type); } else if (std::holds_alternative<JsonData>(db)) { auto &value = std::get<JsonData>(db); @@ -91,31 +148,7 @@ StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, cons return {Status::NotFound, "json value specified by the field (json path) should exist and be unique"}; auto val = s->value[0]; - if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) { - if (val.is_string()) return {Status::NotOK, "json value cannot be string for numeric fields"}; - return kqir::MakeValue<kqir::Numeric>(val.as_double()); - } else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) { - if (val.is_string()) { - const char delim[] = {tag->separator, '\0'}; - auto vec = util::Split(val.as_string(), delim); - return kqir::MakeValue<kqir::StringArray>(vec); - } else if (val.is_array()) { - std::vector<std::string> strs; - for (size_t i = 0; i < val.size(); ++i) { - if (!val[i].is_string()) - return {Status::NotOK, "json value should be string or array of strings for tag fields"}; - strs.push_back(val[i].as_string()); - } - return kqir::MakeValue<kqir::StringArray>(strs); - } else { - return {Status::NotOK, "json value should be string or array of strings for tag fields"}; - } - } else { - return {Status::NotOK, "unknown field type to retrieve"}; - } - - return Status::OK(); - + return ParseFromJson(val, type); } else { return {Status::NotOK, "unknown redis data type to retrieve"}; } diff --git a/src/search/indexer.h b/src/search/indexer.h index 029944e2..8ffd503b 100644 --- a/src/search/indexer.h +++ b/src/search/indexer.h @@ -65,6 +65,9 @@ struct FieldValueRetriever { explicit FieldValueRetriever(JsonValue json) : db(std::in_place_type<JsonData>, std::move(json)) {} StatusOr<kqir::Value> Retrieve(std::string_view field, const redis::IndexFieldMetadata *type); + + static StatusOr<kqir::Value> ParseFromJson(const jsoncons::json &value, const redis::IndexFieldMetadata *type); + static StatusOr<kqir::Value> ParseFromHash(const std::string &value, const redis::IndexFieldMetadata *type); }; struct IndexUpdater { diff --git a/src/search/value.h b/src/search/value.h index f1a1e8b7..f3395717 100644 --- a/src/search/value.h +++ b/src/search/value.h @@ -40,8 +40,8 @@ using String = std::string; // e.g. a single tag using NumericArray = std::vector<Numeric>; // used for vector fields using StringArray = std::vector<String>; // used for tag fields, e.g. a list for tags -struct Value : std::variant<Null, Numeric, StringArray> { - using Base = std::variant<Null, Numeric, StringArray>; +struct Value : std::variant<Null, Numeric, StringArray, NumericArray> { + using Base = std::variant<Null, Numeric, StringArray, NumericArray>; using Base::Base; @@ -72,6 +72,9 @@ struct Value : std::variant<Null, Numeric, StringArray> { } else if (Is<StringArray>()) { return util::StringJoin( Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; }, sep); + } else if (Is<NumericArray>()) { + return util::StringJoin( + Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return std::to_string(v); }, sep); } __builtin_unreachable(); @@ -87,6 +90,8 @@ struct Value : std::variant<Null, Numeric, StringArray> { char sep = tag ? tag->separator : ','; return util::StringJoin( Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; }, std::string(1, sep)); + } else if (Is<NumericArray>()) { + return util::StringJoin(Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return std::to_string(v); }); } __builtin_unreachable();
