This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 12269d7c feat(search): Add HNSW encoding index & insertion/deletion
algorithm (#2368)
12269d7c is described below
commit 12269d7c36afa933c6eaa9bf75029acb5ebc59b9
Author: Rebecca Zhou <[email protected]>
AuthorDate: Fri Jul 12 21:55:49 2024 -0700
feat(search): Add HNSW encoding index & insertion/deletion algorithm (#2368)
Co-authored-by: hulk <[email protected]>
---
src/search/hnsw_indexer.cc | 552 ++++++++++++++++++++++++++++++++
src/search/hnsw_indexer.h | 115 +++++++
src/search/indexer.cc | 43 ++-
src/search/indexer.h | 2 +
src/search/search_encoding.h | 161 ++++++++++
tests/cppunit/hnsw_index_test.cc | 664 +++++++++++++++++++++++++++++++++++++++
tests/cppunit/hnsw_node_test.cc | 165 ++++++++++
7 files changed, 1693 insertions(+), 9 deletions(-)
diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc
new file mode 100644
index 00000000..f03e4c95
--- /dev/null
+++ b/src/search/hnsw_indexer.cc
@@ -0,0 +1,552 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "hnsw_indexer.h"
+
+#include <fmt/core.h>
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <queue>
+#include <random>
+#include <unordered_set>
+#include <vector>
+
+#include "db_util.h"
+
+namespace redis {
+
+HnswNode::HnswNode(NodeKey key, uint16_t level) : key(std::move(key)),
level(level) {}
+
+StatusOr<HnswNodeFieldMetadata> HnswNode::DecodeMetadata(const SearchKey&
search_key, engine::Storage* storage) const {
+ auto node_index_key = search_key.ConstructHnswNode(level, key);
+ rocksdb::PinnableSlice value;
+ auto s = storage->Get(rocksdb::ReadOptions(),
storage->GetCFHandle(ColumnFamilyID::Search), node_index_key, &value);
+ if (!s.ok()) return {Status::NotOK, s.ToString()};
+
+ HnswNodeFieldMetadata metadata;
+ s = metadata.Decode(&value);
+ if (!s.ok()) return {Status::NotOK, s.ToString()};
+ return metadata;
+}
+
+void HnswNode::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey&
search_key, engine::Storage* storage,
+ rocksdb::WriteBatchBase* batch) const {
+ std::string updated_metadata;
+ node_meta->Encode(&updated_metadata);
+ batch->Put(storage->GetCFHandle(ColumnFamilyID::Search),
search_key.ConstructHnswNode(level, key), updated_metadata);
+}
+
+void HnswNode::DecodeNeighbours(const SearchKey& search_key, engine::Storage*
storage) {
+ neighbours.clear();
+ auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key);
+ util::UniqueIterator iter(storage, storage->DefaultScanOptions(),
ColumnFamilyID::Search);
+ for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) {
+ if (!iter->key().starts_with(edge_prefix)) {
+ break;
+ }
+ auto neighbour_edge = iter->key();
+ neighbour_edge.remove_prefix(edge_prefix.size());
+ Slice neighbour;
+ GetSizedString(&neighbour_edge, &neighbour);
+ neighbours.push_back(neighbour.ToString());
+ }
+}
+
+Status HnswNode::AddNeighbour(const NodeKey& neighbour_key, const SearchKey&
search_key, engine::Storage* storage,
+ rocksdb::WriteBatchBase* batch) const {
+ auto edge_index_key = search_key.ConstructHnswEdge(level, key,
neighbour_key);
+ batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key,
Slice());
+
+ HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key,
storage));
+ node_metadata.num_neighbours++;
+ PutMetadata(&node_metadata, search_key, storage, batch);
+ return Status::OK();
+}
+
+Status HnswNode::RemoveNeighbour(const NodeKey& neighbour_key, const
SearchKey& search_key, engine::Storage* storage,
+ rocksdb::WriteBatchBase* batch) const {
+ auto edge_index_key = search_key.ConstructHnswEdge(level, key,
neighbour_key);
+ auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search),
edge_index_key);
+ if (!s.ok()) {
+ return {Status::NotOK, fmt::format("failed to delete edge, {}",
s.ToString())};
+ }
+
+ HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key,
storage));
+ node_metadata.num_neighbours--;
+ PutMetadata(&node_metadata, search_key, storage, batch);
+ return Status::OK();
+}
+
+Status VectorItem::Create(NodeKey key, const kqir::NumericArray& vector, const
HnswVectorFieldMetadata* metadata,
+ VectorItem* out) {
+ if (metadata->dim != vector.size()) {
+ return {Status::InvalidArgument, "VectorItem's metadata dimension must be
consistent with the vector itself."};
+ }
+
+ *out = VectorItem(std::move(key), vector, metadata);
+ return Status::OK();
+}
+
+Status VectorItem::Create(NodeKey key, kqir::NumericArray&& vector, const
HnswVectorFieldMetadata* metadata,
+ VectorItem* out) {
+ if (metadata->dim != vector.size()) {
+ return {Status::InvalidArgument, "VectorItem's metadata dimension must be
consistent with the vector itself."};
+ }
+
+ *out = VectorItem(std::move(key), std::move(vector), metadata);
+ return Status::OK();
+}
+
+bool VectorItem::operator==(const VectorItem& other) const { return key ==
other.key; }
+
+bool VectorItem::operator<(const VectorItem& other) const { return key <
other.key; }
+
+VectorItem::VectorItem(NodeKey&& key, const kqir::NumericArray& vector, const
HnswVectorFieldMetadata* metadata)
+ : key(std::move(key)), vector(vector), metadata(metadata) {}
+
+VectorItem::VectorItem(NodeKey&& key, kqir::NumericArray&& vector, const
HnswVectorFieldMetadata* metadata)
+ : key(std::move(key)), vector(std::move(vector)), metadata(metadata) {}
+
+StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem&
right) {
+ if (left.metadata->distance_metric != right.metadata->distance_metric ||
left.metadata->dim != right.metadata->dim)
+ return {Status::InvalidArgument, "Vectors must be of the same metric and
dimension to compute distance."};
+
+ auto metric = left.metadata->distance_metric;
+ auto dim = left.metadata->dim;
+
+ switch (metric) {
+ case DistanceMetric::L2: {
+ double dist = 0.0;
+ for (auto i = 0; i < dim; i++) {
+ double diff = left.vector[i] - right.vector[i];
+ dist += diff * diff;
+ }
+ return std::sqrt(dist);
+ }
+ case DistanceMetric::IP: {
+ double dist = 0.0;
+ for (auto i = 0; i < dim; i++) {
+ dist += left.vector[i] * right.vector[i];
+ }
+ return -dist;
+ }
+ case DistanceMetric::COSINE: {
+ double dist = 0.0;
+ double norm_left = 0.0;
+ double norm_right = 0.0;
+ for (auto i = 0; i < dim; i++) {
+ dist += left.vector[i] * right.vector[i];
+ norm_left += left.vector[i] * left.vector[i];
+ norm_right += right.vector[i] * right.vector[i];
+ }
+ auto similarity = dist / std::sqrt(norm_left * norm_right);
+ return 1.0 - similarity;
+ }
+ default:
+ __builtin_unreachable();
+ }
+}
+
+HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata*
vector, engine::Storage* storage)
+ : search_key(search_key),
+ metadata(vector),
+ storage(storage),
+ m_level_normalization_factor(1.0 / std::log(metadata->m)) {
+ std::random_device rand_dev;
+ generator = std::mt19937(rand_dev());
+}
+
+uint16_t HnswIndex::RandomizeLayer() {
+ std::uniform_real_distribution<double> level_dist(0.0, 1.0);
+ double r = level_dist(generator);
+ double log_val = -std::log(r);
+ double layer_val = log_val * m_level_normalization_factor;
+ return static_cast<uint16_t>(std::floor(layer_val));
+}
+
+StatusOr<HnswIndex::NodeKey> HnswIndex::DefaultEntryPoint(uint16_t level)
const {
+ auto prefix = search_key.ConstructHnswLevelNodePrefix(level);
+ util::UniqueIterator it(storage, storage->DefaultScanOptions(),
ColumnFamilyID::Search);
+ it->Seek(prefix);
+
+ Slice node_key;
+ Slice node_key_dst;
+ if (it->Valid() && it->key().starts_with(prefix)) {
+ std::string node_key_str = it->key().ToString().substr(prefix.size());
+ node_key = Slice(node_key_str);
+ if (!GetSizedString(&node_key, &node_key_dst)) {
+ return {Status::NotOK, fmt::format("fail to decode the default node key
layer {}", level)};
+ }
+ return node_key_dst.ToString();
+ }
+ return {Status::NotFound, fmt::format("No node found in layer {}", level)};
+}
+
+StatusOr<std::vector<VectorItem>> HnswIndex::DecodeNodesToVectorItems(const
std::vector<NodeKey>& node_keys,
+ uint16_t
level, const SearchKey& search_key,
+
engine::Storage* storage,
+ const
HnswVectorFieldMetadata* metadata) {
+ std::vector<VectorItem> vector_items;
+ vector_items.reserve(node_keys.size());
+
+ for (const auto& neighbour_key : node_keys) {
+ HnswNode neighbour_node(neighbour_key, level);
+ auto neighbour_metadata_status = neighbour_node.DecodeMetadata(search_key,
storage);
+ if (!neighbour_metadata_status.IsOK()) {
+ continue; // Skip this neighbour if metadata can't be decoded
+ }
+ auto neighbour_metadata = neighbour_metadata_status.GetValue();
+ VectorItem item;
+ GET_OR_RET(VectorItem::Create(neighbour_key,
std::move(neighbour_metadata.vector), metadata, &item));
+ vector_items.emplace_back(std::move(item));
+ }
+ return vector_items;
+}
+
+Status HnswIndex::AddEdge(const NodeKey& node_key1, const NodeKey& node_key2,
uint16_t layer,
+ ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch)
const {
+ auto edge_index_key1 = search_key.ConstructHnswEdge(layer, node_key1,
node_key2);
+ auto s = batch->Put(storage->GetCFHandle(ColumnFamilyID::Search),
edge_index_key1, Slice());
+ if (!s.ok()) {
+ return {Status::NotOK, fmt::format("failed to add edge, {}",
s.ToString())};
+ }
+
+ auto edge_index_key2 = search_key.ConstructHnswEdge(layer, node_key2,
node_key1);
+ s = batch->Put(storage->GetCFHandle(ColumnFamilyID::Search),
edge_index_key2, Slice());
+ if (!s.ok()) {
+ return {Status::NotOK, fmt::format("failed to add edge, {}",
s.ToString())};
+ }
+ return Status::OK();
+}
+
+Status HnswIndex::RemoveEdge(const NodeKey& node_key1, const NodeKey&
node_key2, uint16_t layer,
+ ObserverOrUniquePtr<rocksdb::WriteBatchBase>&
batch) const {
+ auto edge_index_key1 = search_key.ConstructHnswEdge(layer, node_key1,
node_key2);
+ auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search),
edge_index_key1);
+ if (!s.ok()) {
+ return {Status::NotOK, fmt::format("failed to delete edge, {}",
s.ToString())};
+ }
+
+ auto edge_index_key2 = search_key.ConstructHnswEdge(layer, node_key2,
node_key1);
+ s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search),
edge_index_key2);
+ if (!s.ok()) {
+ return {Status::NotOK, fmt::format("failed to delete edge, {}",
s.ToString())};
+ }
+ return Status::OK();
+}
+
+StatusOr<std::vector<VectorItem>> HnswIndex::SelectNeighbors(const VectorItem&
vec,
+ const
std::vector<VectorItem>& vertors,
+ uint16_t layer)
const {
+ std::vector<std::pair<double, VectorItem>> distances;
+ distances.reserve(vertors.size());
+ for (const auto& candidate : vertors) {
+ auto dist = GET_OR_RET(ComputeSimilarity(vec, candidate));
+ distances.emplace_back(dist, candidate);
+ }
+
+ std::sort(distances.begin(), distances.end());
+ std::vector<VectorItem> selected_vs;
+
+ selected_vs.reserve(vertors.size());
+ uint16_t m_max = layer != 0 ? metadata->m : 2 * metadata->m;
+ for (auto i = 0; i < std::min(m_max, (uint16_t)distances.size()); i++) {
+ selected_vs.push_back(distances[i].second);
+ }
+ return selected_vs;
+}
+
+StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const
VectorItem& target_vector,
+ uint32_t ef_runtime,
+ const
std::vector<NodeKey>& entry_points) const {
+ std::vector<VectorItem> candidates;
+ std::unordered_set<NodeKey> visited;
+ std::priority_queue<std::pair<double, VectorItem>,
std::vector<std::pair<double, VectorItem>>, std::greater<>>
+ explore_heap;
+ std::priority_queue<std::pair<double, VectorItem>> result_heap;
+
+ for (const auto& entry_point_key : entry_points) {
+ HnswNode entry_node = HnswNode(entry_point_key, level);
+ auto entry_node_metadata =
GET_OR_RET(entry_node.DecodeMetadata(search_key, storage));
+
+ VectorItem entry_point_vector;
+ GET_OR_RET(
+ VectorItem::Create(entry_point_key,
std::move(entry_node_metadata.vector), metadata, &entry_point_vector));
+ auto dist = GET_OR_RET(ComputeSimilarity(target_vector,
entry_point_vector));
+
+ explore_heap.push(std::make_pair(dist, entry_point_vector));
+ result_heap.push(std::make_pair(dist, std::move(entry_point_vector)));
+ visited.insert(entry_point_key);
+ }
+
+ while (!explore_heap.empty()) {
+ auto [dist, current_vector] = explore_heap.top();
+ explore_heap.pop();
+ if (dist > result_heap.top().first) {
+ break;
+ }
+
+ auto current_node = HnswNode(current_vector.key, level);
+ current_node.DecodeNeighbours(search_key, storage);
+
+ for (const auto& neighbour_key : current_node.neighbours) {
+ if (visited.find(neighbour_key) != visited.end()) {
+ continue;
+ }
+ visited.insert(neighbour_key);
+
+ auto neighbour_node = HnswNode(neighbour_key, level);
+ auto neighbour_node_metadata =
GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
+
+ VectorItem neighbour_node_vector;
+ GET_OR_RET(VectorItem::Create(neighbour_key,
std::move(neighbour_node_metadata.vector), metadata,
+ &neighbour_node_vector));
+
+ auto dist = GET_OR_RET(ComputeSimilarity(target_vector,
neighbour_node_vector));
+ explore_heap.push(std::make_pair(dist, neighbour_node_vector));
+ result_heap.push(std::make_pair(dist, neighbour_node_vector));
+ while (result_heap.size() > ef_runtime) {
+ result_heap.pop();
+ }
+ }
+ }
+
+ while (!result_heap.empty()) {
+ candidates.push_back(result_heap.top().second);
+ result_heap.pop();
+ }
+
+ std::reverse(candidates.begin(), candidates.end());
+ return candidates;
+}
+
+Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const
kqir::NumericArray& vector,
+
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
+ uint16_t target_level) const {
+ auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search);
+ VectorItem inserted_vector_item;
+ GET_OR_RET(VectorItem::Create(std::string(key), vector, metadata,
&inserted_vector_item));
+ std::vector<VectorItem> nearest_vec_items;
+
+ if (metadata->num_levels != 0) {
+ auto level = metadata->num_levels - 1;
+
+ auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level));
+ std::vector<NodeKey> entry_points{default_entry_node};
+
+ for (; level > target_level; level--) {
+ nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item,
metadata->ef_runtime, entry_points));
+ entry_points = {nearest_vec_items[0].key};
+ }
+
+ for (; level >= 0; level--) {
+ nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item,
metadata->ef_construction, entry_points));
+ auto candidate_vec_items =
GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level));
+ auto node = HnswNode(std::string(key), level);
+ auto m_max = level == 0 ? 2 * metadata->m : metadata->m;
+
+ std::unordered_set<NodeKey> connected_edges_set;
+ std::unordered_map<NodeKey, std::unordered_set<NodeKey>>
deleted_edges_map;
+
+ // Check if candidate node has room for more outgoing edges
+ auto has_room_for_more_edges = [&](uint16_t
candidate_node_num_neighbours) {
+ return candidate_node_num_neighbours < m_max;
+ };
+
+ // Check if candidate node has room after some other nodes' are pruned
in current batch
+ auto has_room_after_deletions = [&](const HnswNode& candidate_node,
uint16_t candidate_node_num_neighbours) {
+ auto it = deleted_edges_map.find(candidate_node.key);
+ if (it != deleted_edges_map.end()) {
+ auto num_deleted_edges = static_cast<uint16_t>(it->second.size());
+ return (candidate_node_num_neighbours - num_deleted_edges) < m_max;
+ }
+ return false;
+ };
+
+ for (const auto& candidate_vec : candidate_vec_items) {
+ auto candidate_node = HnswNode(candidate_vec.key, level);
+ auto candidate_node_metadata =
GET_OR_RET(candidate_node.DecodeMetadata(search_key, storage));
+ uint16_t candidate_node_num_neighbours =
candidate_node_metadata.num_neighbours;
+
+ if (has_room_for_more_edges(candidate_node_num_neighbours) ||
+ has_room_after_deletions(candidate_node,
candidate_node_num_neighbours)) {
+ GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key,
level, batch));
+ connected_edges_set.insert(candidate_node.key);
+ continue;
+ }
+
+ // Re-evaluate the neighbours for the candidate node
+ candidate_node.DecodeNeighbours(search_key, storage);
+ auto candidate_node_neighbour_vec_items =
+ GET_OR_RET(DecodeNodesToVectorItems(candidate_node.neighbours,
level, search_key, storage, metadata));
+ candidate_node_neighbour_vec_items.push_back(inserted_vector_item);
+ auto sorted_neighbours_by_distance =
+ GET_OR_RET(SelectNeighbors(candidate_vec,
candidate_node_neighbour_vec_items, level));
+
+ bool inserted_node_is_selected =
+ std::find(sorted_neighbours_by_distance.begin(),
sorted_neighbours_by_distance.end(),
+ inserted_vector_item) !=
sorted_neighbours_by_distance.end();
+
+ if (inserted_node_is_selected) {
+ // Add the edge between candidate and inserted node
+ GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key,
level, batch));
+ connected_edges_set.insert(candidate_node.key);
+
+ auto find_deleted_item = [&](const std::vector<VectorItem>&
candidate_neighbours,
+ const std::vector<VectorItem>&
selected_neighbours) -> VectorItem {
+ auto it =
+ std::find_if(candidate_neighbours.begin(),
candidate_neighbours.end(), [&](const VectorItem& item) {
+ return std::find(selected_neighbours.begin(),
selected_neighbours.end(), item) ==
+ selected_neighbours.end();
+ });
+ return *it;
+ };
+
+ // Remove the edge for candidate and the pruned node
+ auto deleted_node =
find_deleted_item(candidate_node_neighbour_vec_items,
sorted_neighbours_by_distance);
+ GET_OR_RET(RemoveEdge(deleted_node.key, candidate_node.key, level,
batch));
+ deleted_edges_map[candidate_node.key].insert(deleted_node.key);
+ deleted_edges_map[deleted_node.key].insert(candidate_node.key);
+ }
+ }
+
+ // Update inserted node metadata
+ HnswNodeFieldMetadata
node_metadata(static_cast<uint16_t>(connected_edges_set.size()), vector);
+ node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
+
+ // Update modified nodes metadata
+ for (const auto& node_edges : deleted_edges_map) {
+ auto& current_node_key = node_edges.first;
+ auto current_node = HnswNode(current_node_key, level);
+ auto current_node_metadata =
GET_OR_RET(current_node.DecodeMetadata(search_key, storage));
+ auto new_num_neighbours = current_node_metadata.num_neighbours -
node_edges.second.size();
+ if (connected_edges_set.count(current_node_key) != 0) {
+ new_num_neighbours++;
+ connected_edges_set.erase(current_node_key);
+ }
+ current_node_metadata.num_neighbours = new_num_neighbours;
+ current_node.PutMetadata(¤t_node_metadata, search_key, storage,
batch.Get());
+ }
+
+ for (const auto& current_node_key : connected_edges_set) {
+ auto current_node = HnswNode(current_node_key, level);
+ HnswNodeFieldMetadata current_node_metadata =
GET_OR_RET(current_node.DecodeMetadata(search_key, storage));
+ current_node_metadata.num_neighbours++;
+ current_node.PutMetadata(¤t_node_metadata, search_key, storage,
batch.Get());
+ }
+
+ entry_points.clear();
+ for (const auto& new_entry_point : nearest_vec_items) {
+ entry_points.push_back(new_entry_point.key);
+ }
+ }
+ } else {
+ auto node = HnswNode(std::string(key), 0);
+ HnswNodeFieldMetadata node_metadata(0, vector);
+ node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
+ metadata->num_levels = 1;
+ }
+
+ while (target_level > metadata->num_levels - 1) {
+ auto node = HnswNode(std::string(key), metadata->num_levels);
+ HnswNodeFieldMetadata node_metadata(0, vector);
+ node.PutMetadata(&node_metadata, search_key, storage, batch.Get());
+ metadata->num_levels++;
+ }
+
+ std::string encoded_index_metadata;
+ metadata->Encode(&encoded_index_metadata);
+ auto index_meta_key = search_key.ConstructFieldMeta();
+ batch->Put(cf_handle, index_meta_key, encoded_index_metadata);
+
+ return Status::OK();
+}
+
+Status HnswIndex::InsertVectorEntry(std::string_view key, const
kqir::NumericArray& vector,
+
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
+ auto target_level = RandomizeLayer();
+ return InsertVectorEntryInternal(key, vector, batch, target_level);
+}
+
+Status HnswIndex::DeleteVectorEntry(std::string_view key,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
+ std::string node_key(key);
+ for (uint16_t level = 0; level < metadata->num_levels; level++) {
+ auto node = HnswNode(node_key, level);
+ auto node_metadata_status = node.DecodeMetadata(search_key, storage);
+ if (!node_metadata_status.IsOK()) {
+ break;
+ }
+
+ auto node_metadata = std::move(node_metadata_status).GetValue();
+ auto node_index_key = search_key.ConstructHnswNode(level, key);
+ auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search),
node_index_key);
+ if (!s.ok()) {
+ return {Status::NotOK, s.ToString()};
+ }
+
+ node.DecodeNeighbours(search_key, storage);
+ for (const auto& neighbour_key : node.neighbours) {
+ GET_OR_RET(RemoveEdge(node_key, neighbour_key, level, batch));
+ auto neighbour_node = HnswNode(neighbour_key, level);
+ HnswNodeFieldMetadata neighbour_node_metadata =
GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
+ neighbour_node_metadata.num_neighbours--;
+ neighbour_node.PutMetadata(&neighbour_node_metadata, search_key,
storage, batch.Get());
+ }
+ }
+
+ auto has_other_nodes_at_level = [&](uint16_t level, std::string_view
skip_key) -> bool {
+ auto prefix = search_key.ConstructHnswLevelNodePrefix(level);
+ util::UniqueIterator it(storage, storage->DefaultScanOptions(),
ColumnFamilyID::Search);
+ it->Seek(prefix);
+
+ Slice node_key;
+ Slice node_key_dst;
+ while (it->Valid() && it->key().starts_with(prefix)) {
+ std::string node_key_str = it->key().ToString().substr(prefix.size());
+ node_key = Slice(node_key_str);
+ if (!GetSizedString(&node_key, &node_key_dst)) {
+ continue;
+ }
+ if (node_key_dst.ToString() != skip_key) {
+ return true;
+ }
+ it->Next();
+ }
+ return false;
+ };
+
+ while (metadata->num_levels > 0) {
+ if (has_other_nodes_at_level(metadata->num_levels - 1, key)) {
+ break;
+ }
+ metadata->num_levels--;
+ }
+
+ std::string encoded_index_metadata;
+ metadata->Encode(&encoded_index_metadata);
+ auto index_meta_key = search_key.ConstructFieldMeta();
+ batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
encoded_index_metadata);
+
+ return Status::OK();
+}
+
+} // namespace redis
diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h
new file mode 100644
index 00000000..30bdf94a
--- /dev/null
+++ b/src/search/hnsw_indexer.h
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <random>
+#include <string>
+#include <vector>
+
+#include "search/indexer.h"
+#include "search/search_encoding.h"
+#include "search/value.h"
+#include "storage/storage.h"
+
+namespace redis {
+
+class HnswIndex;
+
+struct HnswNode {
+ using NodeKey = std::string;
+ NodeKey key;
+ uint16_t level{};
+ std::vector<NodeKey> neighbours;
+
+ HnswNode(NodeKey key, uint16_t level);
+
+ StatusOr<HnswNodeFieldMetadata> DecodeMetadata(const SearchKey& search_key,
engine::Storage* storage) const;
+ void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey&
search_key, engine::Storage* storage,
+ rocksdb::WriteBatchBase* batch) const;
+ void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage);
+
+ // For testing purpose
+ Status AddNeighbour(const NodeKey& neighbour_key, const SearchKey&
search_key, engine::Storage* storage,
+ rocksdb::WriteBatchBase* batch) const;
+ Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey&
search_key, engine::Storage* storage,
+ rocksdb::WriteBatchBase* batch) const;
+ friend class HnswIndex;
+};
+
+struct VectorItem {
+ using NodeKey = HnswNode::NodeKey;
+
+ NodeKey key;
+ kqir::NumericArray vector;
+ const HnswVectorFieldMetadata* metadata;
+
+ VectorItem() : metadata(nullptr) {}
+
+ static Status Create(NodeKey key, const kqir::NumericArray& vector, const
HnswVectorFieldMetadata* metadata,
+ VectorItem* out);
+ static Status Create(NodeKey key, kqir::NumericArray&& vector, const
HnswVectorFieldMetadata* metadata,
+ VectorItem* out);
+
+ bool operator==(const VectorItem& other) const;
+ bool operator<(const VectorItem& other) const;
+
+ private:
+ VectorItem(NodeKey&& key, const kqir::NumericArray& vector, const
HnswVectorFieldMetadata* metadata);
+ VectorItem(NodeKey&& key, kqir::NumericArray&& vector, const
HnswVectorFieldMetadata* metadata);
+};
+
+StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem&
right);
+
+struct HnswIndex {
+ using NodeKey = HnswNode::NodeKey;
+
+ SearchKey search_key;
+ HnswVectorFieldMetadata* metadata;
+ engine::Storage* storage = nullptr;
+
+ std::mt19937 generator;
+ double m_level_normalization_factor;
+
+ HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector,
engine::Storage* storage);
+
+ static StatusOr<std::vector<VectorItem>> DecodeNodesToVectorItems(const
std::vector<NodeKey>& node_key,
+ uint16_t
level, const SearchKey& search_key,
+
engine::Storage* storage,
+ const
HnswVectorFieldMetadata* metadata);
+ uint16_t RandomizeLayer();
+ StatusOr<NodeKey> DefaultEntryPoint(uint16_t level) const;
+ Status AddEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t
layer,
+ ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
+ Status RemoveEdge(const NodeKey& node_key1, const NodeKey& node_key2,
uint16_t layer,
+ ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
+
+ StatusOr<std::vector<VectorItem>> SelectNeighbors(const VectorItem& vec,
const std::vector<VectorItem>& vectors,
+ uint16_t layer) const;
+ StatusOr<std::vector<VectorItem>> SearchLayer(uint16_t level, const
VectorItem& target_vector, uint32_t ef_runtime,
+ const std::vector<NodeKey>&
entry_points) const;
+ Status InsertVectorEntryInternal(std::string_view key, const
kqir::NumericArray& vector,
+
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, uint16_t layer) const;
+ Status InsertVectorEntry(std::string_view key, const kqir::NumericArray&
vector,
+ ObserverOrUniquePtr<rocksdb::WriteBatchBase>&
batch);
+ Status DeleteVectorEntry(std::string_view key,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
+};
+
+} // namespace redis
diff --git a/src/search/indexer.cc b/src/search/indexer.cc
index 1212dd2f..576de073 100644
--- a/src/search/indexer.cc
+++ b/src/search/indexer.cc
@@ -25,6 +25,7 @@
#include "db_util.h"
#include "parse_util.h"
+#include "search/hnsw_indexer.h"
#include "search/search_encoding.h"
#include "search/value.h"
#include "storage/redis_metadata.h"
@@ -57,10 +58,6 @@ StatusOr<FieldValueRetriever>
FieldValueRetriever::Create(IndexOnDataType type,
}
}
-// placeholders, remove them after vector indexing is implemented
-static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; }
-static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; }
-
StatusOr<kqir::Value> FieldValueRetriever::ParseFromJson(const jsoncons::json
&val,
const
redis::IndexFieldMetadata *type) {
if (auto numeric [[maybe_unused]] = dynamic_cast<const
redis::NumericFieldMetadata *>(type)) {
@@ -82,8 +79,8 @@ StatusOr<kqir::Value>
FieldValueRetriever::ParseFromJson(const jsoncons::json &v
} else {
return {Status::NotOK, "json value should be string or array of strings
for tag fields"};
}
- } else if (IsVectorType(type)) {
- size_t dim = GetVectorDim(type);
+ } else if (auto vector = dynamic_cast<const redis::HnswVectorFieldMetadata
*>(type)) {
+ const auto dim = vector->dim;
if (!val.is_array()) return {Status::NotOK, "json value should be array of
numbers for vector fields"};
if (dim != val.size()) return {Status::NotOK, "the size of the json array
is not equal to the dim of the vector"};
std::vector<double> nums;
@@ -107,8 +104,8 @@ StatusOr<kqir::Value>
FieldValueRetriever::ParseFromHash(const std::string &valu
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(value, delim);
return kqir::MakeValue<kqir::StringArray>(vec);
- } else if (IsVectorType(type)) {
- const size_t dim = GetVectorDim(type);
+ } else if (auto vector = dynamic_cast<const redis::HnswVectorFieldMetadata
*>(type)) {
+ const auto dim = vector->dim;
if (value.size() != dim * sizeof(double)) {
return {Status::NotOK, "field value is too short or too long to be
parsed as a vector"};
}
@@ -246,7 +243,7 @@ Status IndexUpdater::UpdateTagIndex(std::string_view key,
const kqir::Value &ori
Status IndexUpdater::UpdateNumericIndex(std::string_view key, const
kqir::Value &original, const kqir::Value ¤t,
const SearchKey &search_key, const
NumericFieldMetadata *num) const {
CHECK(original.IsNull() || original.Is<kqir::Numeric>());
- CHECK(original.IsNull() || original.Is<kqir::Numeric>());
+ CHECK(current.IsNull() || current.Is<kqir::Numeric>());
auto *storage = indexer->storage;
auto batch = storage->GetWriteBatchBase();
@@ -269,6 +266,32 @@ Status IndexUpdater::UpdateNumericIndex(std::string_view
key, const kqir::Value
return Status::OK();
}
+Status IndexUpdater::UpdateHnswVectorIndex(std::string_view key, const
kqir::Value &original,
+ const kqir::Value ¤t, const
SearchKey &search_key,
+ HnswVectorFieldMetadata *vector)
const {
+ CHECK(original.IsNull() || original.Is<kqir::NumericArray>());
+ CHECK(current.IsNull() || current.Is<kqir::NumericArray>());
+
+ auto storage = indexer->storage;
+ auto hnsw = HnswIndex(search_key, vector, storage);
+
+ if (!original.IsNull()) {
+ auto batch = storage->GetWriteBatchBase();
+ GET_OR_RET(hnsw.DeleteVectorEntry(key, batch));
+ auto s = storage->Write(storage->DefaultWriteOptions(),
batch->GetWriteBatch());
+ if (!s.ok()) return {Status::NotOK, s.ToString()};
+ }
+
+ if (!current.IsNull()) {
+ auto batch = storage->GetWriteBatchBase();
+ GET_OR_RET(hnsw.InsertVectorEntry(key, current.Get<kqir::NumericArray>(),
batch));
+ auto s = storage->Write(storage->DefaultWriteOptions(),
batch->GetWriteBatch());
+ if (!s.ok()) return {Status::NotOK, s.ToString()};
+ }
+
+ return Status::OK();
+}
+
Status IndexUpdater::UpdateIndex(const std::string &field, std::string_view
key, const kqir::Value &original,
const kqir::Value ¤t) const {
if (original == current) {
@@ -287,6 +310,8 @@ Status IndexUpdater::UpdateIndex(const std::string &field,
std::string_view key,
GET_OR_RET(UpdateTagIndex(key, original, current, search_key, tag));
} else if (auto numeric [[maybe_unused]] = dynamic_cast<NumericFieldMetadata
*>(metadata)) {
GET_OR_RET(UpdateNumericIndex(key, original, current, search_key,
numeric));
+ } else if (auto vector = dynamic_cast<HnswVectorFieldMetadata *>(metadata)) {
+ GET_OR_RET(UpdateHnswVectorIndex(key, original, current, search_key,
vector));
} else {
return {Status::NotOK, "Unexpected field type"};
}
diff --git a/src/search/indexer.h b/src/search/indexer.h
index 8ffd503b..e5e0aa4f 100644
--- a/src/search/indexer.h
+++ b/src/search/indexer.h
@@ -89,6 +89,8 @@ struct IndexUpdater {
const SearchKey &search_key, const TagFieldMetadata
*tag) const;
Status UpdateNumericIndex(std::string_view key, const kqir::Value &original,
const kqir::Value ¤t,
const SearchKey &search_key, const
NumericFieldMetadata *num) const;
+ Status UpdateHnswVectorIndex(std::string_view key, const kqir::Value
&original, const kqir::Value ¤t,
+ const SearchKey &search_key,
HnswVectorFieldMetadata *vector) const;
};
struct GlobalIndexer {
diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h
index 68e248bb..2fbbde8c 100644
--- a/src/search/search_encoding.h
+++ b/src/search/search_encoding.h
@@ -33,6 +33,7 @@ enum class IndexOnDataType : uint8_t {
};
inline constexpr auto kErrorInsufficientLength = "insufficient length while
decoding metadata";
+inline constexpr auto kErrorIncorrectLength = "length is too short or too long
to be parsed as a vector";
class IndexMetadata {
public:
@@ -76,6 +77,23 @@ enum class IndexFieldType : uint8_t {
TAG = 1,
NUMERIC = 2,
+
+ VECTOR = 3,
+};
+
+enum class VectorType : uint8_t {
+ FLOAT64 = 1,
+};
+
+enum class DistanceMetric : uint8_t {
+ L2 = 0,
+ IP = 1,
+ COSINE = 2,
+};
+
+enum class HnswLevelType : uint8_t {
+ NODE = 1,
+ EDGE = 2,
};
struct SearchKey {
@@ -95,6 +113,26 @@ struct SearchKey {
void PutIndex(std::string *dst) const { PutSizedString(dst, index); }
+ static void PutHnswLevelType(std::string *dst, HnswLevelType type) {
PutFixed8(dst, uint8_t(type)); }
+
+ void PutHnswLevelPrefix(std::string *dst, uint16_t level) const {
+ PutNamespace(dst);
+ PutType(dst, SearchSubkeyType::FIELD);
+ PutIndex(dst);
+ PutSizedString(dst, field);
+ PutFixed16(dst, level);
+ }
+
+ void PutHnswLevelNodePrefix(std::string *dst, uint16_t level) const {
+ PutHnswLevelPrefix(dst, level);
+ PutHnswLevelType(dst, HnswLevelType::NODE);
+ }
+
+ void PutHnswLevelEdgePrefix(std::string *dst, uint16_t level) const {
+ PutHnswLevelPrefix(dst, level);
+ PutHnswLevelType(dst, HnswLevelType::EDGE);
+ }
+
std::string ConstructIndexMeta() const {
std::string dst;
PutNamespace(&dst);
@@ -177,6 +215,34 @@ struct SearchKey {
PutSizedString(&dst, key);
return dst;
}
+
+ std::string ConstructHnswLevelNodePrefix(uint16_t level) const {
+ std::string dst;
+ PutHnswLevelNodePrefix(&dst, level);
+ return dst;
+ }
+
+ std::string ConstructHnswNode(uint16_t level, std::string_view key) const {
+ std::string dst;
+ PutHnswLevelNodePrefix(&dst, level);
+ PutSizedString(&dst, key);
+ return dst;
+ }
+
+ std::string ConstructHnswEdgeWithSingleEnd(uint16_t level, std::string_view
key) const {
+ std::string dst;
+ PutHnswLevelEdgePrefix(&dst, level);
+ PutSizedString(&dst, key);
+ return dst;
+ }
+
+ std::string ConstructHnswEdge(uint16_t level, std::string_view key1,
std::string_view key2) const {
+ std::string dst;
+ PutHnswLevelEdgePrefix(&dst, level);
+ PutSizedString(&dst, key1);
+ PutSizedString(&dst, key2);
+ return dst;
+ }
};
struct IndexPrefixes {
@@ -236,6 +302,8 @@ struct IndexFieldMetadata {
return "tag";
case IndexFieldType::NUMERIC:
return "numeric";
+ case IndexFieldType::VECTOR:
+ return "vector";
default:
return "unknown";
}
@@ -291,6 +359,96 @@ struct NumericFieldMetadata : IndexFieldMetadata {
bool IsSortable() const override { return true; }
};
+struct HnswVectorFieldMetadata : IndexFieldMetadata {
+ VectorType vector_type;
+ uint16_t dim;
+ DistanceMetric distance_metric;
+
+ uint32_t initial_cap = 500000; // Initial vector capacity
+ uint16_t m = 16; // Max allowed outgoing edges per node
+ uint32_t ef_construction = 200; // Max potential outgoing edge candidates
during construction
+ uint32_t ef_runtime = 10; // Max top candidates held during KNN search
+ double epsilon = 0.01; // Relative factor setting search
boundaries in range queries
+ uint16_t num_levels = 0; // Number of levels in the HNSW graph
+
+ HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {}
+
+ void Encode(std::string *dst) const override {
+ IndexFieldMetadata::Encode(dst);
+ PutFixed8(dst, uint8_t(vector_type));
+ PutFixed16(dst, dim);
+ PutFixed8(dst, uint8_t(distance_metric));
+ PutFixed32(dst, initial_cap);
+ PutFixed16(dst, m);
+ PutFixed32(dst, ef_construction);
+ PutFixed32(dst, ef_runtime);
+ PutDouble(dst, epsilon);
+ PutFixed16(dst, num_levels);
+ }
+
+ rocksdb::Status Decode(Slice *input) override {
+ if (auto s = IndexFieldMetadata::Decode(input); !s.ok()) {
+ return s;
+ }
+
+ constexpr size_t required_size = sizeof(uint8_t) + sizeof(uint16_t) +
sizeof(uint8_t) + sizeof(uint32_t) +
+ sizeof(uint16_t) + sizeof(uint32_t) +
sizeof(uint32_t) + sizeof(uint64_t) +
+ sizeof(uint16_t);
+
+ if (input->size() < required_size) {
+ return rocksdb::Status::Corruption(kErrorInsufficientLength);
+ }
+
+ GetFixed8(input, (uint8_t *)(&vector_type));
+ GetFixed16(input, &dim);
+ GetFixed8(input, (uint8_t *)(&distance_metric));
+ GetFixed32(input, &initial_cap);
+ GetFixed16(input, &m);
+ GetFixed32(input, &ef_construction);
+ GetFixed32(input, &ef_runtime);
+ GetDouble(input, &epsilon);
+ GetFixed16(input, &num_levels);
+ return rocksdb::Status::OK();
+ }
+};
+
+struct HnswNodeFieldMetadata {
+ uint16_t num_neighbours;
+ std::vector<double> vector;
+
+ HnswNodeFieldMetadata() = default;
+ HnswNodeFieldMetadata(uint16_t num_neighbours, std::vector<double> vector)
+ : num_neighbours(num_neighbours), vector(std::move(vector)) {}
+
+ void Encode(std::string *dst) const {
+ PutFixed16(dst, num_neighbours);
+ PutFixed16(dst, static_cast<uint16_t>(vector.size()));
+ for (double element : vector) {
+ PutDouble(dst, element);
+ }
+ }
+
+ rocksdb::Status Decode(Slice *input) {
+ if (input->size() < 2 + 2) {
+ return rocksdb::Status::Corruption(kErrorInsufficientLength);
+ }
+ GetFixed16(input, (uint16_t *)(&num_neighbours));
+
+ uint16_t dim = 0;
+ GetFixed16(input, (uint16_t *)(&dim));
+
+ if (input->size() != dim * sizeof(double)) {
+ return rocksdb::Status::Corruption(kErrorIncorrectLength);
+ }
+ vector.resize(dim);
+
+ for (auto i = 0; i < dim; ++i) {
+ GetDouble(input, &vector[i]);
+ }
+ return rocksdb::Status::OK();
+ }
+};
+
inline rocksdb::Status IndexFieldMetadata::Decode(Slice *input,
std::unique_ptr<IndexFieldMetadata> &ptr) {
if (input->size() < 1) {
return rocksdb::Status::Corruption(kErrorInsufficientLength);
@@ -303,6 +461,9 @@ inline rocksdb::Status IndexFieldMetadata::Decode(Slice
*input, std::unique_ptr<
case IndexFieldType::NUMERIC:
ptr = std::make_unique<NumericFieldMetadata>();
break;
+ case IndexFieldType::VECTOR:
+ ptr = std::make_unique<HnswVectorFieldMetadata>();
+ break;
default:
return rocksdb::Status::Corruption("encountered unknown field type");
}
diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc
new file mode 100644
index 00000000..e09e9830
--- /dev/null
+++ b/tests/cppunit/hnsw_index_test.cc
@@ -0,0 +1,664 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include <gtest/gtest.h>
+#include <test_base.h>
+
+#include <iostream>
+#include <memory>
+#include <random>
+
+#include "search/hnsw_indexer.h"
+#include "search/indexer.h"
+#include "search/search_encoding.h"
+#include "search/value.h"
+#include "storage/storage.h"
+
+struct HnswIndexTest : TestBase {
+ redis::HnswVectorFieldMetadata metadata;
+ std::string ns = "hnsw_test_ns";
+ std::string idx_name = "hnsw_test_idx";
+ std::string key = "vector";
+ std::unique_ptr<redis::HnswIndex> hnsw_index;
+
+ HnswIndexTest() {
+ metadata.vector_type = redis::VectorType::FLOAT64;
+ metadata.dim = 3;
+ metadata.m = 3;
+ metadata.distance_metric = redis::DistanceMetric::L2;
+ auto search_key = redis::SearchKey(ns, idx_name, key);
+ hnsw_index = std::make_unique<redis::HnswIndex>(search_key, &metadata,
storage_.get());
+ }
+
+ void TearDown() override { hnsw_index.reset(); }
+};
+
+TEST_F(HnswIndexTest, ComputeSimilarity) {
+ redis::VectorItem vec1;
+ auto status1 = redis::VectorItem::Create("1", {1.0, 1.2, 1.4},
hnsw_index->metadata, &vec1);
+ ASSERT_TRUE(status1.IsOK());
+ redis::VectorItem vec2;
+ auto status2 = redis::VectorItem::Create("2", {3.0, 3.2, 3.4},
hnsw_index->metadata, &vec2);
+ ASSERT_TRUE(status2.IsOK());
+ redis::VectorItem vec3; // identical to vec1
+ auto status3 = redis::VectorItem::Create("3", {1.0, 1.2, 1.4},
hnsw_index->metadata, &vec3);
+ ASSERT_TRUE(status3.IsOK());
+
+ auto s1 = redis::ComputeSimilarity(vec1, vec3);
+ ASSERT_TRUE(s1.IsOK());
+ double similarity = s1.GetValue();
+ EXPECT_EQ(similarity, 0.0);
+
+ auto s2 = redis::ComputeSimilarity(vec1, vec2);
+ ASSERT_TRUE(s2.IsOK());
+ similarity = s2.GetValue();
+ EXPECT_NEAR(similarity, std::sqrt(12), 1e-5);
+
+ hnsw_index->metadata->distance_metric = redis::DistanceMetric::IP;
+ auto s3 = redis::ComputeSimilarity(vec1, vec2);
+ ASSERT_TRUE(s3.IsOK());
+ similarity = s3.GetValue();
+ EXPECT_NEAR(similarity, -(1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4), 1e-5);
+
+ hnsw_index->metadata->distance_metric = redis::DistanceMetric::COSINE;
+ double expected_res = (1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4) /
+ std::sqrt((1.0 * 1.0 + 1.2 * 1.2 + 1.4 * 1.4) * (3.0 *
3.0 + 3.2 * 3.2 + 3.4 * 3.4));
+ auto s4 = redis::ComputeSimilarity(vec1, vec2);
+ ASSERT_TRUE(s4.IsOK());
+ similarity = s4.GetValue();
+ EXPECT_NEAR(similarity, 1 - expected_res, 1e-5);
+
+ hnsw_index->metadata->distance_metric = redis::DistanceMetric::L2;
+}
+
+TEST_F(HnswIndexTest, RandomizeLayer) {
+ constexpr size_t kSampleSize = 50000;
+
+ std::vector<uint16_t> layers;
+ layers.reserve(kSampleSize);
+
+ for (size_t i = 0; i < kSampleSize; ++i) {
+ layers.push_back(hnsw_index->RandomizeLayer());
+ EXPECT_GE(layers.back(), 0);
+ }
+
+ std::map<uint16_t, size_t> layer_frequency;
+ for (const auto& layer : layers) {
+ layer_frequency[layer]++;
+ }
+
+ uint16_t max_observed_layer = 0;
+ for (const auto& [layer, freq] : layer_frequency) {
+ // std::cout << "Layer: " << layer << " Frequency: " << freq << std::endl;
+ if (layer > max_observed_layer) {
+ max_observed_layer = layer;
+ }
+ }
+
+ // Calculate expected frequencies for each layer based on the theoretical
distribution
+ std::vector<double> expected_frequencies(max_observed_layer + 1, 0);
+ double normalization_factor = 1.0 / std::log(hnsw_index->metadata->m);
+ double total_probability = 0.0;
+
+ for (uint16_t i = 0; i <= max_observed_layer; ++i) {
+ total_probability += std::exp(-i / normalization_factor);
+ }
+
+ for (uint16_t i = 0; i <= max_observed_layer; ++i) {
+ double probability = std::exp(-i / normalization_factor) /
total_probability;
+ expected_frequencies[i] = kSampleSize * probability;
+ }
+
+ for (const auto& [layer, freq] : layer_frequency) {
+ if (layer < expected_frequencies.size() / 3) {
+ double expected_freq = expected_frequencies[layer];
+ double deviation = std::abs(static_cast<double>(freq) - expected_freq) /
expected_freq;
+ EXPECT_LE(deviation, 0.1) << "Layer: " << layer << " Frequency: " <<
freq << " Expected: " << expected_freq;
+ }
+ }
+}
+
+TEST_F(HnswIndexTest, DefaultEntryPointNotFound) {
+ auto initial_result = hnsw_index->DefaultEntryPoint(0);
+ ASSERT_EQ(initial_result.GetCode(), Status::NotFound);
+}
+
+TEST_F(HnswIndexTest, DecodeNodesToVectorItems) {
+ uint16_t layer = 1;
+ std::string node_key1 = "node1";
+ std::string node_key2 = "node2";
+ std::string node_key3 = "node3";
+
+ redis::HnswNode node1(node_key1, layer);
+ redis::HnswNode node2(node_key2, layer);
+ redis::HnswNode node3(node_key3, layer);
+
+ redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3});
+ redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6});
+ redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9});
+
+ auto batch = storage_->GetWriteBatchBase();
+ node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ auto s = storage_->Write(storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ std::vector<std::string> keys = {node_key1, node_key2, node_key3};
+
+ auto s1 = hnsw_index->DecodeNodesToVectorItems(keys, layer,
hnsw_index->search_key, hnsw_index->storage,
+ hnsw_index->metadata);
+ ASSERT_TRUE(s1.IsOK());
+ auto vector_items = s1.GetValue();
+ ASSERT_EQ(vector_items.size(), 3);
+ EXPECT_EQ(vector_items[0].key, node_key1);
+ EXPECT_EQ(vector_items[1].key, node_key2);
+ EXPECT_EQ(vector_items[2].key, node_key3);
+ EXPECT_TRUE(vector_items[0].vector == std::vector<double>({1, 2, 3}));
+ EXPECT_TRUE(vector_items[1].vector == std::vector<double>({4, 5, 6}));
+ EXPECT_TRUE(vector_items[2].vector == std::vector<double>({7, 8, 9}));
+}
+
+TEST_F(HnswIndexTest, SelectNeighbors) {
+ redis::VectorItem vec1;
+ auto status1 = redis::VectorItem::Create("1", {1.0, 1.0, 1.0},
hnsw_index->metadata, &vec1);
+ ASSERT_TRUE(status1.IsOK());
+
+ redis::VectorItem vec2;
+ auto status2 = redis::VectorItem::Create("2", {2.0, 2.0, 2.0},
hnsw_index->metadata, &vec2);
+ ASSERT_TRUE(status2.IsOK());
+
+ redis::VectorItem vec3;
+ auto status3 = redis::VectorItem::Create("3", {3.0, 3.0, 3.0},
hnsw_index->metadata, &vec3);
+ ASSERT_TRUE(status3.IsOK());
+
+ redis::VectorItem vec4;
+ auto status4 = redis::VectorItem::Create("4", {4.0, 4.0, 4.0},
hnsw_index->metadata, &vec4);
+ ASSERT_TRUE(status4.IsOK());
+
+ redis::VectorItem vec5;
+ auto status5 = redis::VectorItem::Create("5", {5.0, 5.0, 5.0},
hnsw_index->metadata, &vec5);
+ ASSERT_TRUE(status5.IsOK());
+
+ redis::VectorItem vec6;
+ auto status6 = redis::VectorItem::Create("6", {6.0, 6.0, 6.0},
hnsw_index->metadata, &vec6);
+ ASSERT_TRUE(status6.IsOK());
+
+ redis::VectorItem vec7;
+ auto status7 = redis::VectorItem::Create("7", {7.0, 7.0, 7.0},
hnsw_index->metadata, &vec7);
+ ASSERT_TRUE(status7.IsOK());
+
+ std::vector<redis::VectorItem> candidates = {vec3, vec2};
+ auto s1 = hnsw_index->SelectNeighbors(vec1, candidates, 1);
+ ASSERT_TRUE(s1.IsOK());
+ auto selected = s1.GetValue();
+ EXPECT_EQ(selected.size(), candidates.size());
+
+ EXPECT_EQ(selected[0].key, vec2.key);
+ EXPECT_EQ(selected[1].key, vec3.key);
+
+ candidates = {vec4, vec2, vec5, vec7, vec3, vec6};
+ auto s2 = hnsw_index->SelectNeighbors(vec1, candidates, 1);
+ ASSERT_TRUE(s2.IsOK());
+ selected = s2.GetValue();
+ EXPECT_EQ(selected.size(), 3);
+
+ EXPECT_EQ(selected[0].key, vec2.key);
+ EXPECT_EQ(selected[1].key, vec3.key);
+ EXPECT_EQ(selected[2].key, vec4.key);
+
+ candidates = {vec4, vec2, vec5, vec7, vec3, vec6};
+ auto s3 = hnsw_index->SelectNeighbors(vec1, candidates, 0);
+ ASSERT_TRUE(s3.IsOK());
+ selected = s3.GetValue();
+ EXPECT_EQ(selected.size(), 6);
+
+ EXPECT_EQ(selected[0].key, vec2.key);
+ EXPECT_EQ(selected[1].key, vec3.key);
+ EXPECT_EQ(selected[2].key, vec4.key);
+ EXPECT_EQ(selected[3].key, vec5.key);
+ EXPECT_EQ(selected[4].key, vec6.key);
+ EXPECT_EQ(selected[5].key, vec7.key);
+}
+
+TEST_F(HnswIndexTest, SearchLayer) {
+ uint16_t layer = 3;
+ std::string node_key1 = "node1";
+ std::string node_key2 = "node2";
+ std::string node_key3 = "node3";
+ std::string node_key4 = "node4";
+ std::string node_key5 = "node5";
+
+ redis::HnswNode node1(node_key1, layer);
+ redis::HnswNode node2(node_key2, layer);
+ redis::HnswNode node3(node_key3, layer);
+ redis::HnswNode node4(node_key4, layer);
+ redis::HnswNode node5(node_key5, layer);
+
+ redis::HnswNodeFieldMetadata metadata1(0, {1.0, 2.0, 3.0});
+ redis::HnswNodeFieldMetadata metadata2(0, {4.0, 5.0, 6.0});
+ redis::HnswNodeFieldMetadata metadata3(0, {7.0, 8.0, 9.0});
+ redis::HnswNodeFieldMetadata metadata4(0, {2.0, 3.0, 4.0});
+ redis::HnswNodeFieldMetadata metadata5(0, {6.0, 6.0, 7.0});
+
+ // Add Nodes
+ auto batch = storage_->GetWriteBatchBase();
+ node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ node4.PutMetadata(&metadata4, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ node5.PutMetadata(&metadata5, hnsw_index->search_key, hnsw_index->storage,
batch.Get());
+ auto s = storage_->Write(storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ // Add Neighbours
+ batch = storage_->GetWriteBatchBase();
+ auto s1 = node1.AddNeighbour("node2", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s1.IsOK());
+ auto s2 = node1.AddNeighbour("node4", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s2.IsOK());
+ auto s3 = node2.AddNeighbour("node1", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s3.IsOK());
+ auto s4 = node2.AddNeighbour("node3", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s1.IsOK());
+ auto s5 = node3.AddNeighbour("node2", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s5.IsOK());
+ auto s6 = node3.AddNeighbour("node5", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s6.IsOK());
+ auto s7 = node4.AddNeighbour("node1", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s7.IsOK());
+ auto s8 = node5.AddNeighbour("node3", hnsw_index->search_key,
hnsw_index->storage, batch.Get());
+ ASSERT_TRUE(s8.IsOK());
+ s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ redis::VectorItem target_vector;
+ auto status = redis::VectorItem::Create("target", {2.0, 3.0, 4.0},
hnsw_index->metadata, &target_vector);
+ ASSERT_TRUE(status.IsOK());
+
+ // Test with multiple entry points
+ std::vector<std::string> entry_points = {"node3", "node2"};
+ uint32_t ef_runtime = 3;
+
+ auto s9 = hnsw_index->SearchLayer(layer, target_vector, ef_runtime,
entry_points);
+ ASSERT_TRUE(s9.IsOK());
+ auto candidates = s9.GetValue();
+
+ ASSERT_EQ(candidates.size(), ef_runtime);
+ EXPECT_EQ(candidates[0].key, "node4");
+ EXPECT_EQ(candidates[1].key, "node1");
+ EXPECT_EQ(candidates[2].key, "node2");
+
+ // Test with a single entry point
+ entry_points = {"node5"};
+ auto s10 = hnsw_index->SearchLayer(layer, target_vector, ef_runtime,
entry_points);
+ ASSERT_TRUE(s10.IsOK());
+ candidates = s10.GetValue();
+
+ ASSERT_EQ(candidates.size(), ef_runtime);
+ EXPECT_EQ(candidates[0].key, "node4");
+ EXPECT_EQ(candidates[1].key, "node1");
+ EXPECT_EQ(candidates[2].key, "node2");
+
+ // Test with different ef_runtime
+ ef_runtime = 10;
+ auto s11 = hnsw_index->SearchLayer(layer, target_vector, ef_runtime,
entry_points);
+ ASSERT_TRUE(s11.IsOK());
+ candidates = s11.GetValue();
+
+ ASSERT_EQ(candidates.size(), 5);
+ EXPECT_EQ(candidates[0].key, "node4");
+ EXPECT_EQ(candidates[1].key, "node1");
+ EXPECT_EQ(candidates[2].key, "node2");
+ EXPECT_EQ(candidates[3].key, "node5");
+ EXPECT_EQ(candidates[4].key, "node3");
+}
+
+TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) {
+ std::vector<double> vec1 = {11.0, 12.0, 13.0};
+ std::vector<double> vec2 = {14.0, 15.0, 16.0};
+ std::vector<double> vec3 = {17.0, 18.0, 19.0};
+ std::vector<double> vec4 = {12.0, 13.0, 14.0};
+ std::vector<double> vec5 = {15.0, 16.0, 17.0};
+
+ std::string key1 = "n1";
+ std::string key2 = "n2";
+ std::string key3 = "n3";
+ std::string key4 = "n4";
+ std::string key5 = "n5";
+
+ // Insert n1 into layer 1
+ uint16_t target_level = 1;
+ auto batch = storage_->GetWriteBatchBase();
+ auto s1 = hnsw_index->InsertVectorEntryInternal(key1, vec1, batch,
target_level);
+ ASSERT_TRUE(s1.IsOK());
+ auto s = storage_->Write(storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ rocksdb::PinnableSlice value;
+ auto index_meta_key = hnsw_index->search_key.ConstructFieldMeta();
+ s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
+ &value);
+ ASSERT_TRUE(s.ok());
+ redis::HnswVectorFieldMetadata decoded_metadata;
+ decoded_metadata.Decode(&value);
+ ASSERT_TRUE(decoded_metadata.num_levels == 2);
+
+ redis::HnswNode node1_layer0(key1, 0);
+ auto s2 = node1_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s2.IsOK());
+ redis::HnswNodeFieldMetadata node1_layer0_meta = s2.GetValue();
+ EXPECT_EQ(node1_layer0_meta.num_neighbours, 0);
+
+ redis::HnswNode node1_layer1(key1, 1);
+ auto s3 = node1_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s3.IsOK());
+ redis::HnswNodeFieldMetadata node1_layer1_meta = s2.GetValue();
+ EXPECT_EQ(node1_layer1_meta.num_neighbours, 0);
+
+ // Insert n2 into layer 3
+ batch = storage_->GetWriteBatchBase();
+ target_level = 3;
+ auto s4 = hnsw_index->InsertVectorEntryInternal(key2, vec2, batch,
target_level);
+ ASSERT_TRUE(s4.IsOK());
+ s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ index_meta_key = hnsw_index->search_key.ConstructFieldMeta();
+ s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
+ &value);
+ ASSERT_TRUE(s.ok());
+ decoded_metadata.Decode(&value);
+ ASSERT_TRUE(decoded_metadata.num_levels == 4);
+
+ node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ EXPECT_EQ(node1_layer0.neighbours.size(), 1);
+ EXPECT_EQ(node1_layer0.neighbours[0], "n2");
+
+ node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ EXPECT_EQ(node1_layer1.neighbours.size(), 1);
+ EXPECT_EQ(node1_layer1.neighbours[0], "n2");
+
+ redis::HnswNode node2_layer0(key2, 0);
+ node2_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ EXPECT_EQ(node2_layer0.neighbours.size(), 1);
+ EXPECT_EQ(node2_layer0.neighbours[0], "n1");
+
+ redis::HnswNode node2_layer1(key2, 1);
+ node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ EXPECT_EQ(node2_layer1.neighbours.size(), 1);
+ EXPECT_EQ(node2_layer1.neighbours[0], "n1");
+
+ redis::HnswNode node2_layer2(key2, 2);
+ auto s5 = node2_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s5.IsOK());
+ redis::HnswNodeFieldMetadata node2_layer2_meta = s5.GetValue();
+ EXPECT_EQ(node2_layer2_meta.num_neighbours, 0);
+
+ redis::HnswNode node2_layer3(key2, 3);
+ auto s6 = node2_layer3.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s6.IsOK());
+ redis::HnswNodeFieldMetadata node2_layer3_meta = s6.GetValue();
+ EXPECT_EQ(node2_layer3_meta.num_neighbours, 0);
+
+ // Insert n3 into layer 2
+ batch = storage_->GetWriteBatchBase();
+ target_level = 2;
+ auto s7 = hnsw_index->InsertVectorEntryInternal(key3, vec3, batch,
target_level);
+ ASSERT_TRUE(s7.IsOK());
+ s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ index_meta_key = hnsw_index->search_key.ConstructFieldMeta();
+ s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
+ &value);
+ ASSERT_TRUE(s.ok());
+ decoded_metadata.Decode(&value);
+ ASSERT_TRUE(decoded_metadata.num_levels == 4);
+
+ redis::HnswNode node3_layer2(key3, target_level);
+ auto s8 = node3_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s8.IsOK());
+ redis::HnswNodeFieldMetadata node3_layer2_meta = s8.GetValue();
+ EXPECT_EQ(node3_layer2_meta.num_neighbours, 1);
+ node3_layer2.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ EXPECT_EQ(node3_layer2.neighbours.size(), 1);
+ EXPECT_EQ(node3_layer2.neighbours[0], "n2");
+
+ redis::HnswNode node3_layer1(key3, 1);
+ auto s9 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s9.IsOK());
+ redis::HnswNodeFieldMetadata node3_layer1_meta = s9.GetValue();
+ EXPECT_EQ(node3_layer1_meta.num_neighbours, 2);
+ node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ EXPECT_EQ(node3_layer1.neighbours.size(), 2);
+ std::unordered_set<std::string> expected_set = {"n1", "n2"};
+ std::unordered_set<std::string> actual_set{node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ // Insert n4 into layer 1
+ batch = storage_->GetWriteBatchBase();
+ target_level = 1;
+ auto s10 = hnsw_index->InsertVectorEntryInternal(key4, vec4, batch,
target_level);
+ ASSERT_TRUE(s10.IsOK());
+ s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ redis::HnswNode node4_layer0(key4, 0);
+ auto s11 = node4_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s11.IsOK());
+ redis::HnswNodeFieldMetadata node4_layer0_meta = s11.GetValue();
+ EXPECT_EQ(node4_layer0_meta.num_neighbours, 3);
+
+ auto s12 = node1_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s12.IsOK());
+ node1_layer1_meta = s12.GetValue();
+ EXPECT_EQ(node1_layer1_meta.num_neighbours, 3);
+ node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n2", "n3", "n4"};
+ actual_set = {node1_layer1.neighbours.begin(),
node1_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s13 = node2_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s13.IsOK());
+ auto node2_layer1_meta = s13.GetValue();
+ EXPECT_EQ(node2_layer1_meta.num_neighbours, 3);
+ node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n3", "n4"};
+ actual_set = {node2_layer1.neighbours.begin(),
node2_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s14 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s14.IsOK());
+ node3_layer1_meta = s14.GetValue();
+ EXPECT_EQ(node3_layer1_meta.num_neighbours, 3);
+ node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n2", "n4"};
+ actual_set = {node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ // Insert n5 into layer 1
+ batch = storage_->GetWriteBatchBase();
+ auto s15 = hnsw_index->InsertVectorEntryInternal(key5, vec5, batch,
target_level);
+ ASSERT_TRUE(s15.IsOK());
+ s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ auto s16 = node2_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s16.IsOK());
+ node2_layer1_meta = s16.GetValue();
+ EXPECT_EQ(node2_layer1_meta.num_neighbours, 3);
+ node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n4", "n5"};
+ actual_set = {node2_layer1.neighbours.begin(),
node2_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s17 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s17.IsOK());
+ node3_layer1_meta = s17.GetValue();
+ EXPECT_EQ(node3_layer1_meta.num_neighbours, 2);
+ node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n5"};
+ actual_set = {node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ redis::HnswNode node4_layer1(key4, 1);
+ auto s18 = node4_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s18.IsOK());
+ auto node4_layer1_meta = s18.GetValue();
+ EXPECT_EQ(node4_layer1_meta.num_neighbours, 3);
+ node4_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n2", "n5"};
+ actual_set = {node4_layer1.neighbours.begin(),
node4_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ redis::HnswNode node5_layer1(key5, 1);
+ auto s19 = node5_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s19.IsOK());
+ auto node5_layer1_meta = s19.GetValue();
+ EXPECT_EQ(node5_layer1_meta.num_neighbours, 3);
+ node5_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n2", "n3", "n4"};
+ actual_set = {node5_layer1.neighbours.begin(),
node5_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s20 = node1_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s20.IsOK());
+ node1_layer0_meta = s20.GetValue();
+ EXPECT_EQ(node1_layer0_meta.num_neighbours, 4);
+ node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n2", "n3", "n4", "n5"};
+ actual_set = {node1_layer0.neighbours.begin(),
node1_layer0.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ redis::HnswNode node5_layer0(key5, 0);
+ auto s21 = node5_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s21.IsOK());
+ auto node5_layer0_meta = s21.GetValue();
+ EXPECT_EQ(node5_layer0_meta.num_neighbours, 4);
+ node5_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n2", "n3", "n4"};
+ actual_set = {node5_layer0.neighbours.begin(),
node5_layer0.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ // Delete n2
+ batch = storage_->GetWriteBatchBase();
+ auto s22 = hnsw_index->DeleteVectorEntry(key2, batch);
+ ASSERT_TRUE(s22.IsOK());
+ s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ index_meta_key = hnsw_index->search_key.ConstructFieldMeta();
+ s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
+ &value);
+ ASSERT_TRUE(s.ok());
+ decoded_metadata.Decode(&value);
+ ASSERT_TRUE(decoded_metadata.num_levels == 3);
+
+ auto s23 = node2_layer3.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s23.IsOK());
+
+ auto s24 = node2_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s24.IsOK());
+
+ auto s25 = node2_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s25.IsOK());
+
+ auto s26 = node2_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s26.IsOK());
+
+ auto s27 = node3_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s27.IsOK());
+ node3_layer2_meta = s27.GetValue();
+ EXPECT_EQ(node3_layer2_meta.num_neighbours, 0);
+
+ auto s28 = node1_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s28.IsOK());
+ node1_layer1_meta = s28.GetValue();
+ EXPECT_EQ(node1_layer1_meta.num_neighbours, 2);
+ node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n3", "n4"};
+ actual_set = {node1_layer1.neighbours.begin(),
node1_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s29 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s29.IsOK());
+ node3_layer1_meta = s29.GetValue();
+ EXPECT_EQ(node3_layer1_meta.num_neighbours, 2);
+ node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n5"};
+ actual_set = {node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s30 = node4_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s30.IsOK());
+ node4_layer1_meta = s30.GetValue();
+ EXPECT_EQ(node4_layer1_meta.num_neighbours, 2);
+ node4_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n5"};
+ actual_set = {node4_layer1.neighbours.begin(),
node4_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s31 = node5_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s31.IsOK());
+ node5_layer1_meta = s31.GetValue();
+ EXPECT_EQ(node5_layer1_meta.num_neighbours, 2);
+ node5_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n3", "n4"};
+ actual_set = {node5_layer1.neighbours.begin(),
node5_layer1.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s32 = node1_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s32.IsOK());
+ node1_layer0_meta = s32.GetValue();
+ EXPECT_EQ(node1_layer0_meta.num_neighbours, 3);
+ node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n3", "n4", "n5"};
+ actual_set = {node1_layer0.neighbours.begin(),
node1_layer0.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ redis::HnswNode node3_layer0(key3, 0);
+ auto s33 = node3_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s33.IsOK());
+ auto node3_layer0_meta = s33.GetValue();
+ EXPECT_EQ(node3_layer0_meta.num_neighbours, 3);
+ node3_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n4", "n5"};
+ actual_set = {node3_layer0.neighbours.begin(),
node3_layer0.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s34 = node4_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s34.IsOK());
+ node4_layer0_meta = s34.GetValue();
+ EXPECT_EQ(node4_layer0_meta.num_neighbours, 3);
+ node4_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n3", "n5"};
+ actual_set = {node4_layer0.neighbours.begin(),
node4_layer0.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+
+ auto s35 = node5_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s35.IsOK());
+ node5_layer0_meta = s35.GetValue();
+ EXPECT_EQ(node5_layer0_meta.num_neighbours, 3);
+ node5_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ expected_set = {"n1", "n3", "n4"};
+ actual_set = {node5_layer0.neighbours.begin(),
node5_layer0.neighbours.end()};
+ EXPECT_EQ(actual_set, expected_set);
+}
diff --git a/tests/cppunit/hnsw_node_test.cc b/tests/cppunit/hnsw_node_test.cc
new file mode 100644
index 00000000..5fadf992
--- /dev/null
+++ b/tests/cppunit/hnsw_node_test.cc
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include <encoding.h>
+#include <gtest/gtest.h>
+#include <test_base.h>
+
+#include <iostream>
+#include <memory>
+#include <unordered_set>
+
+#include "search/hnsw_indexer.h"
+#include "search/indexer.h"
+#include "search/search_encoding.h"
+#include "storage/storage.h"
+
+struct NodeTest : public TestBase {
+ std::string ns = "hnsw_node_test_ns";
+ std::string idx_name = "hnsw_node_test_idx";
+ std::string key = "vector";
+ redis::SearchKey search_key;
+
+ NodeTest() : search_key(ns, idx_name, key) {}
+
+ void TearDown() override {}
+};
+
+TEST_F(NodeTest, PutAndDecodeMetadata) {
+ uint16_t layer = 0;
+ redis::HnswNode node1("node1", layer);
+ redis::HnswNode node2("node2", layer);
+ redis::HnswNode node3("node3", layer);
+
+ redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3});
+ redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6});
+ redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9});
+
+ auto batch = storage_->GetWriteBatchBase();
+ node1.PutMetadata(&metadata1, search_key, storage_.get(), batch.Get());
+ node2.PutMetadata(&metadata2, search_key, storage_.get(), batch.Get());
+ node3.PutMetadata(&metadata3, search_key, storage_.get(), batch.Get());
+ auto s = storage_->Write(storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ auto decoded_metadata1 = node1.DecodeMetadata(search_key, storage_.get());
+ ASSERT_TRUE(decoded_metadata1.IsOK());
+ ASSERT_EQ(decoded_metadata1.GetValue().num_neighbours, 0);
+ ASSERT_EQ(decoded_metadata1.GetValue().vector, std::vector<double>({1, 2,
3}));
+
+ auto decoded_metadata2 = node2.DecodeMetadata(search_key, storage_.get());
+ ASSERT_TRUE(decoded_metadata2.IsOK());
+ ASSERT_EQ(decoded_metadata2.GetValue().num_neighbours, 0);
+ ASSERT_EQ(decoded_metadata2.GetValue().vector, std::vector<double>({4, 5,
6}));
+
+ auto decoded_metadata3 = node3.DecodeMetadata(search_key, storage_.get());
+ ASSERT_TRUE(decoded_metadata3.IsOK());
+ ASSERT_EQ(decoded_metadata3.GetValue().num_neighbours, 0);
+ ASSERT_EQ(decoded_metadata3.GetValue().vector, std::vector<double>({7, 8,
9}));
+
+ // Prepare edges between node1 and node2
+ batch = storage_->GetWriteBatchBase();
+ auto edge1 = search_key.ConstructHnswEdge(layer, "node1", "node2");
+ auto edge2 = search_key.ConstructHnswEdge(layer, "node2", "node1");
+ auto edge3 = search_key.ConstructHnswEdge(layer, "node2", "node3");
+ auto edge4 = search_key.ConstructHnswEdge(layer, "node3", "node2");
+
+ batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge1, Slice());
+ batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge2, Slice());
+ batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge3, Slice());
+ batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge4, Slice());
+ s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ node1.DecodeNeighbours(search_key, storage_.get());
+ EXPECT_EQ(node1.neighbours.size(), 1);
+ EXPECT_EQ(node1.neighbours[0], "node2");
+
+ node2.DecodeNeighbours(search_key, storage_.get());
+ EXPECT_EQ(node2.neighbours.size(), 2);
+ std::unordered_set<std::string> expected_neighbours = {"node1", "node3"};
+ std::unordered_set<std::string> actual_neighbours(node2.neighbours.begin(),
node2.neighbours.end());
+ EXPECT_EQ(actual_neighbours, expected_neighbours);
+
+ node3.DecodeNeighbours(search_key, storage_.get());
+ EXPECT_EQ(node3.neighbours.size(), 1);
+ EXPECT_EQ(node3.neighbours[0], "node2");
+}
+
+TEST_F(NodeTest, ModifyNeighbours) {
+ uint16_t layer = 1;
+ redis::HnswNode node1("node1", layer);
+ redis::HnswNode node2("node2", layer);
+ redis::HnswNode node3("node3", layer);
+ redis::HnswNode node4("node4", layer);
+
+ redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3});
+ redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6});
+ redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9});
+ redis::HnswNodeFieldMetadata metadata4(0, {10, 11, 12});
+
+ // Add Nodes
+ auto batch1 = storage_->GetWriteBatchBase();
+ node1.PutMetadata(&metadata1, search_key, storage_.get(), batch1.Get());
+ node2.PutMetadata(&metadata2, search_key, storage_.get(), batch1.Get());
+ node3.PutMetadata(&metadata3, search_key, storage_.get(), batch1.Get());
+ node4.PutMetadata(&metadata4, search_key, storage_.get(), batch1.Get());
+ auto s = storage_->Write(storage_->DefaultWriteOptions(),
batch1->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ // Add Edges
+ auto batch2 = storage_->GetWriteBatchBase();
+ auto s1 = node1.AddNeighbour("node2", search_key, storage_.get(),
batch2.Get());
+ ASSERT_TRUE(s1.IsOK());
+ auto s2 = node2.AddNeighbour("node1", search_key, storage_.get(),
batch2.Get());
+ ASSERT_TRUE(s2.IsOK());
+ auto s3 = node2.AddNeighbour("node3", search_key, storage_.get(),
batch2.Get());
+ ASSERT_TRUE(s3.IsOK());
+ auto s4 = node3.AddNeighbour("node2", search_key, storage_.get(),
batch2.Get());
+ ASSERT_TRUE(s4.IsOK());
+ s = storage_->Write(storage_->DefaultWriteOptions(),
batch2->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ node1.DecodeNeighbours(search_key, storage_.get());
+ EXPECT_EQ(node1.neighbours.size(), 1);
+ EXPECT_EQ(node1.neighbours[0], "node2");
+
+ node2.DecodeNeighbours(search_key, storage_.get());
+ EXPECT_EQ(node2.neighbours.size(), 2);
+ std::unordered_set<std::string> expected_neighbours = {"node1", "node3"};
+ std::unordered_set<std::string> actual_neighbours(node2.neighbours.begin(),
node2.neighbours.end());
+ EXPECT_EQ(actual_neighbours, expected_neighbours);
+
+ node3.DecodeNeighbours(search_key, storage_.get());
+ EXPECT_EQ(node3.neighbours.size(), 1);
+ EXPECT_EQ(node3.neighbours[0], "node2");
+
+ // Remove Edges
+ auto batch3 = storage_->GetWriteBatchBase();
+ auto s5 = node2.RemoveNeighbour("node3", search_key, storage_.get(),
batch3.Get());
+ ASSERT_TRUE(s5.IsOK());
+
+ s = storage_->Write(storage_->DefaultWriteOptions(),
batch3->GetWriteBatch());
+ ASSERT_TRUE(s.ok());
+
+ node2.DecodeNeighbours(search_key, storage_.get());
+ EXPECT_EQ(node2.neighbours.size(), 1);
+ EXPECT_EQ(node2.neighbours[0], "node1");
+}