This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 8e170e7 Add IVF_HNSW_FLAT index and disk reader (#19)
8e170e7 is described below
commit 8e170e772625be0dae2eba7418589928f7cc8641
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 14:06:04 2026 +0800
Add IVF_HNSW_FLAT index and disk reader (#19)
---
core/benches/recall_bench.rs | 70 +-
core/src/distance.rs | 29 +
core/src/hnsw.rs | 767 ++++++++++++++++++++++
core/src/ivfflat.rs | 10 +-
core/src/ivfhnswflat.rs | 355 ++++++++++
core/src/ivfhnswflat_io.rs | 1460 ++++++++++++++++++++++++++++++++++++++++++
core/src/lib.rs | 4 +
core/src/topk.rs | 90 +++
8 files changed, 2773 insertions(+), 12 deletions(-)
diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index 03415be..1755402 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -1,5 +1,7 @@
use paimon_vindex_core::distance::{fvec_distance, MetricType};
+use paimon_vindex_core::hnsw::HnswBuildParams;
use paimon_vindex_core::ivfflat::IVFFlatIndex;
+use paimon_vindex_core::ivfhnswflat::IVFHNSWFlatIndex;
use paimon_vindex_core::ivfpq::IVFPQIndex;
use std::collections::HashSet;
use std::time::Instant;
@@ -14,6 +16,8 @@ fn main() {
nlist: 64,
pq_m: 8,
nprobes: &[1, 4, 8, 16, 32, 64],
+ hnsw_build_ef: 80,
+ hnsw_search_efs: &[80],
});
println!();
@@ -27,6 +31,8 @@ fn main() {
nlist: 8,
pq_m: 8,
nprobes: &[1, 2, 4, 8],
+ hnsw_build_ef: 200,
+ hnsw_search_efs: &[80, 160, 320],
});
}
@@ -39,6 +45,8 @@ struct Scenario<'a> {
nlist: usize,
pq_m: usize,
nprobes: &'a [usize],
+ hnsw_build_ef: usize,
+ hnsw_search_efs: &'a [usize],
}
fn run_scenario(s: Scenario<'_>) {
@@ -75,9 +83,28 @@ fn run_scenario(s: Scenario<'_>) {
ivfflat.add(&data, &ids, s.n);
println!("build IVF-FLAT: {:.2}s", start.elapsed().as_secs_f64());
+ let start = Instant::now();
+ let mut ivfhnswflat = IVFHNSWFlatIndex::new(
+ s.d,
+ s.nlist,
+ MetricType::L2,
+ HnswBuildParams {
+ m: 16,
+ ef_construction: s.hnsw_build_ef,
+ max_level: 7,
+ },
+ );
+ ivfhnswflat.train(&data, s.n);
+ ivfhnswflat.add(&data, &ids, s.n);
+ ivfhnswflat.build_graphs().unwrap();
+ println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
+
println!();
- println!("index nprobe recall@{} query_ms us/query", s.k);
- println!("--------- ------ --------- -------- --------");
+ println!(
+ "index nprobe ef recall@{} query_ms us/query",
+ s.k
+ );
+ println!("--------- ------ ------ --------- -------- --------");
for &nprobe in s.nprobes {
let mut distances = vec![0.0f32; s.nq * s.k];
@@ -88,6 +115,7 @@ fn run_scenario(s: Scenario<'_>) {
print_row(
"IVF-PQ",
nprobe,
+ None,
recall_at_k(&labels, &ground_truth, s.nq, s.k),
elapsed,
s.nq,
@@ -101,19 +129,53 @@ fn run_scenario(s: Scenario<'_>) {
print_row(
"IVF-FLAT",
nprobe,
+ None,
recall_at_k(&labels, &ground_truth, s.nq, s.k),
elapsed,
s.nq,
);
+
+ for &ef_search in s.hnsw_search_efs {
+ let mut distances = vec![0.0f32; s.nq * s.k];
+ let mut labels = vec![0i64; s.nq * s.k];
+ let start = Instant::now();
+ ivfhnswflat.search(
+ queries,
+ s.nq,
+ s.k,
+ nprobe,
+ ef_search,
+ &mut distances,
+ &mut labels,
+ );
+ let elapsed = start.elapsed();
+ print_row(
+ "IVF-HNSW",
+ nprobe,
+ Some(ef_search),
+ recall_at_k(&labels, &ground_truth, s.nq, s.k),
+ elapsed,
+ s.nq,
+ );
+ }
}
}
-fn print_row(index: &str, nprobe: usize, recall: f64, elapsed:
std::time::Duration, nq: usize) {
+fn print_row(
+ index: &str,
+ nprobe: usize,
+ ef: Option<usize>,
+ recall: f64,
+ elapsed: std::time::Duration,
+ nq: usize,
+) {
let ms = elapsed.as_secs_f64() * 1000.0;
+ let ef = ef.map(|v| v.to_string()).unwrap_or_else(|| "-".to_string());
println!(
- "{:<9} {:>6} {:>8.2}% {:>8.2} {:>8.1}",
+ "{:<9} {:>6} {:>6} {:>8.2}% {:>8.2} {:>8.1}",
index,
nprobe,
+ ef,
recall * 100.0,
ms,
ms * 1000.0 / nq as f64
diff --git a/core/src/distance.rs b/core/src/distance.rs
index 6c86d3c..4cbadd6 100644
--- a/core/src/distance.rs
+++ b/core/src/distance.rs
@@ -105,6 +105,35 @@ pub fn fvec_distance(query: &[f32], vector: &[f32],
metric: MetricType) -> f32 {
}
}
+pub fn preprocess_vectors(data: &[f32], n: usize, d: usize, metric:
MetricType) -> Vec<f32> {
+ let mut processed = data[..n * d].to_vec();
+ if metric == MetricType::Cosine {
+ for i in 0..n {
+ fvec_normalize(&mut processed[i * d..(i + 1) * d]);
+ }
+ }
+ processed
+}
+
+#[cfg(test)]
+mod preprocess_tests {
+ use super::*;
+
+ #[test]
+ fn test_preprocess_vectors_normalizes_cosine_only() {
+ let data = vec![3.0, 4.0, 1.0, 2.0];
+
+ assert_eq!(
+ preprocess_vectors(&data, 1, 2, MetricType::L2),
+ vec![3.0, 4.0]
+ );
+ assert_eq!(
+ preprocess_vectors(&data, 1, 2, MetricType::Cosine),
+ vec![0.6, 0.8]
+ );
+ }
+}
+
/// Compute result[i] = a[i] + bf * b[i]. Used for precomputed table merging.
/// Aligned with Faiss's fvec_madd.
pub fn fvec_madd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
diff --git a/core/src/hnsw.rs b/core/src/hnsw.rs
new file mode 100644
index 0000000..4e4683a
--- /dev/null
+++ b/core/src/hnsw.rs
@@ -0,0 +1,767 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, MetricType};
+use std::cmp::Reverse;
+use std::collections::BinaryHeap;
+use std::io;
+
+#[derive(Debug, Clone, Copy)]
+pub struct HnswBuildParams {
+ pub m: usize,
+ pub ef_construction: usize,
+ pub max_level: usize,
+}
+
+impl Default for HnswBuildParams {
+ fn default() -> Self {
+ Self {
+ m: 20,
+ ef_construction: 150,
+ max_level: 7,
+ }
+ }
+}
+
+impl HnswBuildParams {
+ pub fn sanitized(self) -> Self {
+ Self {
+ m: self.m.max(1),
+ ef_construction: self.ef_construction.max(1),
+ max_level: self.max_level.max(1),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct HnswGraph {
+ d: usize,
+ metric: MetricType,
+ vectors: Vec<f32>,
+ levels: Vec<usize>,
+ neighbors: Vec<Vec<Vec<usize>>>,
+ entry_point: usize,
+ max_observed_level: usize,
+ params: HnswBuildParams,
+}
+
+impl HnswGraph {
+ pub fn build(
+ vectors: &[f32],
+ n: usize,
+ d: usize,
+ metric: MetricType,
+ params: HnswBuildParams,
+ ) -> io::Result<Self> {
+ let expected_len = n.checked_mul(d).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "n * dimension
overflows usize")
+ })?;
+ if vectors.len() < expected_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "vector data length {} is shorter than n*d {}",
+ vectors.len(),
+ expected_len
+ ),
+ ));
+ }
+
+ let params = params.sanitized();
+ let mut graph = HnswGraph {
+ d,
+ metric,
+ vectors: vectors[..n * d].to_vec(),
+ levels: Vec::with_capacity(n),
+ neighbors: Vec::with_capacity(n),
+ entry_point: 0,
+ max_observed_level: 0,
+ params,
+ };
+
+ for node in 0..n {
+ graph.insert(node);
+ }
+ Ok(graph)
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) fn from_parts(
+ vectors: Vec<f32>,
+ n: usize,
+ d: usize,
+ metric: MetricType,
+ levels: Vec<usize>,
+ neighbors: Vec<Vec<Vec<usize>>>,
+ entry_point: usize,
+ max_observed_level: usize,
+ params: HnswBuildParams,
+ ) -> io::Result<Self> {
+ let expected_len = n.checked_mul(d).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "n * dimension
overflows usize")
+ })?;
+ if vectors.len() != expected_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph vector length {} does not match n*d {}",
+ vectors.len(),
+ expected_len
+ ),
+ ));
+ }
+ if levels.len() != n || neighbors.len() != n {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "graph level metadata does not match vector count",
+ ));
+ }
+ if n == 0 {
+ return Ok(Self {
+ d,
+ metric,
+ vectors,
+ levels,
+ neighbors,
+ entry_point: 0,
+ max_observed_level: 0,
+ params: params.sanitized(),
+ });
+ }
+ if entry_point >= n {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("graph entry point {} out of range {}", entry_point,
n),
+ ));
+ }
+ let observed = levels.iter().copied().max().unwrap_or(0);
+ if max_observed_level != observed {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph max level {} does not match observed {}",
+ max_observed_level, observed
+ ),
+ ));
+ }
+ if levels[entry_point] < max_observed_level {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "graph entry point does not reach max observed level",
+ ));
+ }
+ for node in 0..n {
+ if neighbors[node].len() != levels[node] + 1 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("graph node {} has invalid level adjacency", node),
+ ));
+ }
+ for (level, level_neighbors) in neighbors[node].iter().enumerate()
{
+ for &neighbor in level_neighbors {
+ if neighbor >= n || levels[neighbor] < level {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph edge {} -> {} at level {} is invalid",
+ node, neighbor, level
+ ),
+ ));
+ }
+ }
+ }
+ }
+ Ok(Self {
+ d,
+ metric,
+ vectors,
+ levels,
+ neighbors,
+ entry_point,
+ max_observed_level,
+ params: params.sanitized(),
+ })
+ }
+
+ pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize,
f32)> {
+ if self.levels.is_empty() || k == 0 {
+ return Vec::new();
+ }
+
+ let mut ep = self.entry_point;
+ let mut ep_dist = self.distance_to_query(query, ep);
+ for level in (1..=self.max_observed_level).rev() {
+ let (next, dist) = self.greedy_search_query(query, ep, ep_dist,
level);
+ ep = next;
+ ep_dist = dist;
+ }
+
+ let mut visited = vec![0usize; self.levels.len()];
+ let candidates = self.search_layer_query(query, ep, ef.max(k), 0, &mut
visited, 1);
+ candidates
+ .into_iter()
+ .take(k)
+ .map(|n| (n.id, n.dist))
+ .collect()
+ }
+
+ pub fn len(&self) -> usize {
+ self.levels.len()
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.levels.is_empty()
+ }
+
+ pub fn max_degree(&self) -> usize {
+ self.neighbors
+ .iter()
+ .flat_map(|levels| levels.iter().map(Vec::len))
+ .max()
+ .unwrap_or(0)
+ }
+
+ pub(crate) fn vectors(&self) -> &[f32] {
+ &self.vectors
+ }
+
+ pub(crate) fn levels(&self) -> &[usize] {
+ &self.levels
+ }
+
+ pub(crate) fn neighbors(&self) -> &[Vec<Vec<usize>>] {
+ &self.neighbors
+ }
+
+ pub(crate) fn entry_point(&self) -> usize {
+ self.entry_point
+ }
+
+ pub(crate) fn max_observed_level(&self) -> usize {
+ self.max_observed_level
+ }
+
+ fn insert(&mut self, node: usize) {
+ let level = random_level(node, self.params.m, self.params.max_level);
+ self.levels.push(level);
+ self.neighbors.push(vec![Vec::new(); level + 1]);
+
+ if node == 0 {
+ self.entry_point = 0;
+ self.max_observed_level = level;
+ return;
+ }
+
+ let mut ep = self.entry_point;
+ let mut ep_dist = self.distance_between(node, ep);
+
+ for layer in ((level + 1)..=self.max_observed_level).rev() {
+ let (next, dist) = self.greedy_search_node(node, ep, ep_dist,
layer);
+ ep = next;
+ ep_dist = dist;
+ }
+
+ let mut visited = vec![0usize; self.levels.len()];
+ let mut visit_mark = 1usize;
+ for layer in (0..=level.min(self.max_observed_level)).rev() {
+ let candidates = self.search_layer_node(
+ node,
+ ep,
+ self.params.ef_construction,
+ layer,
+ &mut visited,
+ visit_mark,
+ );
+ visit_mark = advance_visit_mark(&mut visited, visit_mark);
+ let selected = self.select_neighbors(candidates,
self.max_neighbors(layer));
+ for neighbor in selected {
+ self.connect(node, neighbor.id, layer);
+ }
+ if let Some(best) = self.neighbors[node][layer]
+ .iter()
+ .copied()
+ .min_by(|&a, &b| {
+ self.distance_between(node, a)
+ .total_cmp(&self.distance_between(node, b))
+ })
+ {
+ ep = best;
+ }
+ }
+
+ if level > self.max_observed_level {
+ self.entry_point = node;
+ self.max_observed_level = level;
+ }
+ }
+
+ fn connect(&mut self, a: usize, b: usize, level: usize) {
+ if !self.neighbors[a][level].contains(&b) {
+ self.neighbors[a][level].push(b);
+ }
+ if level < self.neighbors[b].len() &&
!self.neighbors[b][level].contains(&a) {
+ self.neighbors[b][level].push(a);
+ let pruned = self.pruned_neighbors(b, level,
self.max_neighbors(level));
+ self.neighbors[b][level] = pruned;
+ }
+ let pruned = self.pruned_neighbors(a, level,
self.max_neighbors(level));
+ self.neighbors[a][level] = pruned;
+ }
+
+ fn pruned_neighbors(&self, node: usize, level: usize, max_neighbors:
usize) -> Vec<usize> {
+ let mut ranked: Vec<ScoredNode> = self.neighbors[node][level]
+ .iter()
+ .map(|&id| ScoredNode {
+ id,
+ dist: self.distance_between(node, id),
+ })
+ .collect();
+ ranked.sort_by(|a, b| a.dist.total_cmp(&b.dist));
+ ranked
+ .into_iter()
+ .take(max_neighbors)
+ .map(|node| node.id)
+ .collect()
+ }
+
+ fn select_neighbors(
+ &self,
+ mut candidates: Vec<ScoredNode>,
+ max_neighbors: usize,
+ ) -> Vec<ScoredNode> {
+ candidates.sort_by(|a, b| a.dist.total_cmp(&b.dist));
+ let mut selected: Vec<ScoredNode> = Vec::with_capacity(max_neighbors);
+ let mut backfill: Vec<ScoredNode> = Vec::new();
+ for candidate in candidates {
+ if selected.len() >= max_neighbors {
+ break;
+ }
+ let closer_to_selected = selected
+ .iter()
+ .any(|neighbor| self.distance_between(candidate.id,
neighbor.id) < candidate.dist);
+ if !closer_to_selected {
+ selected.push(candidate);
+ } else {
+ backfill.push(candidate);
+ }
+ }
+ for candidate in backfill {
+ if selected.len() >= max_neighbors {
+ break;
+ }
+ if !selected.iter().any(|neighbor| neighbor.id == candidate.id) {
+ selected.push(candidate);
+ }
+ }
+ selected
+ }
+
+ fn greedy_search_query(
+ &self,
+ query: &[f32],
+ mut current: usize,
+ mut current_dist: f32,
+ level: usize,
+ ) -> (usize, f32) {
+ loop {
+ let mut best = current;
+ let mut best_dist = current_dist;
+ for &neighbor in self.neighbors_at(current, level) {
+ let dist = self.distance_to_query(query, neighbor);
+ if dist < best_dist {
+ best = neighbor;
+ best_dist = dist;
+ }
+ }
+ if best == current {
+ return (current, current_dist);
+ }
+ current = best;
+ current_dist = best_dist;
+ }
+ }
+
+ fn greedy_search_node(
+ &self,
+ node: usize,
+ mut current: usize,
+ mut current_dist: f32,
+ level: usize,
+ ) -> (usize, f32) {
+ loop {
+ let mut best = current;
+ let mut best_dist = current_dist;
+ for &neighbor in self.neighbors_at(current, level) {
+ let dist = self.distance_between(node, neighbor);
+ if dist < best_dist {
+ best = neighbor;
+ best_dist = dist;
+ }
+ }
+ if best == current {
+ return (current, current_dist);
+ }
+ current = best;
+ current_dist = best_dist;
+ }
+ }
+
+ fn search_layer_query(
+ &self,
+ query: &[f32],
+ entry: usize,
+ ef: usize,
+ level: usize,
+ visited: &mut [usize],
+ visit_mark: usize,
+ ) -> Vec<ScoredNode> {
+ self.search_layer(entry, ef, level, visited, visit_mark, |id| {
+ self.distance_to_query(query, id)
+ })
+ }
+
+ fn search_layer_node(
+ &self,
+ node: usize,
+ entry: usize,
+ ef: usize,
+ level: usize,
+ visited: &mut [usize],
+ visit_mark: usize,
+ ) -> Vec<ScoredNode> {
+ self.search_layer(entry, ef, level, visited, visit_mark, |id| {
+ self.distance_between(node, id)
+ })
+ }
+
+ fn search_layer(
+ &self,
+ entry: usize,
+ ef: usize,
+ level: usize,
+ visited: &mut [usize],
+ visit_mark: usize,
+ mut distance: impl FnMut(usize) -> f32,
+ ) -> Vec<ScoredNode> {
+ let entry_dist = distance(entry);
+ visited[entry] = visit_mark;
+
+ let mut candidates = BinaryHeap::new();
+ candidates.push(Reverse(HeapNode {
+ id: entry,
+ dist: entry_dist,
+ }));
+
+ let mut results = BinaryHeap::new();
+ results.push(HeapNode {
+ id: entry,
+ dist: entry_dist,
+ });
+
+ while let Some(Reverse(current)) = candidates.pop() {
+ let worst = results
+ .peek()
+ .map(|node| node.dist)
+ .unwrap_or(f32::INFINITY);
+ if current.dist > worst && results.len() >= ef {
+ break;
+ }
+
+ for &neighbor in self.neighbors_at(current.id, level) {
+ if visited[neighbor] == visit_mark {
+ continue;
+ }
+ visited[neighbor] = visit_mark;
+ let dist = distance(neighbor);
+ let worst = results
+ .peek()
+ .map(|node| node.dist)
+ .unwrap_or(f32::INFINITY);
+ if results.len() < ef || dist < worst {
+ candidates.push(Reverse(HeapNode { id: neighbor, dist }));
+ results.push(HeapNode { id: neighbor, dist });
+ if results.len() > ef {
+ results.pop();
+ }
+ }
+ }
+ }
+
+ let mut result: Vec<ScoredNode> = results
+ .into_iter()
+ .map(|node| ScoredNode {
+ id: node.id,
+ dist: node.dist,
+ })
+ .collect();
+ result.sort_by(|a, b| a.dist.total_cmp(&b.dist));
+ result
+ }
+
+ fn max_neighbors(&self, level: usize) -> usize {
+ if level == 0 {
+ self.params.m * 2
+ } else {
+ self.params.m
+ }
+ }
+
+ fn neighbors_at(&self, node: usize, level: usize) -> &[usize] {
+ self.neighbors
+ .get(node)
+ .and_then(|levels| levels.get(level))
+ .map(Vec::as_slice)
+ .unwrap_or(&[])
+ }
+
+ fn distance_between(&self, a: usize, b: usize) -> f32 {
+ let va = &self.vectors[a * self.d..(a + 1) * self.d];
+ let vb = &self.vectors[b * self.d..(b + 1) * self.d];
+ fvec_distance(va, vb, self.metric)
+ }
+
+ fn distance_to_query(&self, query: &[f32], id: usize) -> f32 {
+ let vector = &self.vectors[id * self.d..(id + 1) * self.d];
+ fvec_distance(query, vector, self.metric)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+struct ScoredNode {
+ id: usize,
+ dist: f32,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+struct HeapNode {
+ id: usize,
+ dist: f32,
+}
+
+impl Eq for HeapNode {}
+
+impl PartialOrd for HeapNode {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl Ord for HeapNode {
+ fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+ self.dist.total_cmp(&other.dist)
+ }
+}
+
+fn random_level(node: usize, m: usize, max_level: usize) -> usize {
+ if node == 0 || max_level <= 1 {
+ // Keep the first insertion deterministic. Later higher-level nodes
replace
+ // the entry point as they appear, while tiny lists naturally stay
flat.
+ return 0;
+ }
+ let mut x = splitmix64(node as u64 + 0x9E37_79B9_7F4A_7C15);
+ let mut level = 0;
+ let threshold = (u64::MAX / m.max(2) as u64).max(1);
+ while level + 1 < max_level && x < threshold {
+ level += 1;
+ x = splitmix64(x);
+ }
+ level
+}
+
+fn advance_visit_mark(visited: &mut [usize], visit_mark: usize) -> usize {
+ visit_mark.checked_add(1).unwrap_or_else(|| {
+ visited.fill(0);
+ 1
+ })
+}
+
+fn splitmix64(mut x: u64) -> u64 {
+ x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
+ let mut z = x;
+ z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
+ z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
+ z ^ (z >> 31)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::distance::MetricType;
+
+ #[test]
+ fn test_hnsw_recalls_query_vector_on_single_partition() {
+ let d = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| [i as f32 * 0.01, 1.0, 2.0, 3.0])
+ .collect();
+ let params = HnswBuildParams {
+ m: 8,
+ ef_construction: 32,
+ max_level: 6,
+ };
+
+ let graph = HnswGraph::build(&data, n, d, MetricType::L2,
params).unwrap();
+ let query_id = 17;
+ let results = graph.search(&data[query_id * d..(query_id + 1) * d], 5,
32);
+
+ assert_eq!(results[0].0, query_id);
+ assert_eq!(results[0].1, 0.0);
+ }
+
+ #[test]
+ fn test_hnsw_empty_graph_returns_no_results() {
+ let graph =
+ HnswGraph::build(&[], 0, 4, MetricType::L2,
HnswBuildParams::default()).unwrap();
+
+ assert!(graph.search(&[0.0, 0.0, 0.0, 0.0], 10, 20).is_empty());
+ assert!(graph.is_empty());
+ assert_eq!(graph.len(), 0);
+ assert_eq!(graph.max_degree(), 0);
+ }
+
+ #[test]
+ fn test_hnsw_build_rejects_short_vector_input() {
+ let err = HnswGraph::build(
+ &[0.0, 1.0, 2.0],
+ 2,
+ 2,
+ MetricType::L2,
+ HnswBuildParams::default(),
+ )
+ .unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err.to_string().contains("shorter than n*d"));
+ }
+
+ #[test]
+ fn test_hnsw_respects_neighbor_degree_bound() {
+ let d = 8;
+ let n = 512;
+ let data = generate_clustered_data(n, d, 16);
+ let params = HnswBuildParams {
+ m: 12,
+ ef_construction: 100,
+ max_level: 6,
+ };
+
+ let graph = HnswGraph::build(&data, n, d, MetricType::L2,
params).unwrap();
+
+ assert_eq!(graph.len(), n);
+ assert!(graph.max_degree() <= params.m * 2);
+ }
+
+ #[test]
+ fn test_hnsw_large_partition_recall_tracks_exact_search() {
+ let d = 16;
+ let n = 4096;
+ let nq = 32;
+ let k = 10;
+ let data = generate_clustered_data(n, d, 32);
+ let params = HnswBuildParams {
+ m: 16,
+ ef_construction: 200,
+ max_level: 7,
+ };
+
+ let graph = HnswGraph::build(&data, n, d, MetricType::L2,
params).unwrap();
+ let mut hits = 0usize;
+ for qi in 0..nq {
+ let query = &data[qi * d..(qi + 1) * d];
+ let expected = exact_topk(&data, n, d, query, k);
+ let actual = graph.search(query, k, 200);
+ hits += actual
+ .iter()
+ .filter(|(id, _)| expected.contains(id))
+ .count();
+ }
+
+ let recall = hits as f32 / (nq * k) as f32;
+ assert!(recall >= 0.95, "recall={}", recall);
+ }
+
+ #[test]
+ fn test_hnsw_neighbor_selection_backfills_after_diversification() {
+ let d = 1;
+ let data = vec![0.0, 1.0, 2.0, 3.0];
+ let graph = HnswGraph::build(
+ &data,
+ 4,
+ d,
+ MetricType::L2,
+ HnswBuildParams {
+ m: 2,
+ ef_construction: 4,
+ max_level: 1,
+ },
+ )
+ .unwrap();
+ let candidates = vec![
+ ScoredNode { id: 1, dist: 1.0 },
+ ScoredNode { id: 2, dist: 2.0 },
+ ScoredNode { id: 3, dist: 3.0 },
+ ];
+
+ let selected = graph.select_neighbors(candidates, 3);
+
+ assert_eq!(selected.len(), 3);
+ }
+
+ #[test]
+ fn test_hnsw_greedy_search_chooses_best_improving_neighbor() {
+ let graph = HnswGraph::from_parts(
+ vec![0.0, 5.0, 2.0],
+ 3,
+ 1,
+ MetricType::L2,
+ vec![0, 0, 0],
+ vec![vec![vec![1, 2]], vec![vec![]], vec![vec![]]],
+ 0,
+ 0,
+ HnswBuildParams::default(),
+ )
+ .unwrap();
+
+ let (next, dist) = graph.greedy_search_query(&[2.0], 0, 4.0, 0);
+
+ assert_eq!(next, 2);
+ assert_eq!(dist, 0.0);
+ }
+
+ fn exact_topk(data: &[f32], n: usize, d: usize, query: &[f32], k: usize)
-> Vec<usize> {
+ let mut distances: Vec<(f32, usize)> = (0..n)
+ .map(|i| {
+ let vector = &data[i * d..(i + 1) * d];
+ (fvec_distance(query, vector, MetricType::L2), i)
+ })
+ .collect();
+ distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ distances[..k].iter().map(|&(_, id)| id).collect()
+ }
+
+ fn generate_clustered_data(n: usize, d: usize, num_clusters: usize) ->
Vec<f32> {
+ let mut data = vec![0.0f32; n * d];
+ for i in 0..n {
+ let cluster = i % num_clusters;
+ for j in 0..d {
+ data[i * d + j] = cluster as f32 * 20.0 + j as f32 * 0.01 + i
as f32 * 0.0001;
+ }
+ }
+ data
+ }
+}
diff --git a/core/src/ivfflat.rs b/core/src/ivfflat.rs
index 185bdf1..601bb3d 100644
--- a/core/src/ivfflat.rs
+++ b/core/src/ivfflat.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::distance::{fvec_distance, fvec_normalize, MetricType};
+use crate::distance::{fvec_distance, preprocess_vectors, MetricType};
use crate::ivfpq::RowIdFilter;
use crate::kmeans::{self, KMeansConfig};
@@ -134,13 +134,7 @@ impl IVFFlatIndex {
}
pub(crate) fn preprocess_vectors(&self, data: &[f32], n: usize) ->
Vec<f32> {
- let mut processed = data[..n * self.d].to_vec();
- if self.metric == MetricType::Cosine {
- for i in 0..n {
- fvec_normalize(&mut processed[i * self.d..(i + 1) * self.d]);
- }
- }
- processed
+ preprocess_vectors(data, n, self.d, self.metric)
}
}
diff --git a/core/src/ivfhnswflat.rs b/core/src/ivfhnswflat.rs
new file mode 100644
index 0000000..e2f7ccf
--- /dev/null
+++ b/core/src/ivfhnswflat.rs
@@ -0,0 +1,355 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::ivfflat::IVFFlatIndex;
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans;
+use crate::topk::TopKHeap;
+use std::io;
+
+pub struct IVFHNSWFlatIndex {
+ /// Exposed to match the existing core index structs. Mutating `flat`
+ /// directly can stale `graphs`; call `build_graphs` again before HNSW
search.
+ pub flat: IVFFlatIndex,
+ pub graphs: Vec<Option<HnswGraph>>,
+ pub hnsw_params: HnswBuildParams,
+}
+
+impl IVFHNSWFlatIndex {
+ pub fn new(d: usize, nlist: usize, metric: MetricType, hnsw_params:
HnswBuildParams) -> Self {
+ IVFHNSWFlatIndex {
+ flat: IVFFlatIndex::new(d, nlist, metric),
+ graphs: vec![None; nlist],
+ hnsw_params,
+ }
+ }
+
+ pub fn train(&mut self, data: &[f32], n: usize) {
+ self.flat.train(data, n);
+ }
+
+ pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
+ self.flat.add(data, ids, n);
+ self.graphs.fill(None);
+ }
+
+ pub fn build_graphs(&mut self) -> io::Result<()> {
+ for list_id in 0..self.flat.nlist {
+ let count = self.flat.ids[list_id].len();
+ self.graphs[list_id] = if count == 0 {
+ None
+ } else {
+ Some(HnswGraph::build(
+ &self.flat.vectors[list_id],
+ count,
+ self.flat.d,
+ self.flat.metric,
+ self.hnsw_params,
+ )?)
+ };
+ }
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn search(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ self.search_with_filter(
+ queries,
+ nq,
+ k,
+ nprobe,
+ ef_search,
+ None,
+ result_distances,
+ result_labels,
+ );
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn search_with_filter(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ filter: Option<&dyn RowIdFilter>,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ let processed_queries = self.flat.preprocess_vectors(queries, nq);
+ let (all_probe_indices, _) = kmeans::find_topk_batch(
+ &processed_queries,
+ nq,
+ &self.flat.quantizer_centroids,
+ self.flat.nlist,
+ self.flat.d,
+ nprobe,
+ );
+
+ for qi in 0..nq {
+ let query = &processed_queries[qi * self.flat.d..(qi + 1) *
self.flat.d];
+ let mut heap = TopKHeap::new(k);
+ let force_flat_scan = filter
+ .map(|f| self.count_filtered(&all_probe_indices[qi], f) <=
ef_search.max(k))
+ .unwrap_or(false);
+
+ for &list_id in &all_probe_indices[qi] {
+ if force_flat_scan {
+ self.scan_flat_list(query, list_id, filter, &mut heap);
+ continue;
+ }
+ if let Some(ref graph) = self.graphs[list_id] {
+ let local_results = graph.search(query, ef_search.max(k),
ef_search.max(k));
+ for (local_id, dist) in local_results {
+ let row_id = self.flat.ids[list_id][local_id];
+ if let Some(f) = filter {
+ if !f.contains(row_id) {
+ continue;
+ }
+ }
+ heap.push(dist, row_id);
+ }
+ } else {
+ self.scan_flat_list(query, list_id, filter, &mut heap);
+ }
+ }
+ if filter.is_some() && heap.len() < k {
+ for &list_id in &all_probe_indices[qi] {
+ self.scan_flat_list(query, list_id, filter, &mut heap);
+ }
+ }
+
+ let sorted = heap.into_sorted();
+ let out_base = qi * k;
+ for (i, &(dist, id)) in sorted.iter().enumerate() {
+ result_distances[out_base + i] = dist;
+ result_labels[out_base + i] = id;
+ }
+ for i in sorted.len()..k {
+ result_distances[out_base + i] = f32::MAX;
+ result_labels[out_base + i] = -1;
+ }
+ }
+ }
+
+ fn count_filtered(&self, probe_indices: &[usize], filter: &dyn
RowIdFilter) -> usize {
+ probe_indices
+ .iter()
+ .map(|&list_id| {
+ self.flat.ids[list_id]
+ .iter()
+ .filter(|&&id| filter.contains(id))
+ .count()
+ })
+ .sum()
+ }
+
+ fn scan_flat_list(
+ &self,
+ query: &[f32],
+ list_id: usize,
+ filter: Option<&dyn RowIdFilter>,
+ heap: &mut TopKHeap,
+ ) {
+ for (local_id, &row_id) in self.flat.ids[list_id].iter().enumerate() {
+ if let Some(f) = filter {
+ if !f.contains(row_id) {
+ continue;
+ }
+ }
+ let vector =
+ &self.flat.vectors[list_id][local_id * self.flat.d..(local_id
+ 1) * self.flat.d];
+ heap.push(fvec_distance(query, vector, self.flat.metric), row_id);
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::distance::MetricType;
+ use crate::hnsw::HnswBuildParams;
+
+ #[test]
+ fn test_ivfhnswflat_recalls_query_vector() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let query_id = 23;
+ let mut distances = vec![0.0; 5];
+ let mut labels = vec![0; 5];
+ index.search(
+ &data[query_id * d..(query_id + 1) * d],
+ 1,
+ 5,
+ nlist,
+ 32,
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(labels[0], ids[query_id]);
+ assert_eq!(distances[0], 0.0);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_without_built_graphs_falls_back_to_flat_scan() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let query_id = 23;
+ let mut distances = vec![0.0; 5];
+ let mut labels = vec![0; 5];
+ index.search(
+ &data[query_id * d..(query_id + 1) * d],
+ 1,
+ 5,
+ nlist,
+ 32,
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(labels[0], ids[query_id]);
+ assert_eq!(distances[0], 0.0);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_selective_filter_uses_exact_results() {
+ use std::collections::HashSet;
+
+ let d = 2;
+ let nlist = 1;
+ let n = 64;
+ let mut data = Vec::with_capacity(n * d);
+ for i in 0..n {
+ data.push(i as f32);
+ data.push(0.0);
+ }
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let filter: HashSet<i64> = [63].into_iter().collect();
+ let mut distances = vec![0.0; 1];
+ let mut labels = vec![0; 1];
+ index.search_with_filter(
+ &[0.0, 0.0],
+ 1,
+ 1,
+ 1,
+ 4,
+ Some(&filter),
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(labels[0], 63);
+ assert_eq!(distances[0], 63.0 * 63.0);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_filter_backfills_when_graph_returns_too_few_matches() {
+ use std::collections::HashSet;
+
+ let d = 2;
+ let nlist = 1;
+ let n = 128;
+ let mut data = Vec::with_capacity(n * d);
+ for i in 0..n {
+ data.push(i as f32);
+ data.push(0.0);
+ }
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let filter: HashSet<i64> = (0..n as i64).filter(|id| id % 2 ==
0).collect();
+ let mut distances = vec![0.0; 10];
+ let mut labels = vec![0; 10];
+ index.search_with_filter(
+ &[127.0, 0.0],
+ 1,
+ 10,
+ 1,
+ 1,
+ Some(&filter),
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(
+ labels,
+ vec![126, 124, 122, 120, 118, 116, 114, 112, 110, 108]
+ );
+ assert!(labels.iter().all(|id| id % 2 == 0));
+ }
+
+ #[test]
+ fn test_topk_heap_keeps_closest_duplicate_id() {
+ let mut heap = TopKHeap::new(2);
+
+ heap.push(10.0, 7);
+ heap.push(5.0, 8);
+ heap.push(1.0, 7);
+
+ assert_eq!(heap.into_sorted(), vec![(1.0, 7), (5.0, 8)]);
+ }
+}
diff --git a/core/src/ivfhnswflat_io.rs b/core/src/ivfhnswflat_io.rs
new file mode 100644
index 0000000..d504d43
--- /dev/null
+++ b/core/src/ivfhnswflat_io.rs
@@ -0,0 +1,1460 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, preprocess_vectors, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::io::{SeekRead, SeekWrite};
+use crate::ivfhnswflat::IVFHNSWFlatIndex;
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans;
+use crate::topk::TopKHeap;
+use roaring::RoaringTreemap;
+use std::io;
+
+pub const IVF_HNSW_FLAT_MAGIC: u32 = 0x4948464C; // "IHFL"
+pub const IVF_HNSW_FLAT_VERSION: u32 = 1;
+pub const IVF_HNSW_FLAT_HEADER_SIZE: usize = 64;
+
+pub fn write_ivfhnswflat_index(
+ index: &IVFHNSWFlatIndex,
+ out: &mut dyn SeekWrite,
+) -> io::Result<()> {
+ validate_index_shape(index)?;
+ let d = index.flat.d;
+ let nlist = index.flat.nlist;
+ let total_vectors = index.flat.ids.iter().try_fold(0i64, |sum, ids| {
+ let count = usize_to_i64(ids.len(), "total vector count")?;
+ sum.checked_add(count).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "total vector count exceeds i64 length limit",
+ )
+ })
+ })?;
+ let graph_bytes: Vec<Vec<u8>> = (0..nlist)
+ .map(|list_id| {
+ if index.flat.ids[list_id].is_empty() {
+ Ok(Vec::new())
+ } else {
+ encode_graph(index.graphs[list_id].as_ref())
+ }
+ })
+ .collect::<io::Result<_>>()?;
+
+ write_u32_le(out, IVF_HNSW_FLAT_MAGIC)?;
+ write_u32_le(out, IVF_HNSW_FLAT_VERSION)?;
+ write_i32_le(out, usize_to_i32(d, "dimension")?)?;
+ write_i32_le(out, usize_to_i32(nlist, "nlist")?)?;
+ write_u32_le(out, index.flat.metric as u32)?;
+ write_i64_le(out, total_vectors)?;
+ let params = index.hnsw_params.sanitized();
+ write_i32_le(out, usize_to_i32(params.m, "hnsw m")?)?;
+ write_i32_le(
+ out,
+ usize_to_i32(params.ef_construction, "hnsw ef_construction")?,
+ )?;
+ write_i32_le(out, usize_to_i32(params.max_level, "hnsw max_level")?)?;
+ out.write_all(&[0u8; 24])?;
+
+ write_f32_slice(out, &index.flat.quantizer_centroids)?;
+
+ let offset_table_size = nlist.checked_mul(24).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-FLAT offset table size overflow",
+ )
+ })?;
+ let data_start = out
+ .pos()
+ .checked_add(offset_table_size as u64)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-FLAT data start offset overflow",
+ )
+ })?;
+ let mut list_offsets = vec![0i64; nlist];
+ let mut list_counts = vec![0i32; nlist];
+ let mut list_graph_bytes_lens = vec![0i32; nlist];
+ let mut current_offset = data_start;
+
+ for list_id in 0..nlist {
+ list_offsets[list_id] = u64_to_i64(current_offset, "list offset")?;
+ let count = index.flat.ids[list_id].len();
+ list_counts[list_id] = usize_to_i32(count, "list count")?;
+ list_graph_bytes_lens[list_id] =
usize_to_i32(graph_bytes[list_id].len(), "graph bytes")?;
+ if count > 0 {
+ let payload_len = list_payload_len(count, d,
graph_bytes[list_id].len())?;
+ current_offset = current_offset
+ .checked_add(payload_len as u64)
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "IVF-HNSW-FLAT
offset overflow")
+ })?;
+ }
+ }
+
+ for list_id in 0..nlist {
+ write_i64_le(out, list_offsets[list_id])?;
+ write_i32_le(out, list_counts[list_id])?;
+ write_i32_le(out, list_graph_bytes_lens[list_id])?;
+ write_i64_le(out, 0)?;
+ }
+
+ for list_id in 0..nlist {
+ if index.flat.ids[list_id].is_empty() {
+ continue;
+ }
+ for &id in &index.flat.ids[list_id] {
+ write_i64_le(out, id)?;
+ }
+ write_f32_slice(out, &index.flat.vectors[list_id])?;
+ out.write_all(&graph_bytes[list_id])?;
+ }
+
+ Ok(())
+}
+
+pub struct IVFHNSWFlatIndexReader<R: SeekRead> {
+ reader: R,
+ pub d: usize,
+ pub nlist: usize,
+ pub metric: MetricType,
+ pub total_vectors: i64,
+ pub hnsw_params: HnswBuildParams,
+ pub quantizer_centroids: Vec<f32>,
+ pub list_offsets: Vec<i64>,
+ pub list_counts: Vec<i32>,
+ pub list_graph_bytes_lens: Vec<i32>,
+ loaded: bool,
+}
+
+impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
+ pub fn open(mut reader: R) -> io::Result<Self> {
+ reader.seek(0)?;
+
+ let magic = read_u32_le(&mut reader)?;
+ if magic != IVF_HNSW_FLAT_MAGIC {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Invalid IVF_HNSW_FLAT magic: 0x{:08X}", magic),
+ ));
+ }
+ let version = read_u32_le(&mut reader)?;
+ if version != IVF_HNSW_FLAT_VERSION {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unsupported IVF_HNSW_FLAT version: {}", version),
+ ));
+ }
+
+ let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as
usize;
+ let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")?
as usize;
+ let metric_code = read_u32_le(&mut reader)?;
+ let metric = MetricType::from_code(metric_code).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unknown metric type: {}", metric_code),
+ )
+ })?;
+ let total_vectors = read_i64_le(&mut reader)?;
+ let hnsw_params = HnswBuildParams {
+ m: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw m")? as
usize,
+ ef_construction: validate_positive_i32(
+ read_i32_le(&mut reader)?,
+ "hnsw ef_construction",
+ )? as usize,
+ max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw
max_level")? as usize,
+ }
+ .sanitized();
+ let mut reserved = [0u8; 24];
+ reader.read_exact(&mut reserved)?;
+
+ Ok(Self {
+ reader,
+ d,
+ nlist,
+ metric,
+ total_vectors,
+ hnsw_params,
+ quantizer_centroids: Vec::new(),
+ list_offsets: Vec::new(),
+ list_counts: Vec::new(),
+ list_graph_bytes_lens: Vec::new(),
+ loaded: false,
+ })
+ }
+
+ pub fn ensure_loaded(&mut self) -> io::Result<()> {
+ if self.loaded {
+ return Ok(());
+ }
+
+ self.reader.seek(IVF_HNSW_FLAT_HEADER_SIZE as u64)?;
+ self.quantizer_centroids =
+ read_f32_vec(&mut self.reader, checked_section_size(self.nlist,
self.d)?)?;
+ self.list_offsets = vec![0; self.nlist];
+ self.list_counts = vec![0; self.nlist];
+ self.list_graph_bytes_lens = vec![0; self.nlist];
+ for list_id in 0..self.nlist {
+ self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
+ let count = read_i32_le(&mut self.reader)?;
+ if count < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("negative list count {} at list {}", count,
list_id),
+ ));
+ }
+ self.list_counts[list_id] = count;
+ let graph_bytes_len = read_i32_le(&mut self.reader)?;
+ if graph_bytes_len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "negative graph_bytes_len {} at list {}",
+ graph_bytes_len, list_id
+ ),
+ ));
+ }
+ self.list_graph_bytes_lens[list_id] = graph_bytes_len;
+ let _reserved = read_i64_le(&mut self.reader)?;
+ }
+
+ self.loaded = true;
+ Ok(())
+ }
+
+ pub fn read_inverted_list(
+ &mut self,
+ list_id: usize,
+ ) -> io::Result<(Vec<i64>, Vec<f32>, Option<HnswGraph>)> {
+ let Some(list) = self.read_graph_list(list_id)? else {
+ return Ok((Vec::new(), Vec::new(), None));
+ };
+ let vectors = list.graph.vectors().to_vec();
+ Ok((list.ids, vectors, Some(list.graph)))
+ }
+
+ fn read_graph_list(&mut self, list_id: usize) ->
io::Result<Option<GraphList>> {
+ self.ensure_loaded()?;
+ if list_id >= self.nlist {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list_id {} out of range (nlist={})", list_id,
self.nlist),
+ ));
+ }
+ let count = self.list_counts[list_id] as usize;
+ if count == 0 {
+ return Ok(None);
+ }
+
+ let offset = checked_list_offset(self.list_offsets[list_id], list_id)?;
+ let graph_bytes_len = self.list_graph_bytes_lens[list_id] as usize;
+ if graph_bytes_len == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("list {} is missing HNSW graph", list_id),
+ ));
+ }
+ let payload_len = list_payload_len(count, self.d, graph_bytes_len)?;
+ let mut payload = vec![0u8; payload_len];
+ self.reader.pread(offset, &mut payload)?;
+
+ let ids_bytes_len = checked_list_bytes(count, 8)?;
+ let vector_bytes_len = checked_list_bytes(
+ count,
+ self.d.checked_mul(4).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT bytes per vector overflow",
+ )
+ })?,
+ )?;
+ let ids = payload[..ids_bytes_len]
+ .chunks_exact(8)
+ .map(|c| i64::from_le_bytes(c.try_into().unwrap()))
+ .collect();
+ let vectors = bytes_to_f32_vec(&payload[ids_bytes_len..ids_bytes_len +
vector_bytes_len])?;
+ let graph = decode_graph(
+ &payload[ids_bytes_len + vector_bytes_len..],
+ vectors.clone(),
+ count,
+ self.d,
+ self.metric,
+ self.hnsw_params,
+ )?;
+ let graph = graph.ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("list {} is missing HNSW graph", list_id),
+ )
+ })?;
+ Ok(Some(GraphList { ids, graph }))
+ }
+
+ pub fn search(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ self.search_with_filter(query, k, nprobe, ef_search, None)
+ }
+
+ pub fn search_with_filter(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ filter: Option<&dyn RowIdFilter>,
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ self.ensure_loaded()?;
+ if query.len() != self.d {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "query length {} does not match index dimension {}",
+ query.len(),
+ self.d
+ ),
+ ));
+ }
+ if k == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "k must be greater than 0",
+ ));
+ }
+ if nprobe == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nprobe must be greater than 0",
+ ));
+ }
+
+ let q = preprocess_vectors(query, 1, self.d, self.metric);
+ let (probe_indices, _) =
+ kmeans::find_topk(&q, &self.quantizer_centroids, self.nlist,
self.d, nprobe);
+ let mut loaded_lists = Vec::with_capacity(probe_indices.len());
+ for list_id in probe_indices {
+ if let Some(list) = self.read_graph_list(list_id)? {
+ loaded_lists.push(list);
+ }
+ }
+ let mut heap = TopKHeap::new(k);
+ let force_flat_scan = filter
+ .map(|f| count_filtered(&loaded_lists, f) <= ef_search.max(k))
+ .unwrap_or(false);
+
+ for list in &loaded_lists {
+ if force_flat_scan {
+ scan_flat_list(
+ &q,
+ &list.ids,
+ list.graph.vectors(),
+ self.d,
+ self.metric,
+ filter,
+ &mut heap,
+ );
+ } else {
+ let local_results = list.graph.search(&q, ef_search.max(k),
ef_search.max(k));
+ for (local_id, dist) in local_results {
+ let row_id = list.ids[local_id];
+ if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+ heap.push(dist, row_id);
+ }
+ }
+ }
+ }
+ if filter.is_some() && heap.len() < k {
+ for list in &loaded_lists {
+ scan_flat_list(
+ &q,
+ &list.ids,
+ list.graph.vectors(),
+ self.d,
+ self.metric,
+ filter,
+ &mut heap,
+ );
+ }
+ }
+
+ let sorted = heap.into_sorted();
+ let mut labels: Vec<i64> = sorted.iter().map(|&(_, id)| id).collect();
+ let mut distances: Vec<f32> = sorted.iter().map(|&(dist, _)|
dist).collect();
+ labels.resize(k, -1);
+ distances.resize(k, f32::MAX);
+ Ok((labels, distances))
+ }
+
+ pub fn search_with_roaring_filter(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ roaring_filter_bytes: &[u8],
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ let filter = decode_roaring_filter(roaring_filter_bytes)?;
+ self.search_with_filter(query, k, nprobe, ef_search, Some(&filter))
+ }
+}
+
+pub fn search_batch_ivfhnswflat_reader<R: SeekRead>(
+ reader: &mut IVFHNSWFlatIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ search_batch_ivfhnswflat_reader_filter(reader, queries, nq, k, nprobe,
ef_search, None)
+}
+
+pub fn search_batch_ivfhnswflat_reader_filter<R: SeekRead>(
+ reader: &mut IVFHNSWFlatIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ filter: Option<&dyn RowIdFilter>,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ reader.ensure_loaded()?;
+ validate_search_inputs(queries, nq, reader.d, k, nprobe)?;
+
+ let processed = preprocess_vectors(queries, nq, reader.d, reader.metric);
+ let (all_probe_indices, _) = kmeans::find_topk_batch(
+ &processed,
+ nq,
+ &reader.quantizer_centroids,
+ reader.nlist,
+ reader.d,
+ nprobe,
+ );
+
+ let mut list_to_queries = vec![Vec::new(); reader.nlist];
+ let mut unique_lists = Vec::new();
+ for (qi, probe_indices) in all_probe_indices.iter().enumerate() {
+ for &list_id in probe_indices {
+ if list_to_queries[list_id].is_empty() {
+ unique_lists.push(list_id);
+ }
+ list_to_queries[list_id].push(qi);
+ }
+ }
+
+ let mut heaps: Vec<TopKHeap> = (0..nq).map(|_| TopKHeap::new(k)).collect();
+ let mut query_filtered_counts = vec![0usize; nq];
+ let mut loaded_lists = Vec::with_capacity(unique_lists.len());
+ for list_id in unique_lists {
+ let Some(list) = reader.read_graph_list(list_id)? else {
+ continue;
+ };
+ if let Some(f) = filter {
+ let filtered = list.ids.iter().filter(|&&id|
f.contains(id)).count();
+ for &qi in &list_to_queries[list_id] {
+ query_filtered_counts[qi] = query_filtered_counts[qi]
+ .checked_add(filtered)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "filtered vector count overflows usize",
+ )
+ })?;
+ }
+ }
+ loaded_lists.push(LoadedBatchList {
+ query_ids: std::mem::take(&mut list_to_queries[list_id]),
+ ids: list.ids,
+ graph: list.graph,
+ });
+ }
+
+ for list in &loaded_lists {
+ for &qi in &list.query_ids {
+ let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+ let force_flat_scan = filter
+ .map(|_| query_filtered_counts[qi] <= ef_search.max(k))
+ .unwrap_or(false);
+ if force_flat_scan {
+ scan_flat_list(
+ query,
+ &list.ids,
+ list.graph.vectors(),
+ reader.d,
+ reader.metric,
+ filter,
+ &mut heaps[qi],
+ );
+ } else {
+ let local_results = list.graph.search(query, ef_search.max(k),
ef_search.max(k));
+ for (local_id, dist) in local_results {
+ let row_id = list.ids[local_id];
+ if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+ heaps[qi].push(dist, row_id);
+ }
+ }
+ }
+ }
+ }
+ if filter.is_some() {
+ for list in &loaded_lists {
+ for &qi in &list.query_ids {
+ if heaps[qi].len() >= k {
+ continue;
+ }
+ let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+ scan_flat_list(
+ query,
+ &list.ids,
+ list.graph.vectors(),
+ reader.d,
+ reader.metric,
+ filter,
+ &mut heaps[qi],
+ );
+ }
+ }
+ }
+
+ let mut result_ids = vec![-1i64; nq * k];
+ let mut result_dists = vec![f32::MAX; nq * k];
+ for qi in 0..nq {
+ let sorted = std::mem::replace(&mut heaps[qi],
TopKHeap::new(0)).into_sorted();
+ let base = qi * k;
+ for (i, &(dist, id)) in sorted.iter().enumerate() {
+ result_ids[base + i] = id;
+ result_dists[base + i] = dist;
+ }
+ }
+
+ Ok((result_ids, result_dists))
+}
+
+pub fn search_batch_ivfhnswflat_reader_roaring_filter<R: SeekRead>(
+ reader: &mut IVFHNSWFlatIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ roaring_filter_bytes: &[u8],
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ let filter = decode_roaring_filter(roaring_filter_bytes)?;
+ search_batch_ivfhnswflat_reader_filter(reader, queries, nq, k, nprobe,
ef_search, Some(&filter))
+}
+
+struct GraphList {
+ ids: Vec<i64>,
+ graph: HnswGraph,
+}
+
+struct LoadedBatchList {
+ query_ids: Vec<usize>,
+ ids: Vec<i64>,
+ graph: HnswGraph,
+}
+
+fn count_filtered(lists: &[GraphList], filter: &dyn RowIdFilter) -> usize {
+ lists
+ .iter()
+ .map(|list| list.ids.iter().filter(|&&id| filter.contains(id)).count())
+ .sum()
+}
+
+fn scan_flat_list(
+ query: &[f32],
+ ids: &[i64],
+ vectors: &[f32],
+ d: usize,
+ metric: MetricType,
+ filter: Option<&dyn RowIdFilter>,
+ heap: &mut TopKHeap,
+) {
+ for (local_id, &row_id) in ids.iter().enumerate() {
+ if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
+ continue;
+ }
+ let vector = &vectors[local_id * d..(local_id + 1) * d];
+ heap.push(fvec_distance(query, vector, metric), row_id);
+ }
+}
+
+fn validate_search_inputs(
+ queries: &[f32],
+ nq: usize,
+ d: usize,
+ k: usize,
+ nprobe: usize,
+) -> io::Result<()> {
+ if nq == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nq must be greater than 0",
+ ));
+ }
+ let expected_query_len = nq.checked_mul(d).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nq * dimension overflows usize",
+ )
+ })?;
+ if queries.len() != expected_query_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "queries length {} does not match nq * dimension {}",
+ queries.len(),
+ expected_query_len
+ ),
+ ));
+ }
+ if k == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "k must be greater than 0",
+ ));
+ }
+ if nprobe == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nprobe must be greater than 0",
+ ));
+ }
+ Ok(())
+}
+
+fn encode_graph(graph: Option<&HnswGraph>) -> io::Result<Vec<u8>> {
+ let Some(graph) = graph else {
+ return Ok(Vec::new());
+ };
+ let mut buf = Vec::new();
+ write_u32_vec(&mut buf, graph.len())?;
+ write_u32_vec(&mut buf, graph.entry_point())?;
+ write_u32_vec(&mut buf, graph.max_observed_level())?;
+ for &level in graph.levels() {
+ write_u32_vec(&mut buf, level)?;
+ }
+ for node_levels in graph.neighbors() {
+ for level_neighbors in node_levels {
+ write_u32_vec(&mut buf, level_neighbors.len())?;
+ for &neighbor in level_neighbors {
+ write_u32_vec(&mut buf, neighbor)?;
+ }
+ }
+ }
+ Ok(buf)
+}
+
+fn decode_graph(
+ bytes: &[u8],
+ vectors: Vec<f32>,
+ count: usize,
+ d: usize,
+ metric: MetricType,
+ hnsw_params: HnswBuildParams,
+) -> io::Result<Option<HnswGraph>> {
+ if bytes.is_empty() {
+ return Ok(None);
+ }
+ let mut pos = 0usize;
+ let graph_count = read_u32_vec(bytes, &mut pos)? as usize;
+ if graph_count != count {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph count {} does not match list count {}",
+ graph_count, count
+ ),
+ ));
+ }
+ let entry_point = read_u32_vec(bytes, &mut pos)? as usize;
+ let max_observed_level = read_u32_vec(bytes, &mut pos)? as usize;
+ let mut levels = Vec::with_capacity(count);
+ for node in 0..count {
+ let level = read_u32_vec(bytes, &mut pos)? as usize;
+ if level >= hnsw_params.max_level {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph node {} level {} exceeds max_level {}",
+ node,
+ level,
+ hnsw_params.max_level - 1
+ ),
+ ));
+ }
+ levels.push(level);
+ }
+ let mut neighbors = Vec::with_capacity(count);
+ for (node, &level) in levels.iter().enumerate() {
+ let mut node_levels = Vec::with_capacity(level + 1);
+ for graph_level in 0..=level {
+ let degree = read_u32_vec(bytes, &mut pos)? as usize;
+ let max_degree = if graph_level == 0 {
+ hnsw_params.m.saturating_mul(2)
+ } else {
+ hnsw_params.m
+ };
+ if degree > max_degree {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph node {} degree {} at level {} exceeds max
degree {}",
+ node, degree, graph_level, max_degree
+ ),
+ ));
+ }
+ let mut level_neighbors = Vec::with_capacity(degree);
+ for _ in 0..degree {
+ level_neighbors.push(read_u32_vec(bytes, &mut pos)? as usize);
+ }
+ node_levels.push(level_neighbors);
+ }
+ neighbors.push(node_levels);
+ }
+ if pos != bytes.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "trailing bytes in HNSW graph section",
+ ));
+ }
+ Ok(Some(HnswGraph::from_parts(
+ vectors,
+ count,
+ d,
+ metric,
+ levels,
+ neighbors,
+ entry_point,
+ max_observed_level,
+ hnsw_params,
+ )?))
+}
+
+fn write_u32_vec(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
+ if value > u32::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("value {} exceeds u32 limit", value),
+ ));
+ }
+ buf.extend_from_slice(&(value as u32).to_le_bytes());
+ Ok(())
+}
+
+fn read_u32_vec(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
+ let end = pos.checked_add(4).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "graph section position overflow",
+ )
+ })?;
+ if end > bytes.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "truncated HNSW graph section",
+ ));
+ }
+ let value = u32::from_le_bytes(bytes[*pos..end].try_into().unwrap());
+ *pos = end;
+ Ok(value)
+}
+
+fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+fn write_i32_le(out: &mut dyn SeekWrite, v: i32) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+fn write_i64_le(out: &mut dyn SeekWrite, v: i64) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32]) -> io::Result<()> {
+ let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
+ out.write_all(&bytes)
+}
+
+fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+ let mut buf = [0u8; 4];
+ reader.read_exact(&mut buf)?;
+ Ok(u32::from_le_bytes(buf))
+}
+
+fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+ let mut buf = [0u8; 4];
+ reader.read_exact(&mut buf)?;
+ Ok(i32::from_le_bytes(buf))
+}
+
+fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+ let mut buf = [0u8; 8];
+ reader.read_exact(&mut buf)?;
+ Ok(i64::from_le_bytes(buf))
+}
+
+fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) ->
io::Result<Vec<f32>> {
+ let byte_len = count.checked_mul(4).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "f32 section byte length overflow",
+ )
+ })?;
+ let mut buf = vec![0u8; byte_len];
+ reader.read_exact(&mut buf)?;
+ bytes_to_f32_vec(&buf)
+}
+
+fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
+ if !bytes.len().is_multiple_of(4) {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "f32 byte section is not 4-byte aligned",
+ ));
+ }
+ Ok(bytes
+ .chunks_exact(4)
+ .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
+ .collect())
+}
+
+fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
+ if val <= 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("invalid header field {}: {} (must be positive)", field,
val),
+ ));
+ }
+ Ok(val)
+}
+
+fn validate_index_shape(index: &IVFHNSWFlatIndex) -> io::Result<()> {
+ if index.flat.d == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "dimension must be greater than 0",
+ ));
+ }
+ if index.flat.nlist == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nlist must be greater than 0",
+ ));
+ }
+ if index.flat.ids.len() != index.flat.nlist
+ || index.flat.vectors.len() != index.flat.nlist
+ || index.graphs.len() != index.flat.nlist
+ {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-FLAT list storage does not match nlist",
+ ));
+ }
+ let centroid_len = checked_section_size(index.flat.nlist, index.flat.d)?;
+ if index.flat.quantizer_centroids.len() != centroid_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "centroid length {} does not match nlist*d {}",
+ index.flat.quantizer_centroids.len(),
+ centroid_len
+ ),
+ ));
+ }
+ for list_id in 0..index.flat.nlist {
+ let count = index.flat.ids[list_id].len();
+ let expected_vector_len =
count.checked_mul(index.flat.d).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-FLAT vector length overflow",
+ )
+ })?;
+ if index.flat.vectors[list_id].len() != expected_vector_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "list {} vector length {} does not match ids*d {}",
+ list_id,
+ index.flat.vectors[list_id].len(),
+ expected_vector_len
+ ),
+ ));
+ }
+ match &index.graphs[list_id] {
+ Some(graph)
+ if graph.len() == count
+ && graph.vectors() ==
index.flat.vectors[list_id].as_slice() => {}
+ Some(_) => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list {} graph does not match vector storage",
list_id),
+ ));
+ }
+ None if count == 0 => {}
+ None => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list {} is missing HNSW graph", list_id),
+ ));
+ }
+ }
+ }
+ Ok(())
+}
+
+fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
+ if value > i32::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i32 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i32)
+}
+
+fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
+fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as u64 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 offset limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
+const MAX_SECTION_ELEMENTS: usize = 1 << 30;
+
+fn checked_section_size(a: usize, b: usize) -> io::Result<usize> {
+ let result = a.checked_mul(b).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "section size overflow in IVF-HNSW-FLAT header",
+ )
+ })?;
+ if result > MAX_SECTION_ELEMENTS {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "section size {} exceeds maximum {}",
+ result, MAX_SECTION_ELEMENTS
+ ),
+ ));
+ }
+ Ok(result)
+}
+
+fn checked_list_offset(offset: i64, list_id: usize) -> io::Result<u64> {
+ if offset < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("negative list offset {} at list {}", offset, list_id),
+ ));
+ }
+ Ok(offset as u64)
+}
+
+fn checked_list_bytes(count: usize, bytes_per_entry: usize) ->
io::Result<usize> {
+ count.checked_mul(bytes_per_entry).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT list byte size overflow",
+ )
+ })
+}
+
+fn list_payload_len(count: usize, d: usize, graph_bytes_len: usize) ->
io::Result<usize> {
+ let id_bytes = checked_list_bytes(count, 8)?;
+ let vector_bytes = checked_list_bytes(
+ count,
+ d.checked_mul(4).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT bytes per vector overflow",
+ )
+ })?,
+ )?;
+ id_bytes
+ .checked_add(vector_bytes)
+ .and_then(|len| len.checked_add(graph_bytes_len))
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-HNSW-FLAT list payload overflow",
+ )
+ })
+}
+
+fn decode_roaring_filter(bytes: &[u8]) -> io::Result<RoaringTreemap> {
+ RoaringTreemap::deserialize_from(bytes).map_err(|e| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("invalid RoaringTreemap filter: {}", e),
+ )
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::distance::MetricType;
+ use crate::hnsw::HnswBuildParams;
+ use crate::io::{PosWriter, SeekRead};
+ use crate::ivfhnswflat::IVFHNSWFlatIndex;
+ use crate::ivfhnswflat_io::{
+ search_batch_ivfhnswflat_reader,
search_batch_ivfhnswflat_reader_roaring_filter,
+ write_ivfhnswflat_index, IVFHNSWFlatIndexReader,
IVF_HNSW_FLAT_HEADER_SIZE,
+ };
+ use roaring::RoaringTreemap;
+ use std::io;
+ use std::io::Cursor;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+ use std::sync::Arc;
+
+ #[test]
+ fn test_ivfhnswflat_write_read_search_roundtrip() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let query = &data[7 * d..8 * d];
+ let mut expected_distances = vec![0.0; 5];
+ let mut expected_labels = vec![0; 5];
+ index.search(
+ query,
+ 1,
+ 5,
+ nlist,
+ 32,
+ &mut expected_distances,
+ &mut expected_labels,
+ );
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let (labels, distances) = reader.search(query, 5, nlist, 32).unwrap();
+
+ assert_eq!(labels, expected_labels);
+ assert_eq!(distances, expected_distances);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_write_read_search_roundtrip_cosine() {
+ let d = 3;
+ let nlist = 2;
+ let data = vec![1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9,
0.1];
+ let ids = vec![10, 11, 12, 13];
+
+ let mut index =
+ IVFHNSWFlatIndex::new(d, nlist, MetricType::Cosine,
HnswBuildParams::default());
+ index.train(&data, 4);
+ index.add(&data, &ids, 4);
+ index.build_graphs().unwrap();
+
+ let query = [9.0, 1.0, 0.0];
+ let mut expected_distances = vec![0.0; 2];
+ let mut expected_labels = vec![0; 2];
+ index.search(
+ &query,
+ 1,
+ 2,
+ nlist,
+ 8,
+ &mut expected_distances,
+ &mut expected_labels,
+ );
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let (labels, distances) = reader.search(&query, 2, nlist, 8).unwrap();
+
+ assert_eq!(labels, expected_labels);
+ assert_eq!(distances, expected_distances);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_reader_filter_backfills_exact_results() {
+ use std::collections::HashSet;
+
+ let d = 2;
+ let nlist = 1;
+ let n = 128;
+ let mut data = Vec::with_capacity(n * d);
+ for i in 0..n {
+ data.push(i as f32);
+ data.push(0.0);
+ }
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+ let filter: HashSet<i64> = (0..n as i64).filter(|id| id % 2 ==
0).collect();
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let (labels, _) = reader
+ .search_with_filter(&[127.0, 0.0], 10, 1, 1, Some(&filter))
+ .unwrap();
+
+ assert_eq!(
+ labels,
+ vec![126, 124, 122, 120, 118, 116, 114, 112, 110, 108]
+ );
+ }
+
+ #[test]
+ fn test_ivfhnswflat_batch_reader_matches_single_reader_search() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+ let queries = [&data[7 * d..8 * d], &data[63 * d..64 * d]].concat();
+ let k = 5;
+ let nprobe = 3;
+ let ef_search = 32;
+ let mut batch_reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ let (batch_labels, batch_distances) =
+ search_batch_ivfhnswflat_reader(&mut batch_reader, &queries, 2, k,
nprobe, ef_search)
+ .unwrap();
+
+ for qi in 0..2 {
+ let mut single_reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ let query = &queries[qi * d..(qi + 1) * d];
+ let (single_labels, single_distances) =
+ single_reader.search(query, k, nprobe, ef_search).unwrap();
+ assert_eq!(&batch_labels[qi * k..(qi + 1) * k], single_labels);
+ assert_eq!(&batch_distances[qi * k..(qi + 1) * k],
single_distances);
+ }
+ }
+
+ #[test]
+ fn test_ivfhnswflat_batch_reader_search_with_roaring_filter_bytes() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+ let mut allowed = RoaringTreemap::new();
+ allowed.insert(12);
+ let mut filter_bytes = Vec::new();
+ allowed.serialize_into(&mut filter_bytes).unwrap();
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let queries = vec![0.0, 0.0, 10.0, 10.0];
+ let (labels, distances) =
search_batch_ivfhnswflat_reader_roaring_filter(
+ &mut reader,
+ &queries,
+ 2,
+ 2,
+ 1,
+ 4,
+ &filter_bytes,
+ )
+ .unwrap();
+
+ assert_eq!(labels, vec![12, -1, 12, -1]);
+ assert_eq!(distances, vec![200.0, f32::MAX, 0.0, f32::MAX]);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_batch_reader_validates_inputs() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 1.0, 0.0];
+ let ids = vec![10, 11];
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 2);
+ index.add(&data, &ids, 2);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ assert!(search_batch_ivfhnswflat_reader(&mut reader, &[], 0, 1, 1,
4).is_err());
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ assert!(search_batch_ivfhnswflat_reader(&mut reader, &[0.0], 1, 1, 1,
4).is_err());
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ assert!(search_batch_ivfhnswflat_reader(&mut reader, &[0.0, 0.0], 1,
0, 1, 4).is_err());
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ assert!(search_batch_ivfhnswflat_reader(&mut reader, &[0.0, 0.0], 1,
1, 0, 4).is_err());
+ }
+
+ #[test]
+ fn test_ivfhnswflat_reader_filter_reads_probed_list_once() {
+ use std::collections::HashSet;
+
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0];
+ let ids = vec![10, 11, 12, 13];
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 4);
+ index.add(&data, &ids, 4);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+ let pread_count = Arc::new(AtomicUsize::new(0));
+ let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+ let filter: HashSet<i64> = [10, 12].into_iter().collect();
+ let mut reader = IVFHNSWFlatIndexReader::open(cursor).unwrap();
+
+ reader
+ .search_with_filter(&[0.0, 0.0], 2, 1, 1, Some(&filter))
+ .unwrap();
+
+ assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_reader_rejects_truncated_graph_section() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0];
+ let ids = vec![10, 11, 12, 13];
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 4);
+ index.add(&data, &ids, 4);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+ buf.pop();
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let err = reader.search(&[0.0, 0.0], 2, 1, 4).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
+ }
+
+ #[test]
+ fn test_ivfhnswflat_reader_rejects_missing_graph_section() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0];
+ let ids = vec![10, 11, 12, 13];
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 4);
+ index.add(&data, &ids, 4);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswflat_index(&index, &mut writer).unwrap();
+ let graph_len_offset = IVF_HNSW_FLAT_HEADER_SIZE + nlist * d * 4 + 8 +
4;
+ buf[graph_len_offset..graph_len_offset +
4].copy_from_slice(&0i32.to_le_bytes());
+
+ let mut reader =
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let err = reader.search(&[0.0, 0.0], 2, 1, 4).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+ assert!(err.to_string().contains("missing HNSW graph"));
+ }
+
+ #[test]
+ fn
test_ivfhnswflat_decoder_rejects_level_above_hnsw_max_before_allocation() {
+ let params = HnswBuildParams {
+ m: 2,
+ ef_construction: 8,
+ max_level: 3,
+ };
+ let mut graph_bytes = Vec::new();
+ append_u32(&mut graph_bytes, 1);
+ append_u32(&mut graph_bytes, 0);
+ append_u32(&mut graph_bytes, 0);
+ append_u32(&mut graph_bytes, params.max_level as u32 + 1);
+
+ let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2,
MetricType::L2, params)
+ .unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+ assert!(err.to_string().contains("level"));
+ }
+
+ #[test]
+ fn
test_ivfhnswflat_decoder_rejects_degree_above_hnsw_bound_before_allocation() {
+ let params = HnswBuildParams {
+ m: 2,
+ ef_construction: 8,
+ max_level: 3,
+ };
+ let mut graph_bytes = Vec::new();
+ append_u32(&mut graph_bytes, 1);
+ append_u32(&mut graph_bytes, 0);
+ append_u32(&mut graph_bytes, 0);
+ append_u32(&mut graph_bytes, 0);
+ append_u32(&mut graph_bytes, (params.m * 2) as u32 + 1);
+
+ let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2,
MetricType::L2, params)
+ .unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+ assert!(err.to_string().contains("degree"));
+ }
+
+ #[test]
+ fn test_ivfhnswflat_writer_requires_built_graphs() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 1.0, 0.0];
+ let ids = vec![10, 11];
+
+ let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 2);
+ index.add(&data, &ids, 2);
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ let err = write_ivfhnswflat_index(&index, &mut writer).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err.to_string().contains("missing HNSW graph"));
+ }
+
+ struct CountingPreadCursor {
+ data: Vec<u8>,
+ pos: usize,
+ pread_count: Arc<AtomicUsize>,
+ }
+
+ impl CountingPreadCursor {
+ fn new(data: Vec<u8>, pread_count: Arc<AtomicUsize>) -> Self {
+ Self {
+ data,
+ pos: 0,
+ pread_count,
+ }
+ }
+ }
+
+ impl SeekRead for CountingPreadCursor {
+ fn seek(&mut self, pos: u64) -> io::Result<()> {
+ self.pos = pos as usize;
+ Ok(())
+ }
+
+ fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
+ let end = self.pos.checked_add(buf.len()).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::UnexpectedEof, "cursor position
overflow")
+ })?;
+ if end > self.data.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "failed to fill whole buffer",
+ ));
+ }
+ buf.copy_from_slice(&self.data[self.pos..end]);
+ self.pos = end;
+ Ok(())
+ }
+
+ fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
+ self.pread_count.fetch_add(1, Ordering::SeqCst);
+ let pos = pos as usize;
+ let end = pos.checked_add(buf.len()).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::UnexpectedEof, "cursor position
overflow")
+ })?;
+ if end > self.data.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "failed to fill whole buffer",
+ ));
+ }
+ buf.copy_from_slice(&self.data[pos..end]);
+ Ok(())
+ }
+ }
+
+ fn append_u32(buf: &mut Vec<u8>, value: u32) {
+ buf.extend_from_slice(&value.to_le_bytes());
+ }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index e34cd95..8f2f033 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -21,11 +21,15 @@
pub mod blas;
pub mod distance;
pub mod fastscan;
+pub mod hnsw;
pub mod io;
pub mod ivfflat;
pub mod ivfflat_io;
+pub mod ivfhnswflat;
+pub mod ivfhnswflat_io;
pub mod ivfpq;
pub mod kmeans;
pub mod opq;
pub mod pq;
pub mod shuffler;
+pub mod topk;
diff --git a/core/src/topk.rs b/core/src/topk.rs
new file mode 100644
index 0000000..112f76d
--- /dev/null
+++ b/core/src/topk.rs
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+
+pub(crate) struct TopKHeap {
+ k: usize,
+ data: Vec<(f32, i64)>,
+ positions: HashMap<i64, usize>,
+}
+
+impl TopKHeap {
+ pub(crate) fn new(k: usize) -> Self {
+ Self {
+ k,
+ data: Vec::with_capacity(k),
+ positions: HashMap::with_capacity(k),
+ }
+ }
+
+ pub(crate) fn push(&mut self, dist: f32, id: i64) {
+ if self.k == 0 {
+ return;
+ }
+ if let Some(&idx) = self.positions.get(&id) {
+ if dist < self.data[idx].0 {
+ self.data[idx].0 = dist;
+ }
+ return;
+ }
+ if self.data.len() < self.k {
+ self.positions.insert(id, self.data.len());
+ self.data.push((dist, id));
+ return;
+ }
+ if let Some((worst_idx, _)) = self
+ .data
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| a.0.total_cmp(&b.0))
+ {
+ if dist < self.data[worst_idx].0 {
+ let old_id = self.data[worst_idx].1;
+ self.positions.remove(&old_id);
+ self.data[worst_idx] = (dist, id);
+ self.positions.insert(id, worst_idx);
+ }
+ }
+ }
+
+ pub(crate) fn into_sorted(mut self) -> Vec<(f32, i64)> {
+ self.data.sort_by(|a, b| a.0.total_cmp(&b.0));
+ self.data
+ }
+
+ pub(crate) fn len(&self) -> usize {
+ self.data.len()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_topk_heap_updates_duplicate_and_replaces_worst() {
+ let mut heap = TopKHeap::new(2);
+
+ heap.push(10.0, 7);
+ heap.push(5.0, 8);
+ heap.push(1.0, 7);
+ heap.push(3.0, 9);
+
+ assert_eq!(heap.into_sorted(), vec![(1.0, 7), (3.0, 9)]);
+ }
+}