This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e170e7  Add IVF_HNSW_FLAT index and disk reader (#19)
8e170e7 is described below

commit 8e170e772625be0dae2eba7418589928f7cc8641
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 14:06:04 2026 +0800

    Add IVF_HNSW_FLAT index and disk reader (#19)
---
 core/benches/recall_bench.rs |   70 +-
 core/src/distance.rs         |   29 +
 core/src/hnsw.rs             |  767 ++++++++++++++++++++++
 core/src/ivfflat.rs          |   10 +-
 core/src/ivfhnswflat.rs      |  355 ++++++++++
 core/src/ivfhnswflat_io.rs   | 1460 ++++++++++++++++++++++++++++++++++++++++++
 core/src/lib.rs              |    4 +
 core/src/topk.rs             |   90 +++
 8 files changed, 2773 insertions(+), 12 deletions(-)

diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index 03415be..1755402 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -1,5 +1,7 @@
 use paimon_vindex_core::distance::{fvec_distance, MetricType};
+use paimon_vindex_core::hnsw::HnswBuildParams;
 use paimon_vindex_core::ivfflat::IVFFlatIndex;
+use paimon_vindex_core::ivfhnswflat::IVFHNSWFlatIndex;
 use paimon_vindex_core::ivfpq::IVFPQIndex;
 use std::collections::HashSet;
 use std::time::Instant;
@@ -14,6 +16,8 @@ fn main() {
         nlist: 64,
         pq_m: 8,
         nprobes: &[1, 4, 8, 16, 32, 64],
+        hnsw_build_ef: 80,
+        hnsw_search_efs: &[80],
     });
 
     println!();
@@ -27,6 +31,8 @@ fn main() {
         nlist: 8,
         pq_m: 8,
         nprobes: &[1, 2, 4, 8],
+        hnsw_build_ef: 200,
+        hnsw_search_efs: &[80, 160, 320],
     });
 }
 
@@ -39,6 +45,8 @@ struct Scenario<'a> {
     nlist: usize,
     pq_m: usize,
     nprobes: &'a [usize],
