This is an automated email from the ASF dual-hosted git repository.
beihao pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new b8cf118b fix(test): use fixed seed to improve reproducibility. (#2557)
b8cf118b is described below
commit b8cf118bbe56984c7112c0929402f79d81d8a0d7
Author: Edward Xu <[email protected]>
AuthorDate: Fri Sep 27 18:35:30 2024 +0800
fix(test): use fixed seed to improve reproducibility. (#2557)
---
src/search/hnsw_indexer.cc | 9 ++++-----
src/search/hnsw_indexer.h | 3 ++-
tests/cppunit/hnsw_index_test.cc | 3 ++-
3 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc
index bb2918d5..de0f519a 100644
--- a/src/search/hnsw_indexer.cc
+++ b/src/search/hnsw_indexer.cc
@@ -172,14 +172,13 @@ StatusOr<double> ComputeSimilarity(const VectorItem&
left, const VectorItem& rig
}
}
-HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata*
vector, engine::Storage* storage)
+HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata*
vector, engine::Storage* storage,
+ std::random_device::result_type seed)
: search_key(search_key),
metadata(vector),
storage(storage),
- m_level_normalization_factor(1.0 / std::log(metadata->m)) {
- std::random_device rand_dev;
- generator = std::mt19937(rand_dev());
-}
+ generator(std::mt19937(seed)),
+ m_level_normalization_factor(1.0 / std::log(metadata->m)) {}
uint16_t HnswIndex::RandomizeLayer() {
std::uniform_real_distribution<double> level_dist(0.0, 1.0);
diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h
index cfe352ff..579352a8 100644
--- a/src/search/hnsw_indexer.h
+++ b/src/search/hnsw_indexer.h
@@ -92,7 +92,8 @@ struct HnswIndex {
std::mt19937 generator;
double m_level_normalization_factor;
- HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector,
engine::Storage* storage);
+ HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector,
engine::Storage* storage,
+ std::random_device::result_type seed = std::random_device()());
static StatusOr<std::vector<VectorItem>>
DecodeNodesToVectorItems(engine::Context& ctx,
const
std::vector<NodeKey>& node_key,
diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc
index 332c1582..022f2a73 100644
--- a/tests/cppunit/hnsw_index_test.cc
+++ b/tests/cppunit/hnsw_index_test.cc
@@ -66,6 +66,7 @@ struct HnswIndexTest : TestBase {
std::string idx_name = "hnsw_test_idx";
std::string key = "vector";
std::unique_ptr<redis::HnswIndex> hnsw_index;
+ const std::random_device::result_type seed = 14863; // fixed seed for
reproducibility
HnswIndexTest() {
metadata.vector_type = redis::VectorType::FLOAT64;
@@ -73,7 +74,7 @@ struct HnswIndexTest : TestBase {
metadata.m = 3;
metadata.distance_metric = redis::DistanceMetric::L2;
auto search_key = redis::SearchKey(ns, idx_name, key);
- hnsw_index = std::make_unique<redis::HnswIndex>(search_key, &metadata,
storage_.get());
+ hnsw_index = std::make_unique<redis::HnswIndex>(search_key, &metadata,
storage_.get(), seed);
}
void TearDown() override { hnsw_index.reset(); }