This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 79a740cf feat(search): Hnsw Vector Search Plan Operator & Executor
(#2434)
79a740cf is described below
commit 79a740cf8c7aad157c7b3dd32581c5071d7d78b6
Author: Rebecca Zhou <[email protected]>
AuthorDate: Wed Jul 24 00:21:08 2024 -0700
feat(search): Hnsw Vector Search Plan Operator & Executor (#2434)
Co-authored-by: Twice <[email protected]>
---
.../hnsw_vector_field_knn_scan_executor.h | 76 ++++
.../hnsw_vector_field_range_scan_executor.h | 86 +++++
src/search/hnsw_indexer.cc | 99 ++++-
src/search/hnsw_indexer.h | 11 +
src/search/ir_plan.h | 37 ++
src/search/plan_executor.cc | 16 +
tests/cppunit/hnsw_index_test.cc | 418 +++++++++------------
tests/cppunit/indexer_test.cc | 44 +++
tests/cppunit/plan_executor_test.cc | 100 ++++-
9 files changed, 634 insertions(+), 253 deletions(-)
diff --git a/src/search/executors/hnsw_vector_field_knn_scan_executor.h
b/src/search/executors/hnsw_vector_field_knn_scan_executor.h
new file mode 100644
index 00000000..5002fba1
--- /dev/null
+++ b/src/search/executors/hnsw_vector_field_knn_scan_executor.h
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <string>
+
+#include "db_util.h"
+#include "encoding.h"
+#include "search/hnsw_indexer.h"
+#include "search/plan_executor.h"
+#include "search/search_encoding.h"
+#include "storage/redis_db.h"
+#include "storage/redis_metadata.h"
+#include "storage/storage.h"
+#include "string_util.h"
+
+namespace kqir {
+
+// TODO(Beihao): Add DB context to improve consistency and isolation - see
#2332
+struct HnswVectorFieldKnnScanExecutor : ExecutorNode {
+ HnswVectorFieldKnnScan *scan;
+ redis::LatestSnapShot ss;
+ bool initialized = false;
+
+ IndexInfo *index;
+ redis::SearchKey search_key;
+ redis::HnswVectorFieldMetadata field_metadata;
+ redis::HnswIndex hnsw_index;
+ std::vector<redis::KeyWithDistance> row_keys;
+ decltype(row_keys)::iterator row_keys_iter;
+
+ HnswVectorFieldKnnScanExecutor(ExecutorContext *ctx, HnswVectorFieldKnnScan
*scan)
+ : ExecutorNode(ctx),
+ scan(scan),
+ ss(ctx->storage),
+ index(scan->field->info->index),
+ search_key(index->ns, index->name, scan->field->name),
+
field_metadata(*(scan->field->info->MetadataAs<redis::HnswVectorFieldMetadata>())),
+ hnsw_index(redis::HnswIndex(search_key, &field_metadata,
ctx->storage)) {}
+
+ StatusOr<Result> Next() override {
+ if (!initialized) {
+ row_keys = GET_OR_RET(hnsw_index.KnnSearch(scan->vector, scan->k));
+ row_keys_iter = row_keys.begin();
+ initialized = true;
+ }
+
+ if (row_keys_iter == row_keys.end()) {
+ return end;
+ }
+
+ auto key_str = row_keys_iter->second;
+ row_keys_iter++;
+ return RowType{key_str, {}, scan->field->info->index};
+ }
+};
+
+} // namespace kqir
diff --git a/src/search/executors/hnsw_vector_field_range_scan_executor.h
b/src/search/executors/hnsw_vector_field_range_scan_executor.h
new file mode 100644
index 00000000..afaf0297
--- /dev/null
+++ b/src/search/executors/hnsw_vector_field_range_scan_executor.h
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <string>
+
+#include "db_util.h"
+#include "encoding.h"
+#include "search/hnsw_indexer.h"
+#include "search/plan_executor.h"
+#include "search/search_encoding.h"
+#include "storage/redis_db.h"
+#include "storage/redis_metadata.h"
+#include "storage/storage.h"
+#include "string_util.h"
+
+namespace kqir {
+
+// TODO(Beihao): Add DB context to improve consistency and isolation - see
#2332
+struct HnswVectorFieldRangeScanExecutor : ExecutorNode {
+ HnswVectorFieldRangeScan *scan;
+ redis::LatestSnapShot ss;
+ bool initialized = false;
+
+ IndexInfo *index;
+ redis::SearchKey search_key;
+ redis::HnswVectorFieldMetadata field_metadata;
+ redis::HnswIndex hnsw_index;
+ std::vector<redis::KeyWithDistance> row_keys;
+ std::unordered_set<std::string> visited;
+ decltype(row_keys)::iterator row_keys_iter;
+
+ HnswVectorFieldRangeScanExecutor(ExecutorContext *ctx,
HnswVectorFieldRangeScan *scan)
+ : ExecutorNode(ctx),
+ scan(scan),
+ ss(ctx->storage),
+ index(scan->field->info->index),
+ search_key(index->ns, index->name, scan->field->name),
+
field_metadata(*(scan->field->info->MetadataAs<redis::HnswVectorFieldMetadata>())),
+ hnsw_index(redis::HnswIndex(search_key, &field_metadata,
ctx->storage)) {}
+
+ StatusOr<Result> Next() override {
+ if (!initialized) {
+ row_keys = GET_OR_RET(hnsw_index.KnnSearch(scan->vector,
field_metadata.ef_runtime));
+ row_keys_iter = row_keys.begin();
+ initialized = true;
+ }
+
+ auto effective_range = scan->range * (1 + field_metadata.epsilon);
+ if (row_keys_iter == row_keys.end() || row_keys_iter->first >
abs(effective_range) ||
+ row_keys_iter->first < -abs(effective_range)) {
+ row_keys = GET_OR_RET(hnsw_index.ExpandSearchScope(scan->vector,
std::move(row_keys), visited));
+ if (row_keys.empty()) return end;
+ row_keys_iter = row_keys.begin();
+ }
+
+ if (row_keys_iter->first > abs(effective_range) || row_keys_iter->first <
-abs(effective_range)) {
+ return end;
+ }
+
+ auto key_str = row_keys_iter->second;
+ row_keys_iter++;
+ visited.insert(key_str);
+ return RowType{key_str, {}, scan->field->info->index};
+ }
+};
+
+} // namespace kqir
diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc
index f03e4c95..3618ad89 100644
--- a/src/search/hnsw_indexer.cc
+++ b/src/search/hnsw_indexer.cc
@@ -275,14 +275,13 @@ StatusOr<std::vector<VectorItem>>
HnswIndex::SelectNeighbors(const VectorItem& v
return selected_vs;
}
-StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const
VectorItem& target_vector,
- uint32_t ef_runtime,
- const
std::vector<NodeKey>& entry_points) const {
- std::vector<VectorItem> candidates;
+StatusOr<std::vector<VectorItemWithDistance>> HnswIndex::SearchLayerInternal(
+ uint16_t level, const VectorItem& target_vector, uint32_t ef_runtime,
+ const std::vector<NodeKey>& entry_points) const {
+ std::vector<VectorItemWithDistance> result;
std::unordered_set<NodeKey> visited;
- std::priority_queue<std::pair<double, VectorItem>,
std::vector<std::pair<double, VectorItem>>, std::greater<>>
- explore_heap;
- std::priority_queue<std::pair<double, VectorItem>> result_heap;
+ std::priority_queue<VectorItemWithDistance,
std::vector<VectorItemWithDistance>, std::greater<>> explore_heap;
+ std::priority_queue<VectorItemWithDistance> result_heap;
for (const auto& entry_point_key : entry_points) {
HnswNode entry_node = HnswNode(entry_point_key, level);
@@ -330,13 +329,25 @@ StatusOr<std::vector<VectorItem>>
HnswIndex::SearchLayer(uint16_t level, const V
}
}
+ result.resize(result_heap.size());
+ auto idx = result_heap.size() - 1;
while (!result_heap.empty()) {
- candidates.push_back(result_heap.top().second);
+ result[idx] = result_heap.top();
result_heap.pop();
+ idx--;
}
+ return result;
+}
- std::reverse(candidates.begin(), candidates.end());
- return candidates;
+StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const
VectorItem& target_vector,
+ uint32_t ef_runtime,
+ const
std::vector<NodeKey>& entry_points) const {
+ std::vector<VectorItem> result;
+ auto result_with_distance = GET_OR_RET(SearchLayerInternal(level,
target_vector, ef_runtime, entry_points));
+ for (auto& [_, vector_item] : result_with_distance) {
+ result.push_back(std::move(vector_item));
+ }
+ return result;
}
Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const
kqir::NumericArray& vector,
@@ -549,4 +560,72 @@ Status HnswIndex::DeleteVectorEntry(std::string_view key,
ObserverOrUniquePtr<ro
return Status::OK();
}
+StatusOr<std::vector<KeyWithDistance>> HnswIndex::KnnSearch(const
kqir::NumericArray& query_vector, uint32_t k) const {
+ VectorItem query_vector_item;
+ GET_OR_RET(VectorItem::Create({}, query_vector, metadata,
&query_vector_item));
+
+ if (metadata->num_levels == 0) {
+ return {Status::NotFound, fmt::format("No vector found in the HNSW
index")};
+ }
+
+ auto level = metadata->num_levels - 1;
+ auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level));
+ std::vector<NodeKey> entry_points{default_entry_node};
+ std::vector<VectorItem> nearest_vec_items;
+
+ for (; level > 0; level--) {
+ nearest_vec_items = GET_OR_RET(SearchLayer(level, query_vector_item,
metadata->ef_runtime, entry_points));
+ entry_points = {nearest_vec_items[0].key};
+ }
+
+ uint32_t effective_ef = std::max(metadata->ef_runtime, k); // Ensure
ef_runtime is at least k
+ auto nearest_vec_with_distance = GET_OR_RET(SearchLayerInternal(0,
query_vector_item, effective_ef, entry_points));
+
+ uint32_t result_length = std::min(k,
static_cast<uint32_t>(nearest_vec_with_distance.size()));
+ std::vector<KeyWithDistance> nearest_neighbours;
+ for (uint32_t result_idx = 0; result_idx < result_length; result_idx++) {
+
nearest_neighbours.emplace_back(nearest_vec_with_distance[result_idx].first,
+
std::move(nearest_vec_with_distance[result_idx].second.key));
+ }
+ return nearest_neighbours;
+}
+
+StatusOr<std::vector<KeyWithDistance>> HnswIndex::ExpandSearchScope(const
kqir::NumericArray& query_vector,
+
std::vector<redis::KeyWithDistance>&& initial_keys,
+
std::unordered_set<std::string>& visited) const {
+ constexpr uint16_t level = 0;
+ VectorItem query_vector_item;
+ GET_OR_RET(VectorItem::Create({}, query_vector, metadata,
&query_vector_item));
+ std::vector<KeyWithDistance> result;
+
+ while (!initial_keys.empty()) {
+ auto current_key = initial_keys.front().second;
+ initial_keys.erase(initial_keys.begin());
+
+ auto current_node = HnswNode(current_key, level);
+ current_node.DecodeNeighbours(search_key, storage);
+
+ for (const auto& neighbour_key : current_node.neighbours) {
+ if (visited.find(neighbour_key) != visited.end()) {
+ continue;
+ }
+ visited.insert(neighbour_key);
+
+ auto neighbour_node = HnswNode(neighbour_key, level);
+ auto neighbour_node_metadata =
GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
+
+ VectorItem neighbour_node_vector;
+ GET_OR_RET(VectorItem::Create(neighbour_key,
std::move(neighbour_node_metadata.vector), metadata,
+ &neighbour_node_vector));
+
+ auto dist = GET_OR_RET(ComputeSimilarity(query_vector_item,
neighbour_node_vector));
+ result.emplace_back(dist, neighbour_key);
+ }
+ }
+ std::sort(result.begin(), result.end(),
+ [](const KeyWithDistance& a, const KeyWithDistance& b) { return
a.first < b.first; });
+
+ return result;
+}
+
} // namespace redis
diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h
index 0f9dbaae..72072039 100644
--- a/src/search/hnsw_indexer.h
+++ b/src/search/hnsw_indexer.h
@@ -78,6 +78,10 @@ struct VectorItem {
StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem&
right);
+using VectorItemWithDistance = std::pair<double, VectorItem>;
+using KeyWithDistance = std::pair<double, std::string>;
+
+// TODO(Beihao): Add DB context to improve consistency and isolation - see
#2332
struct HnswIndex {
using NodeKey = HnswNode::NodeKey;
@@ -103,6 +107,9 @@ struct HnswIndex {
StatusOr<std::vector<VectorItem>> SelectNeighbors(const VectorItem& vec,
const std::vector<VectorItem>& vectors,
uint16_t layer) const;
+ StatusOr<std::vector<VectorItemWithDistance>> SearchLayerInternal(uint16_t
level, const VectorItem& target_vector,
+ uint32_t
ef_runtime,
+ const
std::vector<NodeKey>& entry_points) const;
StatusOr<std::vector<VectorItem>> SearchLayer(uint16_t level, const
VectorItem& target_vector, uint32_t ef_runtime,
const std::vector<NodeKey>&
entry_points) const;
Status InsertVectorEntryInternal(std::string_view key, const
kqir::NumericArray& vector,
@@ -110,6 +117,10 @@ struct HnswIndex {
Status InsertVectorEntry(std::string_view key, const kqir::NumericArray&
vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>&
batch);
Status DeleteVectorEntry(std::string_view key,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const;
+ StatusOr<std::vector<KeyWithDistance>> KnnSearch(const kqir::NumericArray&
query_vector, uint32_t k) const;
+ StatusOr<std::vector<KeyWithDistance>> ExpandSearchScope(const
kqir::NumericArray& query_vector,
+
std::vector<redis::KeyWithDistance>&& initial_keys,
+
std::unordered_set<std::string>& visited) const;
};
} // namespace redis
diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h
index f93199b5..94e8b589 100644
--- a/src/search/ir_plan.h
+++ b/src/search/ir_plan.h
@@ -26,6 +26,7 @@
#include "ir.h"
#include "search/interval.h"
#include "search/ir_sema_checker.h"
+#include "search/value.h"
#include "string_util.h"
namespace kqir {
@@ -96,6 +97,42 @@ struct TagFieldScan : FieldScan {
}
};
+struct HnswVectorFieldKnnScan : FieldScan {
+ kqir::NumericArray vector;
+ uint16_t k;
+
+ HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray
vector, uint16_t k)
+ : FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}
+
+ std::string_view Name() const override { return "HnswVectorFieldKnnScan"; };
+ std::string Content() const override {
+ return fmt::format("[{}], {}", util::StringJoin(vector, [](auto v) {
return std::to_string(v); }), k);
+ };
+ std::string Dump() const override { return fmt::format("hnsw-vector-knn-scan
{}, {}", field->name, Content()); }
+
+ std::unique_ptr<Node> Clone() const override {
+ return
std::make_unique<HnswVectorFieldKnnScan>(field->CloneAs<FieldRef>(), vector, k);
+ }
+};
+
+struct HnswVectorFieldRangeScan : FieldScan {
+ kqir::NumericArray vector;
+ uint32_t range;
+
+ HnswVectorFieldRangeScan(std::unique_ptr<FieldRef> field, kqir::NumericArray
vector, uint32_t range)
+ : FieldScan(std::move(field)), vector(std::move(vector)), range(range) {}
+
+ std::string_view Name() const override { return "HnswVectorFieldRangeScan";
};
+ std::string Content() const override {
+ return fmt::format("[{}], {}", util::StringJoin(vector, [](auto v) {
return std::to_string(v); }), range);
+ };
+ std::string Dump() const override { return
fmt::format("hnsw-vector-range-scan {}, {}", field->name, Content()); }
+
+ std::unique_ptr<Node> Clone() const override {
+ return
std::make_unique<HnswVectorFieldRangeScan>(field->CloneAs<FieldRef>(), vector,
range);
+ }
+};
+
struct Filter : PlanOperator {
std::unique_ptr<PlanOperator> source;
std::unique_ptr<QueryExpr> filter_expr;
diff --git a/src/search/plan_executor.cc b/src/search/plan_executor.cc
index 9140587e..59fcd416 100644
--- a/src/search/plan_executor.cc
+++ b/src/search/plan_executor.cc
@@ -24,6 +24,8 @@
#include "search/executors/filter_executor.h"
#include "search/executors/full_index_scan_executor.h"
+#include "search/executors/hnsw_vector_field_knn_scan_executor.h"
+#include "search/executors/hnsw_vector_field_range_scan_executor.h"
#include "search/executors/limit_executor.h"
#include "search/executors/merge_executor.h"
#include "search/executors/mock_executor.h"
@@ -84,6 +86,14 @@ struct ExecutorContextVisitor {
return Visit(v);
}
+ if (auto v = dynamic_cast<HnswVectorFieldKnnScan *>(op)) {
+ return Visit(v);
+ }
+
+ if (auto v = dynamic_cast<HnswVectorFieldRangeScan *>(op)) {
+ return Visit(v);
+ }
+
if (auto v = dynamic_cast<Mock *>(op)) {
return Visit(v);
}
@@ -129,6 +139,12 @@ struct ExecutorContextVisitor {
void Visit(TagFieldScan *op) { ctx->nodes[op] =
std::make_unique<TagFieldScanExecutor>(ctx, op); }
+ void Visit(HnswVectorFieldKnnScan *op) { ctx->nodes[op] =
std::make_unique<HnswVectorFieldKnnScanExecutor>(ctx, op); }
+
+ void Visit(HnswVectorFieldRangeScan *op) {
+ ctx->nodes[op] = std::make_unique<HnswVectorFieldRangeScanExecutor>(ctx,
op);
+ }
+
void Visit(Mock *op) { ctx->nodes[op] = std::make_unique<MockExecutor>(ctx,
op); }
};
diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc
index e09e9830..3162c88f 100644
--- a/tests/cppunit/hnsw_index_test.cc
+++ b/tests/cppunit/hnsw_index_test.cc
@@ -31,6 +31,35 @@
#include "search/value.h"
#include "storage/storage.h"
+auto GetVectorKeys(const std::vector<redis::KeyWithDistance>& keys_by_dist) ->
std::vector<std::string> {
+ std::vector<std::string> result;
+ result.reserve(keys_by_dist.size());
+ for (const auto& [dist, key] : keys_by_dist) {
+ result.push_back(key);
+ }
+ return result;
+}
+
+void InsertEntryIntoHnswIndex(std::string_view key, const kqir::NumericArray&
vector, uint16_t target_level,
+ redis::HnswIndex* hnsw_index, engine::Storage*
storage) {
+ auto batch = storage->GetWriteBatchBase();
+ auto s = hnsw_index->InsertVectorEntryInternal(key, vector, batch,
target_level);
+ ASSERT_TRUE(s.IsOK());
+ auto status = storage->Write(storage->DefaultWriteOptions(),
batch->GetWriteBatch());
+ ASSERT_TRUE(status.ok());
+}
+
+void VerifyNodeMetadataAndNeighbours(redis::HnswNode* node, redis::HnswIndex*
hnsw_index,
+ const std::unordered_set<std::string>&
expected_set) {
+ auto s = node->DecodeMetadata(hnsw_index->search_key, hnsw_index->storage);
+ ASSERT_TRUE(s.IsOK());
+ auto node_meta = s.GetValue();
+ EXPECT_EQ(node_meta.num_neighbours,
static_cast<uint16_t>(expected_set.size()));
+ node->DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
+ std::unordered_set<std::string> actual_set = {(node->neighbours).begin(),
(node->neighbours).end()};
+ EXPECT_EQ(actual_set, expected_set);
+}
+
struct HnswIndexTest : TestBase {
redis::HnswVectorFieldMetadata metadata;
std::string ns = "hnsw_test_ns";
@@ -344,42 +373,27 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) {
std::string key4 = "n4";
std::string key5 = "n5";
- // Insert n1 into layer 1
+ // Insert
uint16_t target_level = 1;
- auto batch = storage_->GetWriteBatchBase();
- auto s1 = hnsw_index->InsertVectorEntryInternal(key1, vec1, batch,
target_level);
- ASSERT_TRUE(s1.IsOK());
- auto s = storage_->Write(storage_->DefaultWriteOptions(),
batch->GetWriteBatch());
- ASSERT_TRUE(s.ok());
+ InsertEntryIntoHnswIndex(key1, vec1, target_level, hnsw_index.get(),
storage_.get());
rocksdb::PinnableSlice value;
auto index_meta_key = hnsw_index->search_key.ConstructFieldMeta();
- s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
- &value);
+ auto s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search),
+ index_meta_key, &value);
ASSERT_TRUE(s.ok());
redis::HnswVectorFieldMetadata decoded_metadata;
decoded_metadata.Decode(&value);
ASSERT_TRUE(decoded_metadata.num_levels == 2);
redis::HnswNode node1_layer0(key1, 0);
- auto s2 = node1_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s2.IsOK());
- redis::HnswNodeFieldMetadata node1_layer0_meta = s2.GetValue();
- EXPECT_EQ(node1_layer0_meta.num_neighbours, 0);
-
+ VerifyNodeMetadataAndNeighbours(&node1_layer0, hnsw_index.get(), {});
redis::HnswNode node1_layer1(key1, 1);
- auto s3 = node1_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s3.IsOK());
- redis::HnswNodeFieldMetadata node1_layer1_meta = s2.GetValue();
- EXPECT_EQ(node1_layer1_meta.num_neighbours, 0);
+ VerifyNodeMetadataAndNeighbours(&node1_layer1, hnsw_index.get(), {});
- // Insert n2 into layer 3
- batch = storage_->GetWriteBatchBase();
+ // Insert
target_level = 3;
- auto s4 = hnsw_index->InsertVectorEntryInternal(key2, vec2, batch,
target_level);
- ASSERT_TRUE(s4.IsOK());
- s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
- ASSERT_TRUE(s.ok());
+ InsertEntryIntoHnswIndex(key2, vec2, target_level, hnsw_index.get(),
storage_.get());
index_meta_key = hnsw_index->search_key.ConstructFieldMeta();
s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
@@ -388,43 +402,23 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) {
decoded_metadata.Decode(&value);
ASSERT_TRUE(decoded_metadata.num_levels == 4);
- node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- EXPECT_EQ(node1_layer0.neighbours.size(), 1);
- EXPECT_EQ(node1_layer0.neighbours[0], "n2");
-
- node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- EXPECT_EQ(node1_layer1.neighbours.size(), 1);
- EXPECT_EQ(node1_layer1.neighbours[0], "n2");
+ VerifyNodeMetadataAndNeighbours(&node1_layer0, hnsw_index.get(), {"n2"});
+ VerifyNodeMetadataAndNeighbours(&node1_layer1, hnsw_index.get(), {"n2"});
redis::HnswNode node2_layer0(key2, 0);
- node2_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- EXPECT_EQ(node2_layer0.neighbours.size(), 1);
- EXPECT_EQ(node2_layer0.neighbours[0], "n1");
+ VerifyNodeMetadataAndNeighbours(&node2_layer0, hnsw_index.get(), {"n1"});
redis::HnswNode node2_layer1(key2, 1);
- node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- EXPECT_EQ(node2_layer1.neighbours.size(), 1);
- EXPECT_EQ(node2_layer1.neighbours[0], "n1");
+ VerifyNodeMetadataAndNeighbours(&node2_layer1, hnsw_index.get(), {"n1"});
redis::HnswNode node2_layer2(key2, 2);
- auto s5 = node2_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s5.IsOK());
- redis::HnswNodeFieldMetadata node2_layer2_meta = s5.GetValue();
- EXPECT_EQ(node2_layer2_meta.num_neighbours, 0);
-
+ VerifyNodeMetadataAndNeighbours(&node2_layer2, hnsw_index.get(), {});
redis::HnswNode node2_layer3(key2, 3);
- auto s6 = node2_layer3.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s6.IsOK());
- redis::HnswNodeFieldMetadata node2_layer3_meta = s6.GetValue();
- EXPECT_EQ(node2_layer3_meta.num_neighbours, 0);
+ VerifyNodeMetadataAndNeighbours(&node2_layer3, hnsw_index.get(), {});
- // Insert n3 into layer 2
- batch = storage_->GetWriteBatchBase();
+ // Insert
target_level = 2;
- auto s7 = hnsw_index->InsertVectorEntryInternal(key3, vec3, batch,
target_level);
- ASSERT_TRUE(s7.IsOK());
- s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
- ASSERT_TRUE(s.ok());
+ InsertEntryIntoHnswIndex(key3, vec3, target_level, hnsw_index.get(),
storage_.get());
index_meta_key = hnsw_index->search_key.ConstructFieldMeta();
s = storage_->Get(rocksdb::ReadOptions(),
hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
@@ -434,134 +428,41 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) {
ASSERT_TRUE(decoded_metadata.num_levels == 4);
redis::HnswNode node3_layer2(key3, target_level);
- auto s8 = node3_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s8.IsOK());
- redis::HnswNodeFieldMetadata node3_layer2_meta = s8.GetValue();
- EXPECT_EQ(node3_layer2_meta.num_neighbours, 1);
- node3_layer2.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- EXPECT_EQ(node3_layer2.neighbours.size(), 1);
- EXPECT_EQ(node3_layer2.neighbours[0], "n2");
-
+ VerifyNodeMetadataAndNeighbours(&node3_layer2, hnsw_index.get(), {"n2"});
redis::HnswNode node3_layer1(key3, 1);
- auto s9 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s9.IsOK());
- redis::HnswNodeFieldMetadata node3_layer1_meta = s9.GetValue();
- EXPECT_EQ(node3_layer1_meta.num_neighbours, 2);
- node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- EXPECT_EQ(node3_layer1.neighbours.size(), 2);
- std::unordered_set<std::string> expected_set = {"n1", "n2"};
- std::unordered_set<std::string> actual_set{node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ VerifyNodeMetadataAndNeighbours(&node3_layer1, hnsw_index.get(), {"n1",
"n2"});
- // Insert n4 into layer 1
- batch = storage_->GetWriteBatchBase();
+ // Insert
target_level = 1;
- auto s10 = hnsw_index->InsertVectorEntryInternal(key4, vec4, batch,
target_level);
- ASSERT_TRUE(s10.IsOK());
- s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
- ASSERT_TRUE(s.ok());
+ InsertEntryIntoHnswIndex(key4, vec4, target_level, hnsw_index.get(),
storage_.get());
redis::HnswNode node4_layer0(key4, 0);
- auto s11 = node4_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s11.IsOK());
- redis::HnswNodeFieldMetadata node4_layer0_meta = s11.GetValue();
+ auto s1 = node4_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ ASSERT_TRUE(s1.IsOK());
+ redis::HnswNodeFieldMetadata node4_layer0_meta = s1.GetValue();
EXPECT_EQ(node4_layer0_meta.num_neighbours, 3);
- auto s12 = node1_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s12.IsOK());
- node1_layer1_meta = s12.GetValue();
- EXPECT_EQ(node1_layer1_meta.num_neighbours, 3);
- node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n2", "n3", "n4"};
- actual_set = {node1_layer1.neighbours.begin(),
node1_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
-
- auto s13 = node2_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s13.IsOK());
- auto node2_layer1_meta = s13.GetValue();
- EXPECT_EQ(node2_layer1_meta.num_neighbours, 3);
- node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n3", "n4"};
- actual_set = {node2_layer1.neighbours.begin(),
node2_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
-
- auto s14 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s14.IsOK());
- node3_layer1_meta = s14.GetValue();
- EXPECT_EQ(node3_layer1_meta.num_neighbours, 3);
- node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n2", "n4"};
- actual_set = {node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ VerifyNodeMetadataAndNeighbours(&node1_layer1, hnsw_index.get(), {"n2",
"n3", "n4"});
+ VerifyNodeMetadataAndNeighbours(&node2_layer1, hnsw_index.get(), {"n1",
"n3", "n4"});
+ VerifyNodeMetadataAndNeighbours(&node3_layer1, hnsw_index.get(), {"n1",
"n2", "n4"});
// Insert n5 into layer 1
- batch = storage_->GetWriteBatchBase();
- auto s15 = hnsw_index->InsertVectorEntryInternal(key5, vec5, batch,
target_level);
- ASSERT_TRUE(s15.IsOK());
- s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
- ASSERT_TRUE(s.ok());
-
- auto s16 = node2_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s16.IsOK());
- node2_layer1_meta = s16.GetValue();
- EXPECT_EQ(node2_layer1_meta.num_neighbours, 3);
- node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n4", "n5"};
- actual_set = {node2_layer1.neighbours.begin(),
node2_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
-
- auto s17 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s17.IsOK());
- node3_layer1_meta = s17.GetValue();
- EXPECT_EQ(node3_layer1_meta.num_neighbours, 2);
- node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n5"};
- actual_set = {node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ InsertEntryIntoHnswIndex(key5, vec5, target_level, hnsw_index.get(),
storage_.get());
+ VerifyNodeMetadataAndNeighbours(&node2_layer1, hnsw_index.get(), {"n1",
"n4", "n5"});
+ VerifyNodeMetadataAndNeighbours(&node3_layer1, hnsw_index.get(), {"n1",
"n5"});
redis::HnswNode node4_layer1(key4, 1);
- auto s18 = node4_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s18.IsOK());
- auto node4_layer1_meta = s18.GetValue();
- EXPECT_EQ(node4_layer1_meta.num_neighbours, 3);
- node4_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n2", "n5"};
- actual_set = {node4_layer1.neighbours.begin(),
node4_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
-
+ VerifyNodeMetadataAndNeighbours(&node4_layer1, hnsw_index.get(), {"n1",
"n2", "n5"});
redis::HnswNode node5_layer1(key5, 1);
- auto s19 = node5_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s19.IsOK());
- auto node5_layer1_meta = s19.GetValue();
- EXPECT_EQ(node5_layer1_meta.num_neighbours, 3);
- node5_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n2", "n3", "n4"};
- actual_set = {node5_layer1.neighbours.begin(),
node5_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
-
- auto s20 = node1_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s20.IsOK());
- node1_layer0_meta = s20.GetValue();
- EXPECT_EQ(node1_layer0_meta.num_neighbours, 4);
- node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n2", "n3", "n4", "n5"};
- actual_set = {node1_layer0.neighbours.begin(),
node1_layer0.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
-
+ VerifyNodeMetadataAndNeighbours(&node5_layer1, hnsw_index.get(), {"n2",
"n3", "n4"});
+ VerifyNodeMetadataAndNeighbours(&node1_layer0, hnsw_index.get(), {"n2",
"n3", "n4", "n5"});
redis::HnswNode node5_layer0(key5, 0);
- auto s21 = node5_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s21.IsOK());
- auto node5_layer0_meta = s21.GetValue();
- EXPECT_EQ(node5_layer0_meta.num_neighbours, 4);
- node5_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n2", "n3", "n4"};
- actual_set = {node5_layer0.neighbours.begin(),
node5_layer0.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ VerifyNodeMetadataAndNeighbours(&node5_layer0, hnsw_index.get(), {"n1",
"n2", "n3", "n4"});
// Delete n2
- batch = storage_->GetWriteBatchBase();
- auto s22 = hnsw_index->DeleteVectorEntry(key2, batch);
- ASSERT_TRUE(s22.IsOK());
+ auto batch = storage_->GetWriteBatchBase();
+ auto s2 = hnsw_index->DeleteVectorEntry(key2, batch);
+ ASSERT_TRUE(s2.IsOK());
s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());
@@ -572,93 +473,126 @@ TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) {
decoded_metadata.Decode(&value);
ASSERT_TRUE(decoded_metadata.num_levels == 3);
- auto s23 = node2_layer3.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- EXPECT_TRUE(!s23.IsOK());
-
- auto s24 = node2_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- EXPECT_TRUE(!s24.IsOK());
+ auto s3 = node2_layer3.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s3.IsOK());
+ auto s4 = node2_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s4.IsOK());
+ auto s5 = node2_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s5.IsOK());
+ auto s6 = node2_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
+ EXPECT_TRUE(!s6.IsOK());
+
+ VerifyNodeMetadataAndNeighbours(&node3_layer2, hnsw_index.get(), {});
+ VerifyNodeMetadataAndNeighbours(&node1_layer1, hnsw_index.get(), {"n3",
"n4"});
+ VerifyNodeMetadataAndNeighbours(&node3_layer1, hnsw_index.get(), {"n1",
"n5"});
+ VerifyNodeMetadataAndNeighbours(&node4_layer1, hnsw_index.get(), {"n1",
"n5"});
+ VerifyNodeMetadataAndNeighbours(&node5_layer1, hnsw_index.get(), {"n3",
"n4"});
+ VerifyNodeMetadataAndNeighbours(&node1_layer0, hnsw_index.get(), {"n3",
"n4", "n5"});
+ redis::HnswNode node3_layer0(key3, 0);
+ VerifyNodeMetadataAndNeighbours(&node3_layer0, hnsw_index.get(), {"n1",
"n4", "n5"});
+ VerifyNodeMetadataAndNeighbours(&node4_layer0, hnsw_index.get(), {"n1",
"n3", "n5"});
+ VerifyNodeMetadataAndNeighbours(&node5_layer0, hnsw_index.get(), {"n1",
"n3", "n4"});
+}
- auto s25 = node2_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- EXPECT_TRUE(!s25.IsOK());
+TEST_F(HnswIndexTest, SearchKnnAndRange) {
+ hnsw_index->metadata->m = 3;
+ std::vector<double> query_vector = {31.0, 32.0, 23.0};
+ uint32_t k = 3;
+ auto s1 = hnsw_index->KnnSearch(query_vector, k);
+ ASSERT_FALSE(s1.IsOK());
+ EXPECT_EQ(s1.GetCode(), Status::NotFound);
- auto s26 = node2_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- EXPECT_TRUE(!s26.IsOK());
+ std::vector<double> vec1 = {11.0, 12.0, 13.0};
+ std::vector<double> vec2 = {14.0, 15.0, 16.0};
+ std::vector<double> vec3 = {17.0, 18.0, 19.0};
+ std::vector<double> vec4 = {12.0, 13.0, 14.0};
+ std::vector<double> vec5 = {30.0, 40.0, 35.0};
+ std::vector<double> vec6 = {10.0, 9.0, 8.0};
+ std::vector<double> vec7 = {7.0, 6.0, 5.0};
+ std::vector<double> vec8 = {36.0, 37.0, 38.0};
+ std::vector<double> vec9 = {39.0, 40.0, 41.0};
+ std::vector<double> vec10 = {42.0, 43.0, 44.0};
+ std::vector<double> vec11 = {2.0, 3.0, 4.0};
+ std::vector<double> vec12 = {4.0, 5.0, 6.0};
+
+ std::string key1 = "key1";
+ std::string key2 = "key2";
+ std::string key3 = "key3";
+ std::string key4 = "key4";
+ std::string key5 = "key5";
+ std::string key6 = "key6";
+ std::string key7 = "key7";
+ std::string key8 = "key8";
+ std::string key9 = "key9";
+ std::string key10 = "key10";
+ std::string key11 = "key11";
+ std::string key12 = "key12";
- auto s27 = node3_layer2.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s27.IsOK());
- node3_layer2_meta = s27.GetValue();
- EXPECT_EQ(node3_layer2_meta.num_neighbours, 0);
+ uint16_t target_level = 1;
+ InsertEntryIntoHnswIndex(key1, vec1, target_level, hnsw_index.get(),
storage_.get());
- auto s28 = node1_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s28.IsOK());
- node1_layer1_meta = s28.GetValue();
- EXPECT_EQ(node1_layer1_meta.num_neighbours, 2);
- node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n3", "n4"};
- actual_set = {node1_layer1.neighbours.begin(),
node1_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ // Search when HNSW graph contains less than k nodes
+ auto s2 = hnsw_index->KnnSearch(query_vector, k);
+ ASSERT_TRUE(s2.IsOK());
+ auto key_strs = GetVectorKeys(s2.GetValue());
+ std::vector<std::string> expected = {"key1"};
+ EXPECT_EQ(key_strs, expected);
- auto s29 = node3_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s29.IsOK());
- node3_layer1_meta = s29.GetValue();
- EXPECT_EQ(node3_layer1_meta.num_neighbours, 2);
- node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n5"};
- actual_set = {node3_layer1.neighbours.begin(),
node3_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ target_level = 2;
+ InsertEntryIntoHnswIndex(key2, vec2, target_level, hnsw_index.get(),
storage_.get());
+ target_level = 0;
+ InsertEntryIntoHnswIndex(key3, vec3, target_level, hnsw_index.get(),
storage_.get());
- auto s30 = node4_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s30.IsOK());
- node4_layer1_meta = s30.GetValue();
- EXPECT_EQ(node4_layer1_meta.num_neighbours, 2);
- node4_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n5"};
- actual_set = {node4_layer1.neighbours.begin(),
node4_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ // Search when HNSW graph contains exactly k nodes
+ auto s3 = hnsw_index->KnnSearch(query_vector, k);
+ ASSERT_TRUE(s3.IsOK());
+ key_strs = GetVectorKeys(s3.GetValue());
+ expected = {"key3", "key2", "key1"};
+ EXPECT_EQ(key_strs, expected);
- auto s31 = node5_layer1.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s31.IsOK());
- node5_layer1_meta = s31.GetValue();
- EXPECT_EQ(node5_layer1_meta.num_neighbours, 2);
- node5_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n3", "n4"};
- actual_set = {node5_layer1.neighbours.begin(),
node5_layer1.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ target_level = 1;
+ InsertEntryIntoHnswIndex(key4, vec4, target_level, hnsw_index.get(),
storage_.get());
+ target_level = 0;
+ InsertEntryIntoHnswIndex(key5, vec5, target_level, hnsw_index.get(),
storage_.get());
- auto s32 = node1_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s32.IsOK());
- node1_layer0_meta = s32.GetValue();
- EXPECT_EQ(node1_layer0_meta.num_neighbours, 3);
- node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n3", "n4", "n5"};
- actual_set = {node1_layer0.neighbours.begin(),
node1_layer0.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ // Search when HNSW graph contains more than k nodes
+ auto s4 = hnsw_index->KnnSearch(query_vector, k);
+ ASSERT_TRUE(s4.IsOK());
+ key_strs = GetVectorKeys(s4.GetValue());
+ expected = {"key5", "key3", "key2"};
+ EXPECT_EQ(key_strs, expected);
- redis::HnswNode node3_layer0(key3, 0);
- auto s33 = node3_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s33.IsOK());
- auto node3_layer0_meta = s33.GetValue();
- EXPECT_EQ(node3_layer0_meta.num_neighbours, 3);
- node3_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n4", "n5"};
- actual_set = {node3_layer0.neighbours.begin(),
node3_layer0.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ // Edge case: If ef_runtime is smaller than k, enlarge ef_runtime equal to k
+ hnsw_index->metadata->ef_runtime = 1;
+ auto s5 = hnsw_index->KnnSearch(query_vector, k);
+ ASSERT_TRUE(s5.IsOK());
+ auto result = s5.GetValue();
+ key_strs = GetVectorKeys(result);
+ expected = {"key5", "key3", "key2"};
+ EXPECT_EQ(key_strs, expected);
+
+ hnsw_index->metadata->ef_runtime = 5;
+ InsertEntryIntoHnswIndex(key6, vec6, target_level, hnsw_index.get(),
storage_.get());
+ InsertEntryIntoHnswIndex(key7, vec7, target_level, hnsw_index.get(),
storage_.get());
+ InsertEntryIntoHnswIndex(key8, vec8, target_level, hnsw_index.get(),
storage_.get());
+ InsertEntryIntoHnswIndex(key9, vec9, target_level, hnsw_index.get(),
storage_.get());
+ target_level = 1;
+ InsertEntryIntoHnswIndex(key10, vec10, target_level, hnsw_index.get(),
storage_.get());
+ InsertEntryIntoHnswIndex(key11, vec11, target_level, hnsw_index.get(),
storage_.get());
+ target_level = 2;
+ InsertEntryIntoHnswIndex(key12, vec12, target_level, hnsw_index.get(),
storage_.get());
- auto s34 = node4_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s34.IsOK());
- node4_layer0_meta = s34.GetValue();
- EXPECT_EQ(node4_layer0_meta.num_neighbours, 3);
- node4_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n3", "n5"};
- actual_set = {node4_layer0.neighbours.begin(),
node4_layer0.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ std::unordered_set<std::string> visited{key_strs.begin(), key_strs.end()};
+ auto s6 = hnsw_index->ExpandSearchScope(query_vector, std::move(result),
visited);
+ ASSERT_TRUE(s6.IsOK());
+ result = s6.GetValue();
+ key_strs = GetVectorKeys(result);
+ expected = {"key8", "key9", "key10", "key4", "key1", "key6", "key7",
"key12"};
+ EXPECT_EQ(key_strs, expected);
- auto s35 = node5_layer0.DecodeMetadata(hnsw_index->search_key,
hnsw_index->storage);
- ASSERT_TRUE(s35.IsOK());
- node5_layer0_meta = s35.GetValue();
- EXPECT_EQ(node5_layer0_meta.num_neighbours, 3);
- node5_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage);
- expected_set = {"n1", "n3", "n4"};
- actual_set = {node5_layer0.neighbours.begin(),
node5_layer0.neighbours.end()};
- EXPECT_EQ(actual_set, expected_set);
+ auto s7 = hnsw_index->ExpandSearchScope(query_vector, std::move(result),
visited);
+ ASSERT_TRUE(s7.IsOK());
+ key_strs = GetVectorKeys(s7.GetValue());
+ expected = {"key11"};
+ EXPECT_EQ(key_strs, expected);
}
diff --git a/tests/cppunit/indexer_test.cc b/tests/cppunit/indexer_test.cc
index c3a7769e..f6e1e64d 100644
--- a/tests/cppunit/indexer_test.cc
+++ b/tests/cppunit/indexer_test.cc
@@ -56,6 +56,11 @@ struct IndexerTest : TestBase {
auto json_info = std::make_unique<kqir::IndexInfo>("jsontest",
json_field_meta, ns);
json_info->Add(kqir::FieldInfo("$.x",
std::make_unique<redis::TagFieldMetadata>()));
json_info->Add(kqir::FieldInfo("$.y",
std::make_unique<redis::NumericFieldMetadata>()));
+ auto hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+ hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+ hnsw_field_meta->dim = 3;
+ hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
+ json_info->Add(kqir::FieldInfo("$.z", std::move(hnsw_field_meta)));
json_info->prefixes.prefixes.emplace_back("idxtestjson");
map.emplace("jsontest", std::move(json_info));
@@ -280,3 +285,42 @@ TEST_F(IndexerTest, JsonTagBuildIndex) {
ASSERT_EQ(val, "");
}
}
+
+TEST_F(IndexerTest, JsonHnswVector) {
+ redis::Json db(storage_.get(), ns);
+ auto cfhandler = storage_->GetCFHandle(ColumnFamilyID::Search);
+
+ {
+ auto s = indexer.Record("no_exist", ns);
+ ASSERT_TRUE(s.Is<Status::NoPrefixMatched>());
+ }
+
+ auto key3 = "idxtestjson:k3";
+ auto idxname = "jsontest";
+
+ {
+ auto s = indexer.Record(key3, ns);
+ ASSERT_TRUE(s);
+ ASSERT_EQ(s->updater.info->name, idxname);
+ ASSERT_TRUE(s->fields.empty());
+
+ auto s_set = db.Set(key3, "$", R"({"z": [1,2,3]})");
+ ASSERT_TRUE(s_set.ok());
+
+ auto s2 = indexer.Update(*s);
+ EXPECT_EQ(s2.Msg(), Status::ok_msg);
+
+ auto search_key = redis::SearchKey(ns, idxname,
"$.z").ConstructHnswNode(0, key3);
+
+ std::string val;
+ auto s3 = storage_->Get(storage_->DefaultMultiGetOptions(), cfhandler,
search_key, &val);
+ ASSERT_TRUE(s3.ok());
+
+ redis::HnswNodeFieldMetadata node_meta;
+ Slice input(val);
+ node_meta.Decode(&input);
+ EXPECT_EQ(node_meta.num_neighbours, 0);
+ std::vector<double> expected = {1, 2, 3};
+ EXPECT_EQ(expected, node_meta.vector);
+ }
+}
diff --git a/tests/cppunit/plan_executor_test.cc
b/tests/cppunit/plan_executor_test.cc
index 00e0c162..1b80329d 100644
--- a/tests/cppunit/plan_executor_test.cc
+++ b/tests/cppunit/plan_executor_test.cc
@@ -42,6 +42,13 @@ static IndexMap MakeIndexMap() {
auto f1 = FieldInfo("f1", std::make_unique<redis::TagFieldMetadata>());
auto f2 = FieldInfo("f2", std::make_unique<redis::NumericFieldMetadata>());
auto f3 = FieldInfo("f3", std::make_unique<redis::NumericFieldMetadata>());
+
+ auto hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+ hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+ hnsw_field_meta->dim = 3;
+ hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
+ auto f4 = FieldInfo("f4", std::move(hnsw_field_meta));
+
auto ia = std::make_unique<IndexInfo>("ia", redis::IndexMetadata(),
"search_ns");
ia->metadata.on_data_type = redis::IndexOnDataType::JSON;
ia->prefixes.prefixes.emplace_back("test2:");
@@ -49,6 +56,7 @@ static IndexMap MakeIndexMap() {
ia->Add(std::move(f1));
ia->Add(std::move(f2));
ia->Add(std::move(f3));
+ ia->Add(std::move(f4));
IndexMap res;
res.Insert(std::move(ia));
@@ -416,4 +424,94 @@ TEST_F(PlanExecutorTestC, TagFieldScan) {
ASSERT_EQ(NextRow(ctx).key, "test2:e");
ASSERT_EQ(ctx.Next().GetValue(), exe_end);
}
-}
\ No newline at end of file
+}
+
+TEST_F(PlanExecutorTestC, HnswVectorFieldScans) {
+ redis::GlobalIndexer indexer(storage_.get());
+ indexer.Add(redis::IndexUpdater(IndexI()));
+
+ {
+ auto updates = ScopedUpdates(indexer,
+ {"test2:a", "test2:b", "test2:c", "test2:d",
"test2:e", "test2:f", "test2:g",
+ "test2:h", "test2:i", "test2:j", "test2:k",
"test2:l", "test2:m", "test2:n"},
+ "search_ns");
+ json_->Set("test2:a", "$", "{\"f4\": [1,2,3]}");
+ json_->Set("test2:b", "$", "{\"f4\": [4,5,6]}");
+ json_->Set("test2:c", "$", "{\"f4\": [7,8,9]}");
+ json_->Set("test2:d", "$", "{\"f4\": [10,11,12]}");
+ json_->Set("test2:e", "$", "{\"f4\": [13,14,15]}");
+ json_->Set("test2:f", "$", "{\"f4\": [23,24,25]}");
+ json_->Set("test2:g", "$", "{\"f4\": [26,27,28]}");
+ json_->Set("test2:h", "$", "{\"f4\": [77,78,79]}");
+ json_->Set("test2:i", "$", "{\"f4\": [80,81,82]}");
+ json_->Set("test2:j", "$", "{\"f4\": [83,84,85]}");
+ json_->Set("test2:k", "$", "{\"f4\": [86,87,88]}");
+ json_->Set("test2:l", "$", "{\"f4\": [89,90,91]}");
+ json_->Set("test2:m", "$", "{\"f4\": [1026,1027,1028]}");
+ json_->Set("test2:n", "$", "{\"f4\": [2226,2227,2228]}");
+ }
+
+ {
+ std::vector<double> target_vector = {14, 15, 16};
+ auto op =
+
std::make_unique<HnswVectorFieldKnnScan>(std::make_unique<FieldRef>("f4",
FieldI("f4")), target_vector, 5);
+
+ auto ctx = ExecutorContext(op.get(), storage_.get());
+ ASSERT_EQ(NextRow(ctx).key, "test2:e");
+ ASSERT_EQ(NextRow(ctx).key, "test2:d");
+ ASSERT_EQ(NextRow(ctx).key, "test2:c");
+ ASSERT_EQ(NextRow(ctx).key, "test2:f");
+ ASSERT_EQ(NextRow(ctx).key, "test2:b");
+ ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+ }
+
+ {
+ std::vector<double> target_vector = {24, 25, 26};
+ auto op =
+
std::make_unique<HnswVectorFieldKnnScan>(std::make_unique<FieldRef>("f4",
FieldI("f4")), target_vector, 3);
+
+ auto ctx = ExecutorContext(op.get(), storage_.get());
+ ASSERT_EQ(NextRow(ctx).key, "test2:f");
+ ASSERT_EQ(NextRow(ctx).key, "test2:g");
+ ASSERT_EQ(NextRow(ctx).key, "test2:e");
+ ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+ }
+
+ {
+ std::vector<double> query_vector = {11, 12, 13};
+ auto op =
+
std::make_unique<HnswVectorFieldRangeScan>(std::make_unique<FieldRef>("f4",
FieldI("f4")), query_vector, 25);
+
+ auto ctx = ExecutorContext(op.get(), storage_.get());
+ ASSERT_EQ(NextRow(ctx).key, "test2:d");
+ ASSERT_EQ(NextRow(ctx).key, "test2:e");
+ ASSERT_EQ(NextRow(ctx).key, "test2:c");
+ ASSERT_EQ(NextRow(ctx).key, "test2:b");
+ ASSERT_EQ(NextRow(ctx).key, "test2:a");
+ ASSERT_EQ(NextRow(ctx).key, "test2:f");
+ ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+ }
+
+ {
+ std::vector<double> query_vector = {12, 13, 14};
+ auto op =
+
std::make_unique<HnswVectorFieldRangeScan>(std::make_unique<FieldRef>("f4",
FieldI("f4")), query_vector, 5000);
+
+ auto ctx = ExecutorContext(op.get(), storage_.get());
+ ASSERT_EQ(NextRow(ctx).key, "test2:e");
+ ASSERT_EQ(NextRow(ctx).key, "test2:d");
+ ASSERT_EQ(NextRow(ctx).key, "test2:c");
+ ASSERT_EQ(NextRow(ctx).key, "test2:b");
+ ASSERT_EQ(NextRow(ctx).key, "test2:a");
+ ASSERT_EQ(NextRow(ctx).key, "test2:f");
+ ASSERT_EQ(NextRow(ctx).key, "test2:g");
+ ASSERT_EQ(NextRow(ctx).key, "test2:h");
+ ASSERT_EQ(NextRow(ctx).key, "test2:i");
+ ASSERT_EQ(NextRow(ctx).key, "test2:j");
+ ASSERT_EQ(NextRow(ctx).key, "test2:k");
+ ASSERT_EQ(NextRow(ctx).key, "test2:l");
+ ASSERT_EQ(NextRow(ctx).key, "test2:m");
+ ASSERT_EQ(NextRow(ctx).key, "test2:n");
+ ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+ }
+}