+    hnsw_build_ef: usize,
+    hnsw_search_efs: &'a [usize],
 }
 
 fn run_scenario(s: Scenario<'_>) {
@@ -75,9 +83,28 @@ fn run_scenario(s: Scenario<'_>) {
     ivfflat.add(&data, &ids, s.n);
     println!("build IVF-FLAT: {:.2}s", start.elapsed().as_secs_f64());
 
+    let start = Instant::now();
+    let mut ivfhnswflat = IVFHNSWFlatIndex::new(
+        s.d,
+        s.nlist,
+        MetricType::L2,
+        HnswBuildParams {
+            m: 16,
+            ef_construction: s.hnsw_build_ef,
+            max_level: 7,
+        },
+    );
+    ivfhnswflat.train(&data, s.n);
+    ivfhnswflat.add(&data, &ids, s.n);
+    ivfhnswflat.build_graphs().unwrap();
+    println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
+
     println!();
-    println!("index      nprobe  recall@{}  query_ms  us/query", s.k);
-    println!("---------  ------  ---------  --------  --------");
+    println!(
+        "index      nprobe  ef      recall@{}  query_ms  us/query",
+        s.k
+    );
+    println!("---------  ------  ------  ---------  --------  --------");
 
     for &nprobe in s.nprobes {
         let mut distances = vec![0.0f32; s.nq * s.k];
@@ -88,6 +115,7 @@ fn run_scenario(s: Scenario<'_>) {
         print_row(
             "IVF-PQ",
             nprobe,
+            None,
             recall_at_k(&labels, &ground_truth, s.nq, s.k),
             elapsed,
             s.nq,
@@ -101,19 +129,53 @@ fn run_scenario(s: Scenario<'_>) {
         print_row(
             "IVF-FLAT",
             nprobe,
+            None,
             recall_at_k(&labels, &ground_truth, s.nq, s.k),
             elapsed,
             s.nq,
         );
+
+        for &ef_search in s.hnsw_search_efs {
+            let mut distances = vec![0.0f32; s.nq * s.k];
+            let mut labels = vec![0i64; s.nq * s.k];
+            let start = Instant::now();
+            ivfhnswflat.search(
+                queries,
+                s.nq,
+                s.k,
+                nprobe,
+                ef_search,
+                &mut distances,
+                &mut labels,
+            );
+            let elapsed = start.elapsed();
+            print_row(
+                "IVF-HNSW",
+                nprobe,
+                Some(ef_search),
+                recall_at_k(&labels, &ground_truth, s.nq, s.k),
+                elapsed,
+                s.nq,
+            );
+        }
     }
 }
 
-fn print_row(index: &str, nprobe: usize, recall: f64, elapsed: 
std::time::Duration, nq: usize) {
+fn print_row(
+    index: &str,
+    nprobe: usize,
+    ef: Option<usize>,
+    recall: f64,
+    elapsed: std::time::Duration,
+    nq: usize,
+) {
     let ms = elapsed.as_secs_f64() * 1000.0;
+    let ef = ef.map(|v| v.to_string()).unwrap_or_else(|| "-".to_string());
     println!(
-        "{:<9}  {:>6}  {:>8.2}%  {:>8.2}  {:>8.1}",
+        "{:<9}  {:>6}  {:>6}  {:>8.2}%  {:>8.2}  {:>8.1}",
         index,
         nprobe,
+        ef,
         recall * 100.0,
         ms,
         ms * 1000.0 / nq as f64
diff --git a/core/src/distance.rs b/core/src/distance.rs
index 6c86d3c..4cbadd6 100644
--- a/core/src/distance.rs
+++ b/core/src/distance.rs
@@ -105,6 +105,35 @@ pub fn fvec_distance(query: &[f32], vector: &[f32], 
metric: MetricType) -> f32 {
     }
 }
 
+pub fn preprocess_vectors(data: &[f32], n: usize, d: usize, metric: 
MetricType) -> Vec<f32> {
+    let mut processed = data[..n * d].to_vec();
+    if metric == MetricType::Cosine {
+        for i in 0..n {
+            fvec_normalize(&mut processed[i * d..(i + 1) * d]);
+        }
+    }
+    processed
+}
+
+#[cfg(test)]
+mod preprocess_tests {
+    use super::*;
+
+    #[test]
+    fn test_preprocess_vectors_normalizes_cosine_only() {
+        let data = vec![3.0, 4.0, 1.0, 2.0];
+
+        assert_eq!(
+            preprocess_vectors(&data, 1, 2, MetricType::L2),
+            vec![3.0, 4.0]
+        );
+        assert_eq!(
+            preprocess_vectors(&data, 1, 2, MetricType::Cosine),
+            vec![0.6, 0.8]
+        );
+    }
+}
+
 /// Compute result[i] = a[i] + bf * b[i]. Used for precomputed table merging.
 /// Aligned with Faiss's fvec_madd.
 pub fn fvec_madd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
diff --git a/core/src/hnsw.rs b/core/src/hnsw.rs
new file mode 100644
index 0000000..4e4683a
--- /dev/null
+++ b/core/src/hnsw.rs
@@ -0,0 +1,767 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, MetricType};
+use std::cmp::Reverse;
+use std::collections::BinaryHeap;
+use std::io;
+
+#[derive(Debug, Clone, Copy)]
+pub struct HnswBuildParams {
+    pub m: usize,
+    pub ef_construction: usize,
+    pub max_level: usize,
+}
+
+impl Default for HnswBuildParams {
+    fn default() -> Self {
+        Self {
+            m: 20,
+            ef_construction: 150,
+            max_level: 7,
+        }
+    }
+}
+
+impl HnswBuildParams {
+    pub fn sanitized(self) -> Self {
+        Self {
+            m: self.m.max(1),
+            ef_construction: self.ef_construction.max(1),
+            max_level: self.max_level.max(1),
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct HnswGraph {
+    d: usize,
+    metric: MetricType,
+    vectors: Vec<f32>,
+    levels: Vec<usize>,
+    neighbors: Vec<Vec<Vec<usize>>>,
+    entry_point: usize,
+    max_observed_level: usize,
+    params: HnswBuildParams,
+}
+
+impl HnswGraph {
+    pub fn build(
+        vectors: &[f32],
+        n: usize,
+        d: usize,
+        metric: MetricType,
+        params: HnswBuildParams,
+    ) -> io::Result<Self> {
+        let expected_len = n.checked_mul(d).ok_or_else(|| {
+            io::Error::new(io::ErrorKind::InvalidInput, "n * dimension 
overflows usize")
+        })?;
+        if vectors.len() < expected_len {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "vector data length {} is shorter than n*d {}",
+                    vectors.len(),
+                    expected_len
+                ),
+            ));
+        }
+
+        let params = params.sanitized();
+        let mut graph = HnswGraph {
+            d,
+            metric,
+            vectors: vectors[..n * d].to_vec(),
+            levels: Vec::with_capacity(n),
+            neighbors: Vec::with_capacity(n),
+            entry_point: 0,
+            max_observed_level: 0,
+            params,
+        };
+
+        for node in 0..n {
+            graph.insert(node);
+        }
+        Ok(graph)
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    pub(crate) fn from_parts(
+        vectors: Vec<f32>,
+        n: usize,
+        d: usize,
+        metric: MetricType,
+        levels: Vec<usize>,
+        neighbors: Vec<Vec<Vec<usize>>>,
+        entry_point: usize,
+        max_observed_level: usize,
+        params: HnswBuildParams,
+    ) -> io::Result<Self> {
+        let expected_len = n.checked_mul(d).ok_or_else(|| {
+            io::Error::new(io::ErrorKind::InvalidInput, "n * dimension 
overflows usize")
+        })?;
+        if vectors.len() != expected_len {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!(
+                    "graph vector length {} does not match n*d {}",
+                    vectors.len(),
+                    expected_len
+                ),
+            ));
+        }
+        if levels.len() != n || neighbors.len() != n {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                "graph level metadata does not match vector count",
+            ));
+        }
+        if n == 0 {
+            return Ok(Self {
+                d,
+                metric,
+                vectors,
+                levels,
+                neighbors,
+                entry_point: 0,
+                max_observed_level: 0,
+                params: params.sanitized(),
+            });
+        }
+        if entry_point >= n {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("graph entry point {} out of range {}", entry_point, 
n),
+            ));
+        }
+        let observed = levels.iter().copied().max().unwrap_or(0);
+        if max_observed_level != observed {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!(
+                    "graph max level {} does not match observed {}",
+                    max_observed_level, observed
+                ),
+            ));
+        }
+        if levels[entry_point] < max_observed_level {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                "graph entry point does not reach max observed level",
+            ));
+        }
+        for node in 0..n {
+            if neighbors[node].len() != levels[node] + 1 {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    format!("graph node {} has invalid level adjacency", node),
+                ));
+            }
+            for (level, level_neighbors) in neighbors[node].iter().enumerate() 
{
+                for &neighbor in level_neighbors {
+                    if neighbor >= n || levels[neighbor] < level {
+                        return Err(io::Error::new(
+                            io::ErrorKind::InvalidData,
+                            format!(
+                                "graph edge {} -> {} at level {} is invalid",
+                                node, neighbor, level
+                            ),
+                        ));
+                    }
+                }
+            }
+        }
+        Ok(Self {
+            d,
+            metric,
+            vectors,
+            levels,
+            neighbors,
+            entry_point,
+            max_observed_level,
+            params: params.sanitized(),
+        })
+    }
+
+    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, 
f32)> {
+        if self.levels.is_empty() || k == 0 {
+            return Vec::new();
+        }
+
+        let mut ep = self.entry_point;
+        let mut ep_dist = self.distance_to_query(query, ep);
+        for level in (1..=self.max_observed_level).rev() {
+            let (next, dist) = self.greedy_search_query(query, ep, ep_dist, 
level);
+            ep = next;
+            ep_dist = dist;
+        }
+
+        let mut visited = vec![0usize; self.levels.len()];
+        let candidates = self.search_layer_query(query, ep, ef.max(k), 0, &mut 
visited, 1);
+        candidates
+            .into_iter()
+            .take(k)
+            .map(|n| (n.id, n.dist))
+            .collect()
+    }
+
+    pub fn len(&self) -> usize {
+        self.levels.len()
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.levels.is_empty()
+    }
+
+    pub fn max_degree(&self) -> usize {
+        self.neighbors
+            .iter()
+            .flat_map(|levels| levels.iter().map(Vec::len))
+            .max()
+            .unwrap_or(0)
+    }
+
+    pub(crate) fn vectors(&self) -> &[f32] {
+        &self.vectors
+    }
+
+    pub(crate) fn levels(&self) -> &[usize] {
+        &self.levels
+    }
+
+    pub(crate) fn neighbors(&self) -> &[Vec<Vec<usize>>] {
+        &self.neighbors
+    }
+
+    pub(crate) fn entry_point(&self) -> usize {
+        self.entry_point
+    }
+
+    pub(crate) fn max_observed_level(&self) -> usize {
+        self.max_observed_level
+    }
+
+    fn insert(&mut self, node: usize) {
+        let level = random_level(node, self.params.m, self.params.max_level);
+        self.levels.push(level);
+        self.neighbors.push(vec![Vec::new(); level + 1]);
+
+        if node == 0 {
+            self.entry_point = 0;
+            self.max_observed_level = level;
+            return;
+        }
+
+        let mut ep = self.entry_point;
+        let mut ep_dist = self.distance_between(node, ep);
+
+        for layer in ((level + 1)..=self.max_observed_level).rev() {
+            let (next, dist) = self.greedy_search_node(node, ep, ep_dist, 
layer);
+            ep = next;
+            ep_dist = dist;
+        }
+
+        let mut visited = vec![0usize; self.levels.len()];
+        let mut visit_mark = 1usize;
+        for layer in (0..=level.min(self.max_observed_level)).rev() {
+            let candidates = self.search_layer_node(
+                node,
+                ep,
+                self.params.ef_construction,
+                layer,
+                &mut visited,
+                visit_mark,
+            );
+            visit_mark = advance_visit_mark(&mut visited, visit_mark);
+            let selected = self.select_neighbors(candidates, 
self.max_neighbors(layer));
+            for neighbor in selected {
+                self.connect(node, neighbor.id, layer);
+            }
+            if let Some(best) = self.neighbors[node][layer]
+                .iter()
+                .copied()
+                .min_by(|&a, &b| {
+                    self.distance_between(node, a)
+                        .total_cmp(&self.distance_between(node, b))
+                })
+            {
+                ep = best;
+            }
+        }
+
+        if level > self.max_observed_level {
+            self.entry_point = node;
+            self.max_observed_level = level;
+        }
+    }
+
+    fn connect(&mut self, a: usize, b: usize, level: usize) {
+        if !self.neighbors[a][level].contains(&b) {
+            self.neighbors[a][level].push(b);
+        }
+        if level < self.neighbors[b].len() && 
!self.neighbors[b][level].contains(&a) {
+            self.neighbors[b][level].push(a);
+            let pruned = self.pruned_neighbors(b, level, 
self.max_neighbors(level));
+            self.neighbors[b][level] = pruned;
+        }
+        let pruned = self.pruned_neighbors(a, level, 
self.max_neighbors(level));
+        self.neighbors[a][level] = pruned;
+    }
+
+    fn pruned_neighbors(&self, node: usize, level: usize, max_neighbors: 
usize) -> Vec<usize> {
+        let mut ranked: Vec<ScoredNode> = self.neighbors[node][level]
+            .iter()
+            .map(|&id| ScoredNode {
+                id,
+                dist: self.distance_between(node, id),
+            })
+            .collect();
+        ranked.sort_by(|a, b| a.dist.total_cmp(&b.dist));
+        ranked
+            .into_iter()
+            .take(max_neighbors)
+            .map(|node| node.id)
+            .collect()
+    }
+
+    fn select_neighbors(
+        &self,
+        mut candidates: Vec<ScoredNode>,
+        max_neighbors: usize,
+    ) -> Vec<ScoredNode> {
+        candidates.sort_by(|a, b| a.dist.total_cmp(&b.dist));
+        let mut selected: Vec<ScoredNode> = Vec::with_capacity(max_neighbors);
+        let mut backfill: Vec<ScoredNode> = Vec::new();
+        for candidate in candidates {
+            if selected.len() >= max_neighbors {
+                break;
+            }
+            let closer_to_selected = selected
+                .iter()
+                .any(|neighbor| self.distance_between(candidate.id, 
neighbor.id) < candidate.dist);
+            if !closer_to_selected {
+                selected.push(candidate);
+            } else {
+                backfill.push(candidate);
+            }
+        }
+        for candidate in backfill {
+            if selected.len() >= max_neighbors {
+                break;
+            }
+            if !selected.iter().any(|neighbor| neighbor.id == candidate.id) {
+                selected.push(candidate);
+            }
+        }
+        selected
+    }
+
+    fn greedy_search_query(
+        &self,
+        query: &[f32],
+        mut current: usize,
+        mut current_dist: f32,
+        level: usize,
+    ) -> (usize, f32) {
+        loop {
+            let mut best = current;
+            let mut best_dist = current_dist;
+            for &neighbor in self.neighbors_at(current, level) {
+                let dist = self.distance_to_query(query, neighbor);
+                if dist < best_dist {
+                    best = neighbor;
+                    best_dist = dist;
+                }
+            }
+            if best == current {
+                return (current, current_dist);
+            }
+            current = best;
+            current_dist = best_dist;
+        }
+    }
+
+    fn greedy_search_node(
+        &self,
+        node: usize,
+        mut current: usize,
+        mut current_dist: f32,
+        level: usize,
+    ) -> (usize, f32) {
+        loop {
+            let mut best = current;
+            let mut best_dist = current_dist;
+            for &neighbor in self.neighbors_at(current, level) {
+                let dist = self.distance_between(node, neighbor);
+                if dist < best_dist {
+                    best = neighbor;
+                    best_dist = dist;
+                }
+            }
+            if best == current {
+                return (current, current_dist);
+            }
+            current = best;
+            current_dist = best_dist;
+        }
+    }
+
+    fn search_layer_query(
+        &self,
+        query: &[f32],
+        entry: usize,
+        ef: usize,
+        level: usize,
+        visited: &mut [usize],
+        visit_mark: usize,
+    ) -> Vec<ScoredNode> {
+        self.search_layer(entry, ef, level, visited, visit_mark, |id| {
+            self.distance_to_query(query, id)
+        })
+    }
+
+    fn search_layer_node(
+        &self,
+        node: usize,
+        entry: usize,
+        ef: usize,
+        level: usize,
+        visited: &mut [usize],
+        visit_mark: usize,
+    ) -> Vec<ScoredNode> {
+        self.search_layer(entry, ef, level, visited, visit_mark, |id| {
+            self.distance_between(node, id)
+        })
+    }
+
+    fn search_layer(
+        &self,
+        entry: usize,
+        ef: usize,
+        level: usize,
+        visited: &mut [usize],
+        visit_mark: usize,
+        mut distance: impl FnMut(usize) -> f32,
+    ) -> Vec<ScoredNode> {
+        let entry_dist = distance(entry);
+        visited[entry] = visit_mark;
+
+        let mut candidates = BinaryHeap::new();
+        candidates.push(Reverse(HeapNode {
+            id: entry,
+            dist: entry_dist,
+        }));
+
+        let mut results = BinaryHeap::new();
+        results.push(HeapNode {
+            id: entry,
+            dist: entry_dist,
+        });
+
+        while let Some(Reverse(current)) = candidates.pop() {
+            let worst = results
+                .peek()
+                .map(|node| node.dist)
+                .unwrap_or(f32::INFINITY);
+            if current.dist > worst && results.len() >= ef {
+                break;
+            }
+
+            for &neighbor in self.neighbors_at(current.id, level) {
+                if visited[neighbor] == visit_mark {
+                    continue;
+                }
+                visited[neighbor] = visit_mark;
+                let dist = distance(neighbor);
+                let worst = results
+                    .peek()
+                    .map(|node| node.dist)
+                    .unwrap_or(f32::INFINITY);
+                if results.len() < ef || dist < worst {
+                    candidates.push(Reverse(HeapNode { id: neighbor, dist }));
+                    results.push(HeapNode { id: neighbor, dist });
+                    if results.len() > ef {
+                        results.pop();
+                    }
+                }
+            }
+        }
+
+        let mut result: Vec<ScoredNode> = results
+            .into_iter()
+            .map(|node| ScoredNode {
+                id: node.id,
+                dist: node.dist,
+            })
+            .collect();
+        result.sort_by(|a, b| a.dist.total_cmp(&b.dist));
+        result
+    }
+
+    fn max_neighbors(&self, level: usize) -> usize {
+        if level == 0 {
+            self.params.m * 2
+        } else {
+            self.params.m
+        }
+    }
+
+    fn neighbors_at(&self, node: usize, level: usize) -> &[usize] {
+        self.neighbors
+            .get(node)
+            .and_then(|levels| levels.get(level))
+            .map(Vec::as_slice)
+            .unwrap_or(&[])
+    }
+
+    fn distance_between(&self, a: usize, b: usize) -> f32 {
+        let va = &self.vectors[a * self.d..(a + 1) * self.d];
+        let vb = &self.vectors[b * self.d..(b + 1) * self.d];
+        fvec_distance(va, vb, self.metric)
+    }
+
+    fn distance_to_query(&self, query: &[f32], id: usize) -> f32 {
+        let vector = &self.vectors[id * self.d..(id + 1) * self.d];
+        fvec_distance(query, vector, self.metric)
+    }
+}
+
+#[derive(Debug, Clone, Copy)]
+struct ScoredNode {
+    id: usize,
+    dist: f32,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+struct HeapNode {
+    id: usize,
+    dist: f32,
+}
+
+impl Eq for HeapNode {}
+
+impl PartialOrd for HeapNode {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl Ord for HeapNode {
+    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+        self.dist.total_cmp(&other.dist)
+    }
+}
+
+fn random_level(node: usize, m: usize, max_level: usize) -> usize {
+    if node == 0 || max_level <= 1 {
+        // Keep the first insertion deterministic. Later higher-level nodes 
replace
+        // the entry point as they appear, while tiny lists naturally stay 
flat.
+        return 0;
+    }
+    let mut x = splitmix64(node as u64 + 0x9E37_79B9_7F4A_7C15);
+    let mut level = 0;
+    let threshold = (u64::MAX / m.max(2) as u64).max(1);
+    while level + 1 < max_level && x < threshold {
+        level += 1;
+        x = splitmix64(x);
+    }
+    level
+}
+
+fn advance_visit_mark(visited: &mut [usize], visit_mark: usize) -> usize {
+    visit_mark.checked_add(1).unwrap_or_else(|| {
+        visited.fill(0);
+        1
+    })
+}
+
+fn splitmix64(mut x: u64) -> u64 {
+    x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
+    let mut z = x;
+    z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
+    z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
+    z ^ (z >> 31)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::distance::MetricType;
+
+    #[test]
+    fn test_hnsw_recalls_query_vector_on_single_partition() {
+        let d = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| [i as f32 * 0.01, 1.0, 2.0, 3.0])
+            .collect();
+        let params = HnswBuildParams {
+            m: 8,
+            ef_construction: 32,
+            max_level: 6,
+        };
+
+        let graph = HnswGraph::build(&data, n, d, MetricType::L2, 
params).unwrap();
+        let query_id = 17;
+        let results = graph.search(&data[query_id * d..(query_id + 1) * d], 5, 
32);
+
+        assert_eq!(results[0].0, query_id);
+        assert_eq!(results[0].1, 0.0);
+    }
+
+    #[test]
+    fn test_hnsw_empty_graph_returns_no_results() {
+        let graph =
+            HnswGraph::build(&[], 0, 4, MetricType::L2, 
HnswBuildParams::default()).unwrap();
+
+        assert!(graph.search(&[0.0, 0.0, 0.0, 0.0], 10, 20).is_empty());
+        assert!(graph.is_empty());
+        assert_eq!(graph.len(), 0);
+        assert_eq!(graph.max_degree(), 0);
+    }
+
+    #[test]
+    fn test_hnsw_build_rejects_short_vector_input() {
+        let err = HnswGraph::build(
+            &[0.0, 1.0, 2.0],
+            2,
+            2,
+            MetricType::L2,
+            HnswBuildParams::default(),
+        )
+        .unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(err.to_string().contains("shorter than n*d"));
+    }
+
+    #[test]
+    fn test_hnsw_respects_neighbor_degree_bound() {
+        let d = 8;
+        let n = 512;
+        let data = generate_clustered_data(n, d, 16);
+        let params = HnswBuildParams {
+            m: 12,
+            ef_construction: 100,
+            max_level: 6,
+        };
+
+        let graph = HnswGraph::build(&data, n, d, MetricType::L2, 
params).unwrap();
+
+        assert_eq!(graph.len(), n);
+        assert!(graph.max_degree() <= params.m * 2);
+    }
+
+    #[test]
+    fn test_hnsw_large_partition_recall_tracks_exact_search() {
+        let d = 16;
+        let n = 4096;
+        let nq = 32;
+        let k = 10;
+        let data = generate_clustered_data(n, d, 32);
+        let params = HnswBuildParams {
+            m: 16,
+            ef_construction: 200,
+            max_level: 7,
+        };
+
+        let graph = HnswGraph::build(&data, n, d, MetricType::L2, 
params).unwrap();
+        let mut hits = 0usize;
+        for qi in 0..nq {
+            let query = &data[qi * d..(qi + 1) * d];
+            let expected = exact_topk(&data, n, d, query, k);
+            let actual = graph.search(query, k, 200);
+            hits += actual
+                .iter()
+                .filter(|(id, _)| expected.contains(id))
+                .count();
+        }
+
+        let recall = hits as f32 / (nq * k) as f32;
+        assert!(recall >= 0.95, "recall={}", recall);
+    }
+
+    #[test]
+    fn test_hnsw_neighbor_selection_backfills_after_diversification() {
+        let d = 1;
+        let data = vec![0.0, 1.0, 2.0, 3.0];
+        let graph = HnswGraph::build(
+            &data,
+            4,
+            d,
+            MetricType::L2,
+            HnswBuildParams {
+                m: 2,
+                ef_construction: 4,
+                max_level: 1,
+            },
+        )
+        .unwrap();
+        let candidates = vec![
+            ScoredNode { id: 1, dist: 1.0 },
+            ScoredNode { id: 2, dist: 2.0 },
+            ScoredNode { id: 3, dist: 3.0 },
+        ];
+
+        let selected = graph.select_neighbors(candidates, 3);
+
+        assert_eq!(selected.len(), 3);
+    }
+
+    #[test]
+    fn test_hnsw_greedy_search_chooses_best_improving_neighbor() {
+        let graph = HnswGraph::from_parts(
+            vec![0.0, 5.0, 2.0],
+            3,
+            1,
+            MetricType::L2,
+            vec![0, 0, 0],
+            vec![vec![vec![1, 2]], vec![vec![]], vec![vec![]]],
+            0,
+            0,
+            HnswBuildParams::default(),
+        )
+        .unwrap();
+
+        let (next, dist) = graph.greedy_search_query(&[2.0], 0, 4.0, 0);
+
+        assert_eq!(next, 2);
+        assert_eq!(dist, 0.0);
+    }
+
+    fn exact_topk(data: &[f32], n: usize, d: usize, query: &[f32], k: usize) 
-> Vec<usize> {
+        let mut distances: Vec<(f32, usize)> = (0..n)
+            .map(|i| {
+                let vector = &data[i * d..(i + 1) * d];
+                (fvec_distance(query, vector, MetricType::L2), i)
+            })
+            .collect();
+        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+        distances[..k].iter().map(|&(_, id)| id).collect()
+    }
+
+    fn generate_clustered_data(n: usize, d: usize, num_clusters: usize) -> 
Vec<f32> {
+        let mut data = vec![0.0f32; n * d];
+        for i in 0..n {
+            let cluster = i % num_clusters;
+            for j in 0..d {
+                data[i * d + j] = cluster as f32 * 20.0 + j as f32 * 0.01 + i 
as f32 * 0.0001;
+            }
+        }
+        data
+    }
+}
diff --git a/core/src/ivfflat.rs b/core/src/ivfflat.rs
index 185bdf1..601bb3d 100644
--- a/core/src/ivfflat.rs
+++ b/core/src/ivfflat.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::distance::{fvec_distance, fvec_normalize, MetricType};
+use crate::distance::{fvec_distance, preprocess_vectors, MetricType};
 use crate::ivfpq::RowIdFilter;
 use crate::kmeans::{self, KMeansConfig};
 
