This is an automated email from the ASF dual-hosted git repository.

beihao pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new b8cf118b fix(test): use fixed seed to improve reproducibility. (#2557)
b8cf118b is described below

commit b8cf118bbe56984c7112c0929402f79d81d8a0d7
Author: Edward Xu <[email protected]>
AuthorDate: Fri Sep 27 18:35:30 2024 +0800

    fix(test): use fixed seed to improve reproducibility. (#2557)
---
 src/search/hnsw_indexer.cc       | 9 ++++-----
 src/search/hnsw_indexer.h        | 3 ++-
 tests/cppunit/hnsw_index_test.cc | 3 ++-
 3 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc
index bb2918d5..de0f519a 100644
--- a/src/search/hnsw_indexer.cc
+++ b/src/search/hnsw_indexer.cc
@@ -172,14 +172,13 @@ StatusOr<double> ComputeSimilarity(const VectorItem& 
left, const VectorItem& rig
   }
 }
 
-HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* 
vector, engine::Storage* storage)
+HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* 
vector, engine::Storage* storage,
+                     std::random_device::result_type seed)
     : search_key(search_key),
       metadata(vector),
       storage(storage),
-      m_level_normalization_factor(1.0 / std::log(metadata->m)) {
-  std::random_device rand_dev;
-  generator = std::mt19937(rand_dev());
-}
+      generator(std::mt19937(seed)),
+      m_level_normalization_factor(1.0 / std::log(metadata->m)) {}
 
 uint16_t HnswIndex::RandomizeLayer() {
   std::uniform_real_distribution<double> level_dist(0.0, 1.0);
diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h
index cfe352ff..579352a8 100644
--- a/src/search/hnsw_indexer.h
+++ b/src/search/hnsw_indexer.h
@@ -92,7 +92,8 @@ struct HnswIndex {
   std::mt19937 generator;
   double m_level_normalization_factor;
 
-  HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, 
engine::Storage* storage);
+  HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, 
engine::Storage* storage,
+            std::random_device::result_type seed = std::random_device()());
 
   static StatusOr<std::vector<VectorItem>> 
DecodeNodesToVectorItems(engine::Context& ctx,
                                                                     const 
std::vector<NodeKey>& node_key,
diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc
index 332c1582..022f2a73 100644
--- a/tests/cppunit/hnsw_index_test.cc
+++ b/tests/cppunit/hnsw_index_test.cc
@@ -66,6 +66,7 @@ struct HnswIndexTest : TestBase {
   std::string idx_name = "hnsw_test_idx";
   std::string key = "vector";
   std::unique_ptr<redis::HnswIndex> hnsw_index;
+  const std::random_device::result_type seed = 14863;  // fixed seed for 
reproducibility
 
   HnswIndexTest() {
     metadata.vector_type = redis::VectorType::FLOAT64;
@@ -73,7 +74,7 @@ struct HnswIndexTest : TestBase {
     metadata.m = 3;
     metadata.distance_metric = redis::DistanceMetric::L2;
     auto search_key = redis::SearchKey(ns, idx_name, key);
-    hnsw_index = std::make_unique<redis::HnswIndex>(search_key, &metadata, 
storage_.get());
+    hnsw_index = std::make_unique<redis::HnswIndex>(search_key, &metadata, 
storage_.get(), seed);
   }
 
   void TearDown() override { hnsw_index.reset(); }

Reply via email to