@@ -134,13 +134,7 @@ impl IVFFlatIndex {
     }
 
     pub(crate) fn preprocess_vectors(&self, data: &[f32], n: usize) -> 
Vec<f32> {
-        let mut processed = data[..n * self.d].to_vec();
-        if self.metric == MetricType::Cosine {
-            for i in 0..n {
-                fvec_normalize(&mut processed[i * self.d..(i + 1) * self.d]);
-            }
-        }
-        processed
+        preprocess_vectors(data, n, self.d, self.metric)
     }
 }
 
diff --git a/core/src/ivfhnswflat.rs b/core/src/ivfhnswflat.rs
new file mode 100644
index 0000000..e2f7ccf
--- /dev/null
+++ b/core/src/ivfhnswflat.rs
@@ -0,0 +1,355 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::ivfflat::IVFFlatIndex;
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans;
+use crate::topk::TopKHeap;
+use std::io;
+
+pub struct IVFHNSWFlatIndex {
+    /// Exposed to match the existing core index structs. Mutating `flat`
+    /// directly can stale `graphs`; call `build_graphs` again before HNSW 
search.
+    pub flat: IVFFlatIndex,
+    pub graphs: Vec<Option<HnswGraph>>,
+    pub hnsw_params: HnswBuildParams,
+}
+
+impl IVFHNSWFlatIndex {
+    pub fn new(d: usize, nlist: usize, metric: MetricType, hnsw_params: 
HnswBuildParams) -> Self {
+        IVFHNSWFlatIndex {
+            flat: IVFFlatIndex::new(d, nlist, metric),
+            graphs: vec![None; nlist],
+            hnsw_params,
+        }
+    }
+
+    pub fn train(&mut self, data: &[f32], n: usize) {
+        self.flat.train(data, n);
+    }
+
+    pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
+        self.flat.add(data, ids, n);
+        self.graphs.fill(None);
+    }
+
+    pub fn build_graphs(&mut self) -> io::Result<()> {
+        for list_id in 0..self.flat.nlist {
+            let count = self.flat.ids[list_id].len();
+            self.graphs[list_id] = if count == 0 {
+                None
+            } else {
+                Some(HnswGraph::build(
+                    &self.flat.vectors[list_id],
+                    count,
+                    self.flat.d,
+                    self.flat.metric,
+                    self.hnsw_params,
+                )?)
+            };
+        }
+        Ok(())
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    pub fn search(
+        &self,
+        queries: &[f32],
+        nq: usize,
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        result_distances: &mut [f32],
+        result_labels: &mut [i64],
+    ) {
+        self.search_with_filter(
+            queries,
+            nq,
+            k,
+            nprobe,
+            ef_search,
+            None,
+            result_distances,
+            result_labels,
+        );
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    pub fn search_with_filter(
+        &self,
+        queries: &[f32],
+        nq: usize,
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        filter: Option<&dyn RowIdFilter>,
+        result_distances: &mut [f32],
+        result_labels: &mut [i64],
+    ) {
+        let processed_queries = self.flat.preprocess_vectors(queries, nq);
+        let (all_probe_indices, _) = kmeans::find_topk_batch(
+            &processed_queries,
+            nq,
+            &self.flat.quantizer_centroids,
+            self.flat.nlist,
+            self.flat.d,
+            nprobe,
+        );
+
+        for qi in 0..nq {
+            let query = &processed_queries[qi * self.flat.d..(qi + 1) * 
self.flat.d];
+            let mut heap = TopKHeap::new(k);
+            let force_flat_scan = filter
+                .map(|f| self.count_filtered(&all_probe_indices[qi], f) <= 
ef_search.max(k))
+                .unwrap_or(false);
+
+            for &list_id in &all_probe_indices[qi] {
+                if force_flat_scan {
+                    self.scan_flat_list(query, list_id, filter, &mut heap);
+                    continue;
+                }
+                if let Some(ref graph) = self.graphs[list_id] {
+                    let local_results = graph.search(query, ef_search.max(k), 
ef_search.max(k));
+                    for (local_id, dist) in local_results {
+                        let row_id = self.flat.ids[list_id][local_id];
+                        if let Some(f) = filter {
+                            if !f.contains(row_id) {
+                                continue;
+                            }
+                        }
+                        heap.push(dist, row_id);
+                    }
+                } else {
+                    self.scan_flat_list(query, list_id, filter, &mut heap);
+                }
+            }
+            if filter.is_some() && heap.len() < k {
+                for &list_id in &all_probe_indices[qi] {
+                    self.scan_flat_list(query, list_id, filter, &mut heap);
+                }
+            }
+
+            let sorted = heap.into_sorted();
+            let out_base = qi * k;
+            for (i, &(dist, id)) in sorted.iter().enumerate() {
+                result_distances[out_base + i] = dist;
+                result_labels[out_base + i] = id;
+            }
+            for i in sorted.len()..k {
+                result_distances[out_base + i] = f32::MAX;
+                result_labels[out_base + i] = -1;
+            }
+        }
+    }
+
+    fn count_filtered(&self, probe_indices: &[usize], filter: &dyn 
RowIdFilter) -> usize {
+        probe_indices
+            .iter()
+            .map(|&list_id| {
+                self.flat.ids[list_id]
+                    .iter()
+                    .filter(|&&id| filter.contains(id))
+                    .count()
+            })
+            .sum()
+    }
+
+    fn scan_flat_list(
+        &self,
+        query: &[f32],
+        list_id: usize,
+        filter: Option<&dyn RowIdFilter>,
+        heap: &mut TopKHeap,
+    ) {
+        for (local_id, &row_id) in self.flat.ids[list_id].iter().enumerate() {
+            if let Some(f) = filter {
+                if !f.contains(row_id) {
+                    continue;
+                }
+            }
+            let vector =
+                &self.flat.vectors[list_id][local_id * self.flat.d..(local_id 
+ 1) * self.flat.d];
+            heap.push(fvec_distance(query, vector, self.flat.metric), row_id);
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::distance::MetricType;
+    use crate::hnsw::HnswBuildParams;
+
+    #[test]
+    fn test_ivfhnswflat_recalls_query_vector() {
+        let d = 4;
+        let nlist = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| {
+                let cluster = (i % nlist) as f32 * 100.0;
+                [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+            })
+            .collect();
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let query_id = 23;
+        let mut distances = vec![0.0; 5];
+        let mut labels = vec![0; 5];
+        index.search(
+            &data[query_id * d..(query_id + 1) * d],
+            1,
+            5,
+            nlist,
+            32,
+            &mut distances,
+            &mut labels,
+        );
+
+        assert_eq!(labels[0], ids[query_id]);
+        assert_eq!(distances[0], 0.0);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_without_built_graphs_falls_back_to_flat_scan() {
+        let d = 4;
+        let nlist = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| {
+                let cluster = (i % nlist) as f32 * 100.0;
+                [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+            })
+            .collect();
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+
+        let query_id = 23;
+        let mut distances = vec![0.0; 5];
+        let mut labels = vec![0; 5];
+        index.search(
+            &data[query_id * d..(query_id + 1) * d],
+            1,
+            5,
+            nlist,
+            32,
+            &mut distances,
+            &mut labels,
+        );
+
+        assert_eq!(labels[0], ids[query_id]);
+        assert_eq!(distances[0], 0.0);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_selective_filter_uses_exact_results() {
+        use std::collections::HashSet;
+
+        let d = 2;
+        let nlist = 1;
+        let n = 64;
+        let mut data = Vec::with_capacity(n * d);
+        for i in 0..n {
+            data.push(i as f32);
+            data.push(0.0);
+        }
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let filter: HashSet<i64> = [63].into_iter().collect();
+        let mut distances = vec![0.0; 1];
+        let mut labels = vec![0; 1];
+        index.search_with_filter(
+            &[0.0, 0.0],
+            1,
+            1,
+            1,
+            4,
+            Some(&filter),
+            &mut distances,
+            &mut labels,
+        );
+
+        assert_eq!(labels[0], 63);
+        assert_eq!(distances[0], 63.0 * 63.0);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_filter_backfills_when_graph_returns_too_few_matches() {
+        use std::collections::HashSet;
+
+        let d = 2;
+        let nlist = 1;
+        let n = 128;
+        let mut data = Vec::with_capacity(n * d);
+        for i in 0..n {
+            data.push(i as f32);
+            data.push(0.0);
+        }
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let filter: HashSet<i64> = (0..n as i64).filter(|id| id % 2 == 
0).collect();
+        let mut distances = vec![0.0; 10];
+        let mut labels = vec![0; 10];
+        index.search_with_filter(
+            &[127.0, 0.0],
+            1,
+            10,
+            1,
+            1,
+            Some(&filter),
+            &mut distances,
+            &mut labels,
+        );
+
+        assert_eq!(
+            labels,
+            vec![126, 124, 122, 120, 118, 116, 114, 112, 110, 108]
+        );
+        assert!(labels.iter().all(|id| id % 2 == 0));
+    }
+
+    #[test]
+    fn test_topk_heap_keeps_closest_duplicate_id() {
+        let mut heap = TopKHeap::new(2);
+
+        heap.push(10.0, 7);
+        heap.push(5.0, 8);
+        heap.push(1.0, 7);
+
+        assert_eq!(heap.into_sorted(), vec![(1.0, 7), (5.0, 8)]);
+    }
+}
diff --git a/core/src/ivfhnswflat_io.rs b/core/src/ivfhnswflat_io.rs
new file mode 100644
index 0000000..d504d43
--- /dev/null
+++ b/core/src/ivfhnswflat_io.rs
@@ -0,0 +1,1460 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, preprocess_vectors, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::io::{SeekRead, SeekWrite};
+use crate::ivfhnswflat::IVFHNSWFlatIndex;
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans;
+use crate::topk::TopKHeap;
+use roaring::RoaringTreemap;
+use std::io;
+
+pub const IVF_HNSW_FLAT_MAGIC: u32 = 0x4948464C; // "IHFL"
+pub const IVF_HNSW_FLAT_VERSION: u32 = 1;
+pub const IVF_HNSW_FLAT_HEADER_SIZE: usize = 64;
+
+pub fn write_ivfhnswflat_index(
+    index: &IVFHNSWFlatIndex,
+    out: &mut dyn SeekWrite,
+) -> io::Result<()> {
+    validate_index_shape(index)?;
+    let d = index.flat.d;
+    let nlist = index.flat.nlist;
+    let total_vectors = index.flat.ids.iter().try_fold(0i64, |sum, ids| {
+        let count = usize_to_i64(ids.len(), "total vector count")?;
+        sum.checked_add(count).ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "total vector count exceeds i64 length limit",
+            )
+        })
+    })?;
+    let graph_bytes: Vec<Vec<u8>> = (0..nlist)
+        .map(|list_id| {
+            if index.flat.ids[list_id].is_empty() {
+                Ok(Vec::new())
+            } else {
+                encode_graph(index.graphs[list_id].as_ref())
+            }
+        })
+        .collect::<io::Result<_>>()?;
+
+    write_u32_le(out, IVF_HNSW_FLAT_MAGIC)?;
+    write_u32_le(out, IVF_HNSW_FLAT_VERSION)?;
+    write_i32_le(out, usize_to_i32(d, "dimension")?)?;
+    write_i32_le(out, usize_to_i32(nlist, "nlist")?)?;
+    write_u32_le(out, index.flat.metric as u32)?;
+    write_i64_le(out, total_vectors)?;
+    let params = index.hnsw_params.sanitized();
+    write_i32_le(out, usize_to_i32(params.m, "hnsw m")?)?;
+    write_i32_le(
+        out,
+        usize_to_i32(params.ef_construction, "hnsw ef_construction")?,
+    )?;
+    write_i32_le(out, usize_to_i32(params.max_level, "hnsw max_level")?)?;
+    out.write_all(&[0u8; 24])?;
+
+    write_f32_slice(out, &index.flat.quantizer_centroids)?;
+
+    let offset_table_size = nlist.checked_mul(24).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "IVF-HNSW-FLAT offset table size overflow",
+        )
+    })?;
+    let data_start = out
+        .pos()
+        .checked_add(offset_table_size as u64)
+        .ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "IVF-HNSW-FLAT data start offset overflow",
+            )
+        })?;
+    let mut list_offsets = vec![0i64; nlist];
+    let mut list_counts = vec![0i32; nlist];
+    let mut list_graph_bytes_lens = vec![0i32; nlist];
+    let mut current_offset = data_start;
+
+    for list_id in 0..nlist {
+        list_offsets[list_id] = u64_to_i64(current_offset, "list offset")?;
+        let count = index.flat.ids[list_id].len();
+        list_counts[list_id] = usize_to_i32(count, "list count")?;
+        list_graph_bytes_lens[list_id] = 
usize_to_i32(graph_bytes[list_id].len(), "graph bytes")?;
+        if count > 0 {
+            let payload_len = list_payload_len(count, d, 
graph_bytes[list_id].len())?;
+            current_offset = current_offset
+                .checked_add(payload_len as u64)
+                .ok_or_else(|| {
+                    io::Error::new(io::ErrorKind::InvalidInput, "IVF-HNSW-FLAT 
offset overflow")
+                })?;
+        }
+    }
+
+    for list_id in 0..nlist {
+        write_i64_le(out, list_offsets[list_id])?;
+        write_i32_le(out, list_counts[list_id])?;
+        write_i32_le(out, list_graph_bytes_lens[list_id])?;
+        write_i64_le(out, 0)?;
+    }
+
+    for list_id in 0..nlist {
+        if index.flat.ids[list_id].is_empty() {
+            continue;
+        }
+        for &id in &index.flat.ids[list_id] {
+            write_i64_le(out, id)?;
+        }
+        write_f32_slice(out, &index.flat.vectors[list_id])?;
+        out.write_all(&graph_bytes[list_id])?;
+    }
+
+    Ok(())
+}
+
+pub struct IVFHNSWFlatIndexReader<R: SeekRead> {
+    reader: R,
+    pub d: usize,
+    pub nlist: usize,
+    pub metric: MetricType,
+    pub total_vectors: i64,
+    pub hnsw_params: HnswBuildParams,
+    pub quantizer_centroids: Vec<f32>,
+    pub list_offsets: Vec<i64>,
+    pub list_counts: Vec<i32>,
+    pub list_graph_bytes_lens: Vec<i32>,
+    loaded: bool,
+}
+
+impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
+    pub fn open(mut reader: R) -> io::Result<Self> {
+        reader.seek(0)?;
+
+        let magic = read_u32_le(&mut reader)?;
+        if magic != IVF_HNSW_FLAT_MAGIC {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("Invalid IVF_HNSW_FLAT magic: 0x{:08X}", magic),
+            ));
+        }
+        let version = read_u32_le(&mut reader)?;
+        if version != IVF_HNSW_FLAT_VERSION {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("Unsupported IVF_HNSW_FLAT version: {}", version),
+            ));
+        }
+
+        let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as 
usize;
+        let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")? 
as usize;
+        let metric_code = read_u32_le(&mut reader)?;
+        let metric = MetricType::from_code(metric_code).ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("Unknown metric type: {}", metric_code),
+            )
+        })?;
+        let total_vectors = read_i64_le(&mut reader)?;
+        let hnsw_params = HnswBuildParams {
+            m: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw m")? as 
usize,
+            ef_construction: validate_positive_i32(
+                read_i32_le(&mut reader)?,
+                "hnsw ef_construction",
+            )? as usize,
+            max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw 
max_level")? as usize,
+        }
+        .sanitized();
+        let mut reserved = [0u8; 24];
+        reader.read_exact(&mut reserved)?;
+
+        Ok(Self {
+            reader,
+            d,
+            nlist,
+            metric,
+            total_vectors,
+            hnsw_params,
+            quantizer_centroids: Vec::new(),
+            list_offsets: Vec::new(),
+            list_counts: Vec::new(),
+            list_graph_bytes_lens: Vec::new(),
+            loaded: false,
+        })
+    }
+
+    pub fn ensure_loaded(&mut self) -> io::Result<()> {
+        if self.loaded {
+            return Ok(());
+        }
+
+        self.reader.seek(IVF_HNSW_FLAT_HEADER_SIZE as u64)?;
+        self.quantizer_centroids =
+            read_f32_vec(&mut self.reader, checked_section_size(self.nlist, 
self.d)?)?;
+        self.list_offsets = vec![0; self.nlist];
+        self.list_counts = vec![0; self.nlist];
+        self.list_graph_bytes_lens = vec![0; self.nlist];
+        for list_id in 0..self.nlist {
+            self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
+            let count = read_i32_le(&mut self.reader)?;
+            if count < 0 {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    format!("negative list count {} at list {}", count, 
list_id),
+                ));
+            }
+            self.list_counts[list_id] = count;
+            let graph_bytes_len = read_i32_le(&mut self.reader)?;
+            if graph_bytes_len < 0 {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    format!(
+                        "negative graph_bytes_len {} at list {}",
+                        graph_bytes_len, list_id
+                    ),
+                ));
+            }
+            self.list_graph_bytes_lens[list_id] = graph_bytes_len;
+            let _reserved = read_i64_le(&mut self.reader)?;
+        }
+
+        self.loaded = true;
+        Ok(())
+    }
+
+    pub fn read_inverted_list(
+        &mut self,
+        list_id: usize,
+    ) -> io::Result<(Vec<i64>, Vec<f32>, Option<HnswGraph>)> {
+        let Some(list) = self.read_graph_list(list_id)? else {
+            return Ok((Vec::new(), Vec::new(), None));
+        };
+        let vectors = list.graph.vectors().to_vec();
+        Ok((list.ids, vectors, Some(list.graph)))
+    }
+
+    fn read_graph_list(&mut self, list_id: usize) -> 
io::Result<Option<GraphList>> {
+        self.ensure_loaded()?;
+        if list_id >= self.nlist {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!("list_id {} out of range (nlist={})", list_id, 
self.nlist),
+            ));
+        }
+        let count = self.list_counts[list_id] as usize;
+        if count == 0 {
+            return Ok(None);
+        }
+
+        let offset = checked_list_offset(self.list_offsets[list_id], list_id)?;
+        let graph_bytes_len = self.list_graph_bytes_lens[list_id] as usize;
+        if graph_bytes_len == 0 {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("list {} is missing HNSW graph", list_id),
+            ));
+        }
+        let payload_len = list_payload_len(count, self.d, graph_bytes_len)?;
+        let mut payload = vec![0u8; payload_len];
+        self.reader.pread(offset, &mut payload)?;
+
+        let ids_bytes_len = checked_list_bytes(count, 8)?;
+        let vector_bytes_len = checked_list_bytes(
+            count,
+            self.d.checked_mul(4).ok_or_else(|| {
+                io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    "IVF-HNSW-FLAT bytes per vector overflow",
+                )
+            })?,
+        )?;
+        let ids = payload[..ids_bytes_len]
+            .chunks_exact(8)
+            .map(|c| i64::from_le_bytes(c.try_into().unwrap()))
+            .collect();
+        let vectors = bytes_to_f32_vec(&payload[ids_bytes_len..ids_bytes_len + 
vector_bytes_len])?;
+        let graph = decode_graph(
+            &payload[ids_bytes_len + vector_bytes_len..],
+            vectors.clone(),
+            count,
+            self.d,
+            self.metric,
+            self.hnsw_params,
+        )?;
+        let graph = graph.ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("list {} is missing HNSW graph", list_id),
+            )
+        })?;
+        Ok(Some(GraphList { ids, graph }))
+    }
+
+    pub fn search(
+        &mut self,
+        query: &[f32],
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+    ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        self.search_with_filter(query, k, nprobe, ef_search, None)
+    }
+
+    pub fn search_with_filter(
+        &mut self,
+        query: &[f32],
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        filter: Option<&dyn RowIdFilter>,
+    ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        self.ensure_loaded()?;
+        if query.len() != self.d {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "query length {} does not match index dimension {}",
+                    query.len(),
+                    self.d
+                ),
+            ));
+        }
+        if k == 0 {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "k must be greater than 0",
+            ));
+        }
+        if nprobe == 0 {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "nprobe must be greater than 0",
+            ));
+        }
+
+        let q = preprocess_vectors(query, 1, self.d, self.metric);
+        let (probe_indices, _) =
+            kmeans::find_topk(&q, &self.quantizer_centroids, self.nlist, 
self.d, nprobe);
+        let mut loaded_lists = Vec::with_capacity(probe_indices.len());
+        for list_id in probe_indices {
+            if let Some(list) = self.read_graph_list(list_id)? {
+                loaded_lists.push(list);
+            }
+        }
+        let mut heap = TopKHeap::new(k);
+        let force_flat_scan = filter
+            .map(|f| count_filtered(&loaded_lists, f) <= ef_search.max(k))
+            .unwrap_or(false);
+
+        for list in &loaded_lists {
+            if force_flat_scan {
+                scan_flat_list(
+                    &q,
+                    &list.ids,
+                    list.graph.vectors(),
+                    self.d,
+                    self.metric,
+                    filter,
+                    &mut heap,
+                );
+            } else {
+                let local_results = list.graph.search(&q, ef_search.max(k), 
ef_search.max(k));
+                for (local_id, dist) in local_results {
+                    let row_id = list.ids[local_id];
+                    if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+                        heap.push(dist, row_id);
+                    }
+                }
+            }
+        }
+        if filter.is_some() && heap.len() < k {
+            for list in &loaded_lists {
+                scan_flat_list(
+                    &q,
+                    &list.ids,
+                    list.graph.vectors(),
+                    self.d,
+                    self.metric,
+                    filter,
+                    &mut heap,
+                );
+            }
+        }
+
+        let sorted = heap.into_sorted();
+        let mut labels: Vec<i64> = sorted.iter().map(|&(_, id)| id).collect();
+        let mut distances: Vec<f32> = sorted.iter().map(|&(dist, _)| 
dist).collect();
+        labels.resize(k, -1);
+        distances.resize(k, f32::MAX);
+        Ok((labels, distances))
+    }
+
+    pub fn search_with_roaring_filter(
+        &mut self,
+        query: &[f32],
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        roaring_filter_bytes: &[u8],
+    ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        let filter = decode_roaring_filter(roaring_filter_bytes)?;
+        self.search_with_filter(query, k, nprobe, ef_search, Some(&filter))
+    }
+}
+
+pub fn search_batch_ivfhnswflat_reader<R: SeekRead>(
+    reader: &mut IVFHNSWFlatIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    ef_search: usize,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    search_batch_ivfhnswflat_reader_filter(reader, queries, nq, k, nprobe, 
ef_search, None)
+}
+
+pub fn search_batch_ivfhnswflat_reader_filter<R: SeekRead>(
+    reader: &mut IVFHNSWFlatIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    ef_search: usize,
+    filter: Option<&dyn RowIdFilter>,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    reader.ensure_loaded()?;
+    validate_search_inputs(queries, nq, reader.d, k, nprobe)?;
+
+    let processed = preprocess_vectors(queries, nq, reader.d, reader.metric);
+    let (all_probe_indices, _) = kmeans::find_topk_batch(
+        &processed,
+        nq,
+        &reader.quantizer_centroids,
+        reader.nlist,
+        reader.d,
+        nprobe,
+    );
+
+    let mut list_to_queries = vec![Vec::new(); reader.nlist];
+    let mut unique_lists = Vec::new();
+    for (qi, probe_indices) in all_probe_indices.iter().enumerate() {
+        for &list_id in probe_indices {
+            if list_to_queries[list_id].is_empty() {
+                unique_lists.push(list_id);
+            }
+            list_to_queries[list_id].push(qi);
+        }
+    }
+
+    let mut heaps: Vec<TopKHeap> = (0..nq).map(|_| TopKHeap::new(k)).collect();
+    let mut query_filtered_counts = vec![0usize; nq];
+    let mut loaded_lists = Vec::with_capacity(unique_lists.len());
+    for list_id in unique_lists {
+        let Some(list) = reader.read_graph_list(list_id)? else {
+            continue;
+        };
+        if let Some(f) = filter {
+            let filtered = list.ids.iter().filter(|&&id| 
f.contains(id)).count();
+            for &qi in &list_to_queries[list_id] {
+                query_filtered_counts[qi] = query_filtered_counts[qi]
+                    .checked_add(filtered)
+                    .ok_or_else(|| {
+                        io::Error::new(
+                            io::ErrorKind::InvalidData,
+                            "filtered vector count overflows usize",
+                        )
+                    })?;
+            }
+        }
+        loaded_lists.push(LoadedBatchList {
+            query_ids: std::mem::take(&mut list_to_queries[list_id]),
+            ids: list.ids,
+            graph: list.graph,
+        });
+    }
+
+    for list in &loaded_lists {
+        for &qi in &list.query_ids {
+            let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+            let force_flat_scan = filter
+                .map(|_| query_filtered_counts[qi] <= ef_search.max(k))
+                .unwrap_or(false);
+            if force_flat_scan {
+                scan_flat_list(
+                    query,
+                    &list.ids,
+                    list.graph.vectors(),
+                    reader.d,
+                    reader.metric,
+                    filter,
+                    &mut heaps[qi],
+                );
+            } else {
+                let local_results = list.graph.search(query, ef_search.max(k), 
ef_search.max(k));
+                for (local_id, dist) in local_results {
+                    let row_id = list.ids[local_id];
+                    if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+                        heaps[qi].push(dist, row_id);
+                    }
+                }
+            }
+        }
+    }
+    if filter.is_some() {
+        for list in &loaded_lists {
+            for &qi in &list.query_ids {
+                if heaps[qi].len() >= k {
+                    continue;
+                }
+                let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+                scan_flat_list(
+                    query,
+                    &list.ids,
+                    list.graph.vectors(),
+                    reader.d,
+                    reader.metric,
+                    filter,
+                    &mut heaps[qi],
+                );
+            }
+        }
+    }
+
+    let mut result_ids = vec![-1i64; nq * k];
+    let mut result_dists = vec![f32::MAX; nq * k];
+    for qi in 0..nq {
+        let sorted = std::mem::replace(&mut heaps[qi], 
TopKHeap::new(0)).into_sorted();
+        let base = qi * k;
+        for (i, &(dist, id)) in sorted.iter().enumerate() {
+            result_ids[base + i] = id;
+            result_dists[base + i] = dist;
+        }
+    }
+
+    Ok((result_ids, result_dists))
+}
+
+pub fn search_batch_ivfhnswflat_reader_roaring_filter<R: SeekRead>(
+    reader: &mut IVFHNSWFlatIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    ef_search: usize,
+    roaring_filter_bytes: &[u8],
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    let filter = decode_roaring_filter(roaring_filter_bytes)?;
+    search_batch_ivfhnswflat_reader_filter(reader, queries, nq, k, nprobe, 
ef_search, Some(&filter))
+}
+
+struct GraphList {
+    ids: Vec<i64>,
+    graph: HnswGraph,
+}
+
+struct LoadedBatchList {
+    query_ids: Vec<usize>,
+    ids: Vec<i64>,
+    graph: HnswGraph,
+}
+
+fn count_filtered(lists: &[GraphList], filter: &dyn RowIdFilter) -> usize {
+    lists
+        .iter()
+        .map(|list| list.ids.iter().filter(|&&id| filter.contains(id)).count())
+        .sum()
+}
+
+fn scan_flat_list(
+    query: &[f32],
+    ids: &[i64],
+    vectors: &[f32],
+    d: usize,
+    metric: MetricType,
+    filter: Option<&dyn RowIdFilter>,
+    heap: &mut TopKHeap,
+) {
+    for (local_id, &row_id) in ids.iter().enumerate() {
+        if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
+            continue;
+        }
+        let vector = &vectors[local_id * d..(local_id + 1) * d];
+        heap.push(fvec_distance(query, vector, metric), row_id);
+    }
+}
+
+fn validate_search_inputs(
+    queries: &[f32],
+    nq: usize,
+    d: usize,
+    k: usize,
+    nprobe: usize,
+) -> io::Result<()> {
+    if nq == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nq must be greater than 0",
+        ));
+    }
+    let expected_query_len = nq.checked_mul(d).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nq * dimension overflows usize",
+        )
+    })?;
+    if queries.len() != expected_query_len {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!(
+                "queries length {} does not match nq * dimension {}",
+                queries.len(),
+                expected_query_len
+            ),
+        ));
+    }
+    if k == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "k must be greater than 0",
+        ));
+    }
+    if nprobe == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nprobe must be greater than 0",
+        ));
+    }
+    Ok(())
+}
+
+fn encode_graph(graph: Option<&HnswGraph>) -> io::Result<Vec<u8>> {
+    let Some(graph) = graph else {
+        return Ok(Vec::new());
+    };
+    let mut buf = Vec::new();
+    write_u32_vec(&mut buf, graph.len())?;
+    write_u32_vec(&mut buf, graph.entry_point())?;
+    write_u32_vec(&mut buf, graph.max_observed_level())?;
+    for &level in graph.levels() {
+        write_u32_vec(&mut buf, level)?;
+    }
+    for node_levels in graph.neighbors() {
+        for level_neighbors in node_levels {
+            write_u32_vec(&mut buf, level_neighbors.len())?;
+            for &neighbor in level_neighbors {
+                write_u32_vec(&mut buf, neighbor)?;
+            }
+        }
+    }
+    Ok(buf)
+}
+
+fn decode_graph(
+    bytes: &[u8],
+    vectors: Vec<f32>,
+    count: usize,
+    d: usize,
+    metric: MetricType,
+    hnsw_params: HnswBuildParams,
+) -> io::Result<Option<HnswGraph>> {
+    if bytes.is_empty() {
+        return Ok(None);
+    }
+    let mut pos = 0usize;
+    let graph_count = read_u32_vec(bytes, &mut pos)? as usize;
+    if graph_count != count {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!(
+                "graph count {} does not match list count {}",
+                graph_count, count
+            ),
+        ));
+    }
+    let entry_point = read_u32_vec(bytes, &mut pos)? as usize;
+    let max_observed_level = read_u32_vec(bytes, &mut pos)? as usize;
+    let mut levels = Vec::with_capacity(count);
+    for node in 0..count {
+        let level = read_u32_vec(bytes, &mut pos)? as usize;
+        if level >= hnsw_params.max_level {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!(
+                    "graph node {} level {} exceeds max_level {}",
+                    node,
+                    level,
+                    hnsw_params.max_level - 1
+                ),
+            ));
+        }
+        levels.push(level);
+    }
+    let mut neighbors = Vec::with_capacity(count);
+    for (node, &level) in levels.iter().enumerate() {
+        let mut node_levels = Vec::with_capacity(level + 1);
+        for graph_level in 0..=level {
+            let degree = read_u32_vec(bytes, &mut pos)? as usize;
+            let max_degree = if graph_level == 0 {
+                hnsw_params.m.saturating_mul(2)
+            } else {
+                hnsw_params.m
+            };
+            if degree > max_degree {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    format!(
+                        "graph node {} degree {} at level {} exceeds max 
degree {}",
+                        node, degree, graph_level, max_degree
+                    ),
+                ));
+            }
+            let mut level_neighbors = Vec::with_capacity(degree);
+            for _ in 0..degree {
+                level_neighbors.push(read_u32_vec(bytes, &mut pos)? as usize);
+            }
+            node_levels.push(level_neighbors);
+        }
+        neighbors.push(node_levels);
+    }
+    if pos != bytes.len() {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "trailing bytes in HNSW graph section",
+        ));
+    }
+    Ok(Some(HnswGraph::from_parts(
+        vectors,
+        count,
+        d,
+        metric,
+        levels,
+        neighbors,
+        entry_point,
+        max_observed_level,
+        hnsw_params,
+    )?))
+}
+
+fn write_u32_vec(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
+    if value > u32::MAX as usize {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("value {} exceeds u32 limit", value),
+        ));
+    }
+    buf.extend_from_slice(&(value as u32).to_le_bytes());
+    Ok(())
+}
+
+fn read_u32_vec(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
+    let end = pos.checked_add(4).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidData,
+            "graph section position overflow",
+        )
+    })?;
+    if end > bytes.len() {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "truncated HNSW graph section",
+        ));
+    }
+    let value = u32::from_le_bytes(bytes[*pos..end].try_into().unwrap());
+    *pos = end;
+    Ok(value)
+}
+
+fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
+    out.write_all(&v.to_le_bytes())
+}
+
+fn write_i32_le(out: &mut dyn SeekWrite, v: i32) -> io::Result<()> {
+    out.write_all(&v.to_le_bytes())
+}
+
+fn write_i64_le(out: &mut dyn SeekWrite, v: i64) -> io::Result<()> {
+    out.write_all(&v.to_le_bytes())
+}
+
+fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32]) -> io::Result<()> {
+    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
+    out.write_all(&bytes)
+}
+
+fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+    let mut buf = [0u8; 4];
+    reader.read_exact(&mut buf)?;
+    Ok(u32::from_le_bytes(buf))
+}
+
+fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+    let mut buf = [0u8; 4];
+    reader.read_exact(&mut buf)?;
+    Ok(i32::from_le_bytes(buf))
+}
+
+fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+    let mut buf = [0u8; 8];
+    reader.read_exact(&mut buf)?;
+    Ok(i64::from_le_bytes(buf))
+}
+
+fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) -> 
io::Result<Vec<f32>> {
+    let byte_len = count.checked_mul(4).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidData,
+            "f32 section byte length overflow",
+        )
+    })?;
+    let mut buf = vec![0u8; byte_len];
+    reader.read_exact(&mut buf)?;
+    bytes_to_f32_vec(&buf)
+}
+
+fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
+    if !bytes.len().is_multiple_of(4) {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "f32 byte section is not 4-byte aligned",
+        ));
+    }
+    Ok(bytes
+        .chunks_exact(4)
+        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
+        .collect())
+}
+
+fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
+    if val <= 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!("invalid header field {}: {} (must be positive)", field, 
val),
+        ));
+    }
+    Ok(val)
+}
+
+fn validate_index_shape(index: &IVFHNSWFlatIndex) -> io::Result<()> {
+    if index.flat.d == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "dimension must be greater than 0",
+        ));
+    }
+    if index.flat.nlist == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nlist must be greater than 0",
+        ));
+    }
+    if index.flat.ids.len() != index.flat.nlist
+        || index.flat.vectors.len() != index.flat.nlist
+        || index.graphs.len() != index.flat.nlist
+    {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "IVF-HNSW-FLAT list storage does not match nlist",
+        ));
+    }
+    let centroid_len = checked_section_size(index.flat.nlist, index.flat.d)?;
+    if index.flat.quantizer_centroids.len() != centroid_len {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!(
+                "centroid length {} does not match nlist*d {}",
+                index.flat.quantizer_centroids.len(),
+                centroid_len
+            ),
+        ));
+    }
+    for list_id in 0..index.flat.nlist {
+        let count = index.flat.ids[list_id].len();
+        let expected_vector_len = 
count.checked_mul(index.flat.d).ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "IVF-HNSW-FLAT vector length overflow",
+            )
+        })?;
+        if index.flat.vectors[list_id].len() != expected_vector_len {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "list {} vector length {} does not match ids*d {}",
+                    list_id,
+                    index.flat.vectors[list_id].len(),
+                    expected_vector_len
+                ),
+            ));
+        }
+        match &index.graphs[list_id] {
+            Some(graph)
+                if graph.len() == count
+                    && graph.vectors() == 
index.flat.vectors[list_id].as_slice() => {}
+            Some(_) => {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidInput,
+                    format!("list {} graph does not match vector storage", 
list_id),
+                ));
+            }
+            None if count == 0 => {}
+            None => {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidInput,
+                    format!("list {} is missing HNSW graph", list_id),
+                ));
+            }
+        }
+    }
+    Ok(())
+}
+
+fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
+    if value > i32::MAX as usize {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("{} exceeds i32 length limit: {}", field, value),
+        ));
+    }
+    Ok(value as i32)
+}
+
+fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
+    if value > i64::MAX as usize {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("{} exceeds i64 length limit: {}", field, value),
+        ));
+    }
+    Ok(value as i64)
+}
+
+fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
+    if value > i64::MAX as u64 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("{} exceeds i64 offset limit: {}", field, value),
+        ));
+    }
+    Ok(value as i64)
+}
+
+const MAX_SECTION_ELEMENTS: usize = 1 << 30;
+
+fn checked_section_size(a: usize, b: usize) -> io::Result<usize> {
+    let result = a.checked_mul(b).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidData,
+            "section size overflow in IVF-HNSW-FLAT header",
+        )
+    })?;
+    if result > MAX_SECTION_ELEMENTS {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!(
+                "section size {} exceeds maximum {}",
+                result, MAX_SECTION_ELEMENTS
+            ),
+        ));
+    }
+    Ok(result)
+}
+
+fn checked_list_offset(offset: i64, list_id: usize) -> io::Result<u64> {
+    if offset < 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!("negative list offset {} at list {}", offset, list_id),
+        ));
+    }
+    Ok(offset as u64)
+}
+
+fn checked_list_bytes(count: usize, bytes_per_entry: usize) -> 
io::Result<usize> {
+    count.checked_mul(bytes_per_entry).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidData,
+            "IVF-HNSW-FLAT list byte size overflow",
+        )
+    })
+}
+
+fn list_payload_len(count: usize, d: usize, graph_bytes_len: usize) -> 
io::Result<usize> {
+    let id_bytes = checked_list_bytes(count, 8)?;
+    let vector_bytes = checked_list_bytes(
+        count,
+        d.checked_mul(4).ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                "IVF-HNSW-FLAT bytes per vector overflow",
+            )
+        })?,
+    )?;
+    id_bytes
+        .checked_add(vector_bytes)
+        .and_then(|len| len.checked_add(graph_bytes_len))
+        .ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                "IVF-HNSW-FLAT list payload overflow",
+            )
+        })
+}
+
+fn decode_roaring_filter(bytes: &[u8]) -> io::Result<RoaringTreemap> {
+    RoaringTreemap::deserialize_from(bytes).map_err(|e| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("invalid RoaringTreemap filter: {}", e),
+        )
+    })
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::distance::MetricType;
+    use crate::hnsw::HnswBuildParams;
+    use crate::io::{PosWriter, SeekRead};
+    use crate::ivfhnswflat::IVFHNSWFlatIndex;
+    use crate::ivfhnswflat_io::{
+        search_batch_ivfhnswflat_reader, 
search_batch_ivfhnswflat_reader_roaring_filter,
+        write_ivfhnswflat_index, IVFHNSWFlatIndexReader, 
IVF_HNSW_FLAT_HEADER_SIZE,
+    };
+    use roaring::RoaringTreemap;
+    use std::io;
+    use std::io::Cursor;
+    use std::sync::atomic::{AtomicUsize, Ordering};
+    use std::sync::Arc;
+
+    #[test]
+    fn test_ivfhnswflat_write_read_search_roundtrip() {
+        let d = 4;
+        let nlist = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| {
+                let cluster = (i % nlist) as f32 * 100.0;
+                [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+            })
+            .collect();
+        let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let query = &data[7 * d..8 * d];
+        let mut expected_distances = vec![0.0; 5];
+        let mut expected_labels = vec![0; 5];
+        index.search(
+            query,
+            1,
+            5,
+            nlist,
+            32,
+            &mut expected_distances,
+            &mut expected_labels,
+        );
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+        let (labels, distances) = reader.search(query, 5, nlist, 32).unwrap();
+
+        assert_eq!(labels, expected_labels);
+        assert_eq!(distances, expected_distances);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_write_read_search_roundtrip_cosine() {
+        let d = 3;
+        let nlist = 2;
+        let data = vec![1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9, 
0.1];
+        let ids = vec![10, 11, 12, 13];
+
+        let mut index =
+            IVFHNSWFlatIndex::new(d, nlist, MetricType::Cosine, 
HnswBuildParams::default());
+        index.train(&data, 4);
+        index.add(&data, &ids, 4);
+        index.build_graphs().unwrap();
+
+        let query = [9.0, 1.0, 0.0];
+        let mut expected_distances = vec![0.0; 2];
+        let mut expected_labels = vec![0; 2];
+        index.search(
+            &query,
+            1,
+            2,
+            nlist,
+            8,
+            &mut expected_distances,
+            &mut expected_labels,
+        );
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+        let (labels, distances) = reader.search(&query, 2, nlist, 8).unwrap();
+
+        assert_eq!(labels, expected_labels);
+        assert_eq!(distances, expected_distances);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_reader_filter_backfills_exact_results() {
+        use std::collections::HashSet;
+
+        let d = 2;
+        let nlist = 1;
+        let n = 128;
+        let mut data = Vec::with_capacity(n * d);
+        for i in 0..n {
+            data.push(i as f32);
+            data.push(0.0);
+        }
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+        let filter: HashSet<i64> = (0..n as i64).filter(|id| id % 2 == 
0).collect();
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+        let (labels, _) = reader
+            .search_with_filter(&[127.0, 0.0], 10, 1, 1, Some(&filter))
+            .unwrap();
+
+        assert_eq!(
+            labels,
+            vec![126, 124, 122, 120, 118, 116, 114, 112, 110, 108]
+        );
+    }
+
+    #[test]
+    fn test_ivfhnswflat_batch_reader_matches_single_reader_search() {
+        let d = 4;
+        let nlist = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| {
+                let cluster = (i % nlist) as f32 * 100.0;
+                [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+            })
+            .collect();
+        let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+        let queries = [&data[7 * d..8 * d], &data[63 * d..64 * d]].concat();
+        let k = 5;
+        let nprobe = 3;
+        let ef_search = 32;
+        let mut batch_reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+        let (batch_labels, batch_distances) =
+            search_batch_ivfhnswflat_reader(&mut batch_reader, &queries, 2, k, 
nprobe, ef_search)
+                .unwrap();
+
+        for qi in 0..2 {
+            let mut single_reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+            let query = &queries[qi * d..(qi + 1) * d];
+            let (single_labels, single_distances) =
+                single_reader.search(query, k, nprobe, ef_search).unwrap();
+            assert_eq!(&batch_labels[qi * k..(qi + 1) * k], single_labels);
+            assert_eq!(&batch_distances[qi * k..(qi + 1) * k], 
single_distances);
+        }
+    }
+
+    #[test]
+    fn test_ivfhnswflat_batch_reader_search_with_roaring_filter_bytes() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+        let ids = vec![10, 11, 12];
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 3);
+        index.add(&data, &ids, 3);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+        let mut allowed = RoaringTreemap::new();
+        allowed.insert(12);
+        let mut filter_bytes = Vec::new();
+        allowed.serialize_into(&mut filter_bytes).unwrap();
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+        let queries = vec![0.0, 0.0, 10.0, 10.0];
+        let (labels, distances) = 
search_batch_ivfhnswflat_reader_roaring_filter(
+            &mut reader,
+            &queries,
+            2,
+            2,
+            1,
+            4,
+            &filter_bytes,
+        )
+        .unwrap();
+
+        assert_eq!(labels, vec![12, -1, 12, -1]);
+        assert_eq!(distances, vec![200.0, f32::MAX, 0.0, f32::MAX]);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_batch_reader_validates_inputs() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 1.0, 0.0];
+        let ids = vec![10, 11];
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 2);
+        index.add(&data, &ids, 2);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+        assert!(search_batch_ivfhnswflat_reader(&mut reader, &[], 0, 1, 1, 
4).is_err());
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+        assert!(search_batch_ivfhnswflat_reader(&mut reader, &[0.0], 1, 1, 1, 
4).is_err());
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+        assert!(search_batch_ivfhnswflat_reader(&mut reader, &[0.0, 0.0], 1, 
0, 1, 4).is_err());
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+        assert!(search_batch_ivfhnswflat_reader(&mut reader, &[0.0, 0.0], 1, 
1, 0, 4).is_err());
+    }
+
+    #[test]
+    fn test_ivfhnswflat_reader_filter_reads_probed_list_once() {
+        use std::collections::HashSet;
+
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0];
+        let ids = vec![10, 11, 12, 13];
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 4);
+        index.add(&data, &ids, 4);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+
+        let pread_count = Arc::new(AtomicUsize::new(0));
+        let cursor = CountingPreadCursor::new(buf, Arc::clone(&pread_count));
+        let filter: HashSet<i64> = [10, 12].into_iter().collect();
+        let mut reader = IVFHNSWFlatIndexReader::open(cursor).unwrap();
+
+        reader
+            .search_with_filter(&[0.0, 0.0], 2, 1, 1, Some(&filter))
+            .unwrap();
+
+        assert_eq!(pread_count.load(Ordering::SeqCst), 1);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_reader_rejects_truncated_graph_section() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0];
+        let ids = vec![10, 11, 12, 13];
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 4);
+        index.add(&data, &ids, 4);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+        buf.pop();
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+        let err = reader.search(&[0.0, 0.0], 2, 1, 4).unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
+    }
+
+    #[test]
+    fn test_ivfhnswflat_reader_rejects_missing_graph_section() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0];
+        let ids = vec![10, 11, 12, 13];
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 4);
+        index.add(&data, &ids, 4);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswflat_index(&index, &mut writer).unwrap();
+        let graph_len_offset = IVF_HNSW_FLAT_HEADER_SIZE + nlist * d * 4 + 8 + 
4;
+        buf[graph_len_offset..graph_len_offset + 
4].copy_from_slice(&0i32.to_le_bytes());
+
+        let mut reader = 
IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap();
+        let err = reader.search(&[0.0, 0.0], 2, 1, 4).unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+        assert!(err.to_string().contains("missing HNSW graph"));
+    }
+
+    #[test]
+    fn 
test_ivfhnswflat_decoder_rejects_level_above_hnsw_max_before_allocation() {
+        let params = HnswBuildParams {
+            m: 2,
+            ef_construction: 8,
+            max_level: 3,
+        };
+        let mut graph_bytes = Vec::new();
+        append_u32(&mut graph_bytes, 1);
+        append_u32(&mut graph_bytes, 0);
+        append_u32(&mut graph_bytes, 0);
+        append_u32(&mut graph_bytes, params.max_level as u32 + 1);
+
+        let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, 
MetricType::L2, params)
+            .unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+        assert!(err.to_string().contains("level"));
+    }
+
+    #[test]
+    fn 
test_ivfhnswflat_decoder_rejects_degree_above_hnsw_bound_before_allocation() {
+        let params = HnswBuildParams {
+            m: 2,
+            ef_construction: 8,
+            max_level: 3,
+        };
+        let mut graph_bytes = Vec::new();
+        append_u32(&mut graph_bytes, 1);
+        append_u32(&mut graph_bytes, 0);
+        append_u32(&mut graph_bytes, 0);
+        append_u32(&mut graph_bytes, 0);
+        append_u32(&mut graph_bytes, (params.m * 2) as u32 + 1);
+
+        let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, 
MetricType::L2, params)
+            .unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
+        assert!(err.to_string().contains("degree"));
+    }
+
+    #[test]
+    fn test_ivfhnswflat_writer_requires_built_graphs() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 1.0, 0.0];
+        let ids = vec![10, 11];
+
+        let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 2);
+        index.add(&data, &ids, 2);
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        let err = write_ivfhnswflat_index(&index, &mut writer).unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(err.to_string().contains("missing HNSW graph"));
+    }
+
+    struct CountingPreadCursor {
+        data: Vec<u8>,
+        pos: usize,
+        pread_count: Arc<AtomicUsize>,
+    }
+
+    impl CountingPreadCursor {
+        fn new(data: Vec<u8>, pread_count: Arc<AtomicUsize>) -> Self {
+            Self {
+                data,
+                pos: 0,
+                pread_count,
+            }
+        }
+    }
+
+    impl SeekRead for CountingPreadCursor {
+        fn seek(&mut self, pos: u64) -> io::Result<()> {
+            self.pos = pos as usize;
+            Ok(())
+        }
+
+        fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
+            let end = self.pos.checked_add(buf.len()).ok_or_else(|| {
+                io::Error::new(io::ErrorKind::UnexpectedEof, "cursor position 
overflow")
+            })?;
+            if end > self.data.len() {
+                return Err(io::Error::new(
+                    io::ErrorKind::UnexpectedEof,
+                    "failed to fill whole buffer",
+                ));
+            }
+            buf.copy_from_slice(&self.data[self.pos..end]);
+            self.pos = end;
+            Ok(())
+        }
+
+        fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
+            self.pread_count.fetch_add(1, Ordering::SeqCst);
+            let pos = pos as usize;
+            let end = pos.checked_add(buf.len()).ok_or_else(|| {
+                io::Error::new(io::ErrorKind::UnexpectedEof, "cursor position 
overflow")
+            })?;
+            if end > self.data.len() {
+                return Err(io::Error::new(
+                    io::ErrorKind::UnexpectedEof,
+                    "failed to fill whole buffer",
+                ));
+            }
+            buf.copy_from_slice(&self.data[pos..end]);
+            Ok(())
+        }
+    }
+
+    fn append_u32(buf: &mut Vec<u8>, value: u32) {
+        buf.extend_from_slice(&value.to_le_bytes());
+    }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index e34cd95..8f2f033 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -21,11 +21,15 @@
 pub mod blas;
 pub mod distance;
 pub mod fastscan;
+pub mod hnsw;
 pub mod io;
 pub mod ivfflat;
 pub mod ivfflat_io;
+pub mod ivfhnswflat;
+pub mod ivfhnswflat_io;
 pub mod ivfpq;
 pub mod kmeans;
 pub mod opq;
 pub mod pq;
 pub mod shuffler;
+pub mod topk;
diff --git a/core/src/topk.rs b/core/src/topk.rs
new file mode 100644
index 0000000..112f76d
--- /dev/null
+++ b/core/src/topk.rs
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+
+pub(crate) struct TopKHeap {
+    k: usize,
+    data: Vec<(f32, i64)>,
+    positions: HashMap<i64, usize>,
+}
+
+impl TopKHeap {
+    pub(crate) fn new(k: usize) -> Self {
+        Self {
+            k,
+            data: Vec::with_capacity(k),
+            positions: HashMap::with_capacity(k),
+        }
+    }
+
+    pub(crate) fn push(&mut self, dist: f32, id: i64) {
+        if self.k == 0 {
+            return;
+        }
+        if let Some(&idx) = self.positions.get(&id) {
+            if dist < self.data[idx].0 {
+                self.data[idx].0 = dist;
+            }
+            return;
+        }
+        if self.data.len() < self.k {
+            self.positions.insert(id, self.data.len());
+            self.data.push((dist, id));
+            return;
+        }
+        if let Some((worst_idx, _)) = self
+            .data
+            .iter()
+            .enumerate()
+            .max_by(|(_, a), (_, b)| a.0.total_cmp(&b.0))
+        {
+            if dist < self.data[worst_idx].0 {
+                let old_id = self.data[worst_idx].1;
+                self.positions.remove(&old_id);
+                self.data[worst_idx] = (dist, id);
+                self.positions.insert(id, worst_idx);
+            }
+        }
+    }
+
+    pub(crate) fn into_sorted(mut self) -> Vec<(f32, i64)> {
+        self.data.sort_by(|a, b| a.0.total_cmp(&b.0));
+        self.data
+    }
+
+    pub(crate) fn len(&self) -> usize {
+        self.data.len()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_topk_heap_updates_duplicate_and_replaces_worst() {
+        let mut heap = TopKHeap::new(2);
+
+        heap.push(10.0, 7);
+        heap.push(5.0, 8);
+        heap.push(1.0, 7);
+        heap.push(3.0, 9);
+
+        assert_eq!(heap.into_sorted(), vec![(1.0, 7), (3.0, 9)]);
+    }
+}


Reply via email to