This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new ac63164  Optimize HNSW index build performance (#24)
ac63164 is described below

commit ac63164cb5c2071fbae415c7c499e2ad30ac8f5c
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 21:58:02 2026 +0800

    Optimize HNSW index build performance (#24)
---
 core/benches/recall_bench.rs | 118 +++++---
 core/src/distance.rs         | 109 ++++++-
 core/src/hnsw.rs             | 697 +++++++++++++++++++++++++++++++++++++------
 core/src/hnsw_search.rs      |  10 +-
 core/src/ivfhnswflat.rs      |  33 +-
 core/src/ivfhnswsq.rs        |  29 +-
 6 files changed, 839 insertions(+), 157 deletions(-)

diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index aef434f..0f2887f 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -21,7 +21,6 @@ fn main() {
         nlist: 64,
         pq_m: 8,
         nprobes: &[1, 4, 8, 16, 32, 64],
-        hnsw_build_ef: 80,
         hnsw_search_efs: &[80],
     });
 
@@ -36,7 +35,6 @@ fn main() {
         nlist: 8,
         pq_m: 8,
         nprobes: &[1, 2, 4, 8],
-        hnsw_build_ef: 200,
         hnsw_search_efs: &[80, 160, 320],
     });
 }
@@ -50,7 +48,6 @@ struct Scenario<'a> {
     nlist: usize,
     pq_m: usize,
     nprobes: &'a [usize],
-    hnsw_build_ef: usize,
     hnsw_search_efs: &'a [usize],
 }
 
@@ -89,39 +86,14 @@ fn run_scenario(s: Scenario<'_>) {
     println!("build IVF-FLAT: {:.2}s", start.elapsed().as_secs_f64());
 
     let start = Instant::now();
-    let mut ivfhnswflat = IVFHNSWFlatIndex::new(
-        s.d,
-        s.nlist,
-        MetricType::L2,
-        HnswBuildParams {
-            m: 16,
-            ef_construction: s.hnsw_build_ef,
-            max_level: 7,
-        },
-    );
-    ivfhnswflat.train(&data, s.n);
-    ivfhnswflat.add(&data, &ids, s.n);
-    ivfhnswflat.build_graphs().unwrap();
-    println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
-
-    let start = Instant::now();
-    let mut ivfhnswsq = IVFHNSWSQIndex::new(
-        s.d,
-        s.nlist,
-        MetricType::L2,
-        HnswBuildParams {
-            m: 16,
-            ef_construction: s.hnsw_build_ef,
-            max_level: 7,
-        },
-    );
-    ivfhnswsq.train(&data, s.n);
-    ivfhnswsq.add(&data, &ids, s.n);
-    ivfhnswsq.build_graphs().unwrap();
-    println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
-    print_sizes(&ivfpq, &ivfflat, &ivfhnswflat, &ivfhnswsq);
+    let mut ivfsq = IVFHNSWSQIndex::new(s.d, s.nlist, MetricType::L2, 
HnswBuildParams::default());
+    ivfsq.train(&data, s.n);
+    ivfsq.add(&data, &ids, s.n);
+    println!("build IVF-SQ scan: {:.2}s", start.elapsed().as_secs_f64());
+    print_base_sizes(&ivfpq, &ivfflat, &ivfsq);
 
     println!();
+    println!("baseline exact scans over stored representations");
     println!(
         "index      nprobe  ef      recall@{}  query_ms  us/query",
         s.k
@@ -157,6 +129,50 @@ fn run_scenario(s: Scenario<'_>) {
             s.nq,
         );
 
+        let mut distances = vec![0.0f32; s.nq * s.k];
+        let mut labels = vec![0i64; s.nq * s.k];
+        let start = Instant::now();
+        ivfsq.search(queries, s.nq, s.k, nprobe, s.k, &mut distances, &mut 
labels);
+        let elapsed = start.elapsed();
+        print_row(
+            "IVF-SQ",
+            nprobe,
+            None,
+            recall_at_k(&labels, &ground_truth, s.nq, s.k),
+            elapsed,
+            s.nq,
+        );
+    }
+
+    let hnsw_params = HnswBuildParams::default();
+    println!();
+    println!(
+        "hnsw params: m={}, ef_construction={}",
+        hnsw_params.m, hnsw_params.ef_construction
+    );
+
+    let start = Instant::now();
+    let mut ivfhnswflat = IVFHNSWFlatIndex::new(s.d, s.nlist, MetricType::L2, 
hnsw_params);
+    ivfhnswflat.train(&data, s.n);
+    ivfhnswflat.add(&data, &ids, s.n);
+    ivfhnswflat.build_graphs().unwrap();
+    println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
+
+    let start = Instant::now();
+    let mut ivfhnswsq = IVFHNSWSQIndex::new(s.d, s.nlist, MetricType::L2, 
hnsw_params);
+    ivfhnswsq.train(&data, s.n);
+    ivfhnswsq.add(&data, &ids, s.n);
+    ivfhnswsq.build_graphs().unwrap();
+    println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
+    print_hnsw_sizes(&ivfhnswflat, &ivfhnswsq);
+
+    println!(
+        "index      nprobe  ef      recall@{}  query_ms  us/query",
+        s.k
+    );
+    println!("---------  ------  ------  ---------  --------  --------");
+
+    for &nprobe in s.nprobes {
         for &ef_search in s.hnsw_search_efs {
             let mut distances = vec![0.0f32; s.nq * s.k];
             let mut labels = vec![0i64; s.nq * s.k];
@@ -205,30 +221,46 @@ fn run_scenario(s: Scenario<'_>) {
     }
 }
 
-fn print_sizes(
-    ivfpq: &IVFPQIndex,
-    ivfflat: &IVFFlatIndex,
-    ivfhnswflat: &IVFHNSWFlatIndex,
-    ivfhnswsq: &IVFHNSWSQIndex,
-) {
+fn print_base_sizes(ivfpq: &IVFPQIndex, ivfflat: &IVFFlatIndex, ivfsq: 
&IVFHNSWSQIndex) {
     let mut pq = Vec::new();
     write_index(ivfpq, &mut PosWriter::new(&mut pq)).unwrap();
     let mut flat = Vec::new();
     write_ivfflat_index(ivfflat, &mut PosWriter::new(&mut flat)).unwrap();
+
+    println!(
+        "baseline sizes: IVF-PQ={:.2} MiB, IVF-FLAT={:.2} MiB, IVF-SQ 
payload~{:.2} MiB",
+        bytes_to_mib(pq.len()),
+        bytes_to_mib(flat.len()),
+        bytes_to_mib(ivfsq_payload_bytes(ivfsq))
+    );
+}
+
+fn print_hnsw_sizes(ivfhnswflat: &IVFHNSWFlatIndex, ivfhnswsq: 
&IVFHNSWSQIndex) {
     let mut hnswflat = Vec::new();
     write_ivfhnswflat_index(ivfhnswflat, &mut PosWriter::new(&mut 
hnswflat)).unwrap();
     let mut hnswsq = Vec::new();
     write_ivfhnswsq_index(ivfhnswsq, &mut PosWriter::new(&mut 
hnswsq)).unwrap();
 
     println!(
-        "serialized sizes: IVF-PQ={:.2} MiB, IVF-FLAT={:.2} MiB, 
IVF-HNSW-FLAT={:.2} MiB, IVF-HNSW-SQ={:.2} MiB",
-        bytes_to_mib(pq.len()),
-        bytes_to_mib(flat.len()),
+        "serialized sizes: IVF-HNSW-FLAT={:.2} MiB, IVF-HNSW-SQ={:.2} MiB",
         bytes_to_mib(hnswflat.len()),
         bytes_to_mib(hnswsq.len())
     );
 }
 
+fn ivfsq_payload_bytes(index: &IVFHNSWSQIndex) -> usize {
+    let id_bytes: usize = index.ids.iter().map(|ids| ids.len() * 8).sum();
+    let code_bytes: usize = index.codes.iter().map(Vec::len).sum();
+    let centroid_bytes = index.quantizer_centroids.len() * 
std::mem::size_of::<f32>();
+    let global_sq_bytes = (index.sq.mins.len() + index.sq.maxs.len()) * 
std::mem::size_of::<f32>();
+    let list_sq_bytes: usize = index
+        .list_sqs
+        .iter()
+        .map(|sq| (sq.mins.len() + sq.maxs.len()) * std::mem::size_of::<f32>())
+        .sum();
+    id_bytes + code_bytes + centroid_bytes + global_sq_bytes + list_sq_bytes
+}
+
 fn bytes_to_mib(bytes: usize) -> f64 {
     bytes as f64 / 1024.0 / 1024.0
 }
diff --git a/core/src/distance.rs b/core/src/distance.rs
index 4cbadd6..3d2ec90 100644
--- a/core/src/distance.rs
+++ b/core/src/distance.rs
@@ -35,8 +35,44 @@ impl MetricType {
 }
 
 /// Squared L2 distance between two vectors.
+#[inline]
 pub fn fvec_l2sqr(a: &[f32], b: &[f32]) -> f32 {
-    debug_assert_eq!(a.len(), b.len());
+    assert_eq!(
+        a.len(),
+        b.len(),
+        "fvec_l2sqr inputs must have the same length"
+    );
+    fvec_l2sqr_simd(a, b)
+}
+
+#[cfg(target_arch = "x86_64")]
+#[inline]
+fn fvec_l2sqr_simd(a: &[f32], b: &[f32]) -> f32 {
+    if is_x86_feature_detected!("avx2") && a.len() >= 8 {
+        unsafe { fvec_l2sqr_avx2(a, b) }
+    } else {
+        fvec_l2sqr_scalar(a, b)
+    }
+}
+
+#[cfg(target_arch = "aarch64")]
+#[inline]
+fn fvec_l2sqr_simd(a: &[f32], b: &[f32]) -> f32 {
+    unsafe { fvec_l2sqr_neon(a, b) }
+}
+
+#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
+#[inline]
+fn fvec_l2sqr_simd(a: &[f32], b: &[f32]) -> f32 {
+    fvec_l2sqr_scalar(a, b)
+}
+
+#[cfg(any(
+    target_arch = "x86_64",
+    not(any(target_arch = "x86_64", target_arch = "aarch64"))
+))]
+#[inline]
+fn fvec_l2sqr_scalar(a: &[f32], b: &[f32]) -> f32 {
     let mut sum = 0.0f32;
     for i in 0..a.len() {
         let d = a[i] - b[i];
@@ -45,6 +81,69 @@ pub fn fvec_l2sqr(a: &[f32], b: &[f32]) -> f32 {
     sum
 }
 
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn fvec_l2sqr_avx2(a: &[f32], b: &[f32]) -> f32 {
+    use std::arch::x86_64::*;
+
+    let n = a.len();
+    let mut sum = _mm256_setzero_ps();
+    let mut i = 0;
+    while i + 8 <= n {
+        let va = unsafe { _mm256_loadu_ps(a.as_ptr().add(i)) };
+        let vb = unsafe { _mm256_loadu_ps(b.as_ptr().add(i)) };
+        let diff = _mm256_sub_ps(va, vb);
+        sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
+        i += 8;
+    }
+
+    let hi = _mm256_extractf128_ps::<1>(sum);
+    let lo = _mm256_castps256_ps128(sum);
+    let sum128 = _mm_add_ps(lo, hi);
+    let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
+    let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps::<1>(sum64, sum64));
+    let mut result = _mm_cvtss_f32(sum32);
+
+    while i < n {
+        let d = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) };
+        result += d * d;
+        i += 1;
+    }
+    result
+}
+
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn fvec_l2sqr_neon(a: &[f32], b: &[f32]) -> f32 {
+    use std::arch::aarch64::*;
+
+    let n = a.len();
+    let mut sum0 = vdupq_n_f32(0.0);
+    let mut sum1 = vdupq_n_f32(0.0);
+    let mut i = 0;
+    while i + 8 <= n {
+        let va0 = unsafe { vld1q_f32(a.as_ptr().add(i)) };
+        let vb0 = unsafe { vld1q_f32(b.as_ptr().add(i)) };
+        let diff0 = vsubq_f32(va0, vb0);
+        sum0 = vmlaq_f32(sum0, diff0, diff0);
+
+        let va1 = unsafe { vld1q_f32(a.as_ptr().add(i + 4)) };
+        let vb1 = unsafe { vld1q_f32(b.as_ptr().add(i + 4)) };
+        let diff1 = vsubq_f32(va1, vb1);
+        sum1 = vmlaq_f32(sum1, diff1, diff1);
+
+        i += 8;
+    }
+
+    let mut result = vaddvq_f32(vaddq_f32(sum0, sum1));
+    while i < n {
+        let d = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) };
+        result += d * d;
+        i += 1;
+    }
+    result
+}
+
 /// Squared L2 distance on sub-vectors.
 pub fn fvec_l2sqr_sub(a: &[f32], a_off: usize, b: &[f32], b_off: usize, len: 
usize) -> f32 {
     let mut sum = 0.0f32;
@@ -478,6 +577,14 @@ mod tests {
         assert!((fvec_l2sqr(&a, &b) - 27.0).abs() < 1e-6);
     }
 
+    #[test]
+    #[should_panic(expected = "fvec_l2sqr inputs must have the same length")]
+    fn test_l2sqr_rejects_mismatched_lengths() {
+        let a = [1.0, 2.0, 3.0];
+        let b = [4.0, 5.0];
+        let _ = fvec_l2sqr(&a, &b);
+    }
+
     #[test]
     fn test_inner_product() {
         let a = [1.0, 2.0, 3.0];
diff --git a/core/src/hnsw.rs b/core/src/hnsw.rs
index 4e4683a..8b939d9 100644
--- a/core/src/hnsw.rs
+++ b/core/src/hnsw.rs
@@ -16,9 +16,16 @@
 // under the License.
 
 use crate::distance::{fvec_distance, MetricType};
+use rayon::prelude::*;
 use std::cmp::Reverse;
 use std::collections::BinaryHeap;
 use std::io;
+use std::sync::RwLock;
+
+// Parallel insertion pays off once an IVF list is large enough to amortize
+// lock and per-worker visited-set setup. Smaller lists stay on the lean
+// sequential path to avoid nested Rayon overhead.
+const PARALLEL_BUILD_MIN_N: usize = 5_000;
 
 #[derive(Debug, Clone, Copy)]
 pub struct HnswBuildParams {
@@ -81,11 +88,39 @@ impl HnswGraph {
             ));
         }
 
+        Self::build_owned(vectors[..expected_len].to_vec(), n, d, metric, 
params)
+    }
+
+    pub(crate) fn build_owned(
+        vectors: Vec<f32>,
+        n: usize,
+        d: usize,
+        metric: MetricType,
+        params: HnswBuildParams,
+    ) -> io::Result<Self> {
+        let expected_len = n.checked_mul(d).ok_or_else(|| {
+            io::Error::new(io::ErrorKind::InvalidInput, "n * dimension 
overflows usize")
+        })?;
+        if vectors.len() != expected_len {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "vector data length {} does not match n*d {}",
+                    vectors.len(),
+                    expected_len
+                ),
+            ));
+        }
+
         let params = params.sanitized();
+        if n >= PARALLEL_BUILD_MIN_N {
+            return Ok(Self::build_parallel(vectors, n, d, metric, params));
+        }
+
         let mut graph = HnswGraph {
             d,
             metric,
-            vectors: vectors[..n * d].to_vec(),
+            vectors,
             levels: Vec::with_capacity(n),
             neighbors: Vec::with_capacity(n),
             entry_point: 0,
@@ -93,12 +128,68 @@ impl HnswGraph {
             params,
         };
 
+        let mut workspace = HnswBuildWorkspace::new(n, params.ef_construction);
         for node in 0..n {
-            graph.insert(node);
+            graph.insert(node, &mut workspace);
         }
         Ok(graph)
     }
 
+    fn build_parallel(
+        vectors: Vec<f32>,
+        n: usize,
+        d: usize,
+        metric: MetricType,
+        params: HnswBuildParams,
+    ) -> Self {
+        let levels = parallel_build_levels(n, params);
+        let max_observed_level = levels.iter().copied().max().unwrap_or(0);
+        let nodes = levels
+            .iter()
+            .map(|&level| RwLock::new(ParallelBuildNode::new(level)))
+            .collect::<Vec<_>>();
+
+        {
+            let builder = ParallelHnswBuilder {
+                d,
+                metric,
+                vectors: &vectors,
+                levels: &levels,
+                nodes: &nodes,
+                params,
+                entry_point: 0,
+                max_observed_level,
+            };
+            (1..n).into_par_iter().for_each_init(
+                || HnswBuildWorkspace::new(n, params.ef_construction),
+                |workspace, node| builder.insert(node, workspace),
+            );
+        }
+
+        let neighbors = nodes
+            .into_iter()
+            .map(|node| {
+                node.into_inner()
+                    .expect("parallel HNSW builder lock poisoned")
+                    .levels
+                    .into_iter()
+                    .map(|level| level.into_iter().map(|neighbor| 
neighbor.id).collect())
+                    .collect()
+            })
+            .collect();
+
+        Self {
+            d,
+            metric,
+            vectors,
+            levels,
+            neighbors,
+            entry_point: 0,
+            max_observed_level,
+            params,
+        }
+    }
+
     #[allow(clippy::too_many_arguments)]
     pub(crate) fn from_parts(
         vectors: Vec<f32>,
@@ -198,9 +289,25 @@ impl HnswGraph {
     }
 
     pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, 
f32)> {
+        let mut visited = Vec::new();
+        let mut visit_mark = 1usize;
+        self.search_with_workspace(query, k, ef, &mut visited, &mut visit_mark)
+    }
+
+    pub(crate) fn search_with_workspace(
+        &self,
+        query: &[f32],
+        k: usize,
+        ef: usize,
+        visited: &mut Vec<usize>,
+        visit_mark: &mut usize,
+    ) -> Vec<(usize, f32)> {
         if self.levels.is_empty() || k == 0 {
             return Vec::new();
         }
+        if visited.len() < self.levels.len() {
+            visited.resize(self.levels.len(), 0);
+        }
 
         let mut ep = self.entry_point;
         let mut ep_dist = self.distance_to_query(query, ep);
@@ -210,8 +317,9 @@ impl HnswGraph {
             ep_dist = dist;
         }
 
-        let mut visited = vec![0usize; self.levels.len()];
-        let candidates = self.search_layer_query(query, ep, ef.max(k), 0, &mut 
visited, 1);
+        let current_mark = *visit_mark;
+        let candidates = self.search_layer_query(query, ep, ef.max(k), 0, 
visited, current_mark);
+        *visit_mark = advance_visit_mark(visited, current_mark);
         candidates
             .into_iter()
             .take(k)
@@ -255,7 +363,7 @@ impl HnswGraph {
         self.max_observed_level
     }
 
-    fn insert(&mut self, node: usize) {
+    fn insert(&mut self, node: usize, workspace: &mut HnswBuildWorkspace) {
         let level = random_level(node, self.params.m, self.params.max_level);
         self.levels.push(level);
         self.neighbors.push(vec![Vec::new(); level + 1]);
@@ -275,30 +383,15 @@ impl HnswGraph {
             ep_dist = dist;
         }
 
-        let mut visited = vec![0usize; self.levels.len()];
-        let mut visit_mark = 1usize;
         for layer in (0..=level.min(self.max_observed_level)).rev() {
-            let candidates = self.search_layer_node(
-                node,
-                ep,
-                self.params.ef_construction,
-                layer,
-                &mut visited,
-                visit_mark,
-            );
-            visit_mark = advance_visit_mark(&mut visited, visit_mark);
-            let selected = self.select_neighbors(candidates, 
self.max_neighbors(layer));
-            for neighbor in selected {
-                self.connect(node, neighbor.id, layer);
-            }
-            if let Some(best) = self.neighbors[node][layer]
-                .iter()
-                .copied()
-                .min_by(|&a, &b| {
-                    self.distance_between(node, a)
-                        .total_cmp(&self.distance_between(node, b))
-                })
-            {
+            self.search_layer_node_with_workspace(node, ep, layer, workspace);
+            let next_ep = workspace.output.first().map(|candidate| 
candidate.id);
+            let selected = workspace
+                .select_output_neighbors_by(self.max_neighbors(layer), 
|candidate, neighbor| {
+                    self.distance_between(candidate, neighbor)
+                });
+            self.connect_selected(node, selected, layer);
+            if let Some(best) = next_ep {
                 ep = best;
             }
         }
@@ -309,32 +402,57 @@ impl HnswGraph {
         }
     }
 
-    fn connect(&mut self, a: usize, b: usize, level: usize) {
-        if !self.neighbors[a][level].contains(&b) {
-            self.neighbors[a][level].push(b);
+    fn connect_selected(&mut self, node: usize, selected: &[ScoredNode], 
level: usize) {
+        let node_neighbors = &mut self.neighbors[node][level];
+        node_neighbors.clear();
+        node_neighbors.extend(selected.iter().map(|neighbor| neighbor.id));
+
+        for neighbor in selected {
+            let neighbor_id = neighbor.id;
+            if level < self.neighbors[neighbor_id].len()
+                && !self.neighbors[neighbor_id][level].contains(&node)
+            {
+                self.connect_reverse(node, neighbor_id, neighbor.dist, level);
+            }
         }
-        if level < self.neighbors[b].len() && 
!self.neighbors[b][level].contains(&a) {
-            self.neighbors[b][level].push(a);
-            let pruned = self.pruned_neighbors(b, level, 
self.max_neighbors(level));
-            self.neighbors[b][level] = pruned;
+    }
+
+    fn connect_reverse(&mut self, node: usize, neighbor: usize, distance: f32, 
level: usize) {
+        let max_neighbors = self.max_neighbors(level);
+        {
+            let neighbors = &self.neighbors[neighbor][level];
+            if neighbors.len() >= max_neighbors
+                && !neighbors
+                    .iter()
+                    .any(|&existing| distance < 
self.distance_between(neighbor, existing))
+            {
+                return;
+            }
+        }
+
+        self.neighbors[neighbor][level].push(node);
+        if self.neighbors[neighbor][level].len() > max_neighbors {
+            let pruned = self.pruned_neighbors(neighbor, level, max_neighbors);
+            self.neighbors[neighbor][level] = pruned;
         }
-        let pruned = self.pruned_neighbors(a, level, 
self.max_neighbors(level));
-        self.neighbors[a][level] = pruned;
     }
 
     fn pruned_neighbors(&self, node: usize, level: usize, max_neighbors: 
usize) -> Vec<usize> {
-        let mut ranked: Vec<ScoredNode> = self.neighbors[node][level]
+        let neighbors = &self.neighbors[node][level];
+        if neighbors.len() <= max_neighbors {
+            return neighbors.clone();
+        }
+
+        let ranked: Vec<ScoredNode> = neighbors
             .iter()
             .map(|&id| ScoredNode {
                 id,
                 dist: self.distance_between(node, id),
             })
             .collect();
-        ranked.sort_by(|a, b| a.dist.total_cmp(&b.dist));
-        ranked
+        self.select_neighbors(ranked, max_neighbors)
             .into_iter()
-            .take(max_neighbors)
-            .map(|node| node.id)
+            .map(|neighbor| neighbor.id)
             .collect()
     }
 
@@ -343,30 +461,19 @@ impl HnswGraph {
         mut candidates: Vec<ScoredNode>,
         max_neighbors: usize,
     ) -> Vec<ScoredNode> {
-        candidates.sort_by(|a, b| a.dist.total_cmp(&b.dist));
-        let mut selected: Vec<ScoredNode> = Vec::with_capacity(max_neighbors);
-        let mut backfill: Vec<ScoredNode> = Vec::new();
-        for candidate in candidates {
-            if selected.len() >= max_neighbors {
-                break;
-            }
-            let closer_to_selected = selected
-                .iter()
-                .any(|neighbor| self.distance_between(candidate.id, 
neighbor.id) < candidate.dist);
-            if !closer_to_selected {
-                selected.push(candidate);
-            } else {
-                backfill.push(candidate);
-            }
-        }
-        for candidate in backfill {
-            if selected.len() >= max_neighbors {
-                break;
-            }
-            if !selected.iter().any(|neighbor| neighbor.id == candidate.id) {
-                selected.push(candidate);
-            }
-        }
+        candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
+        self.select_neighbors_sorted(&candidates, max_neighbors)
+    }
+
+    fn select_neighbors_sorted(
+        &self,
+        candidates: &[ScoredNode],
+        max_neighbors: usize,
+    ) -> Vec<ScoredNode> {
+        let mut selected = 
Vec::with_capacity(max_neighbors.min(candidates.len()));
+        select_neighbors_sorted_into(candidates, max_neighbors, &mut selected, 
|a, b| {
+            self.distance_between(a, b)
+        });
         selected
     }
 
@@ -434,39 +541,79 @@ impl HnswGraph {
         })
     }
 
-    fn search_layer_node(
+    fn search_layer_node_with_workspace(
         &self,
         node: usize,
         entry: usize,
+        level: usize,
+        workspace: &mut HnswBuildWorkspace,
+    ) {
+        let visit_mark = workspace.visit_mark;
+        self.search_layer_into(
+            entry,
+            self.params.ef_construction,
+            level,
+            &mut workspace.visited,
+            visit_mark,
+            &mut workspace.candidates,
+            &mut workspace.results,
+            &mut workspace.output,
+            |id| self.distance_between(node, id),
+        );
+        workspace.visit_mark = advance_visit_mark(&mut workspace.visited, 
visit_mark);
+    }
+
+    fn search_layer(
+        &self,
+        entry: usize,
         ef: usize,
         level: usize,
         visited: &mut [usize],
         visit_mark: usize,
+        distance: impl FnMut(usize) -> f32,
     ) -> Vec<ScoredNode> {
-        self.search_layer(entry, ef, level, visited, visit_mark, |id| {
-            self.distance_between(node, id)
-        })
+        let mut candidates = BinaryHeap::with_capacity(ef);
+        let mut results = BinaryHeap::with_capacity(ef);
+        let mut output = Vec::with_capacity(ef);
+        self.search_layer_into(
+            entry,
+            ef,
+            level,
+            visited,
+            visit_mark,
+            &mut candidates,
+            &mut results,
+            &mut output,
+            distance,
+        );
+        output
     }
 
-    fn search_layer(
+    #[allow(clippy::too_many_arguments)]
+    fn search_layer_into(
         &self,
         entry: usize,
         ef: usize,
         level: usize,
         visited: &mut [usize],
         visit_mark: usize,
+        candidates: &mut BinaryHeap<Reverse<HeapNode>>,
+        results: &mut BinaryHeap<HeapNode>,
+        output: &mut Vec<ScoredNode>,
         mut distance: impl FnMut(usize) -> f32,
-    ) -> Vec<ScoredNode> {
+    ) {
+        candidates.clear();
+        results.clear();
+        output.clear();
+
         let entry_dist = distance(entry);
         visited[entry] = visit_mark;
 
-        let mut candidates = BinaryHeap::new();
         candidates.push(Reverse(HeapNode {
             id: entry,
             dist: entry_dist,
         }));
 
-        let mut results = BinaryHeap::new();
         results.push(HeapNode {
             id: entry,
             dist: entry_dist,
@@ -501,15 +648,11 @@ impl HnswGraph {
             }
         }
 
-        let mut result: Vec<ScoredNode> = results
-            .into_iter()
-            .map(|node| ScoredNode {
-                id: node.id,
-                dist: node.dist,
-            })
-            .collect();
-        result.sort_by(|a, b| a.dist.total_cmp(&b.dist));
-        result
+        output.extend(results.drain().map(|node| ScoredNode {
+            id: node.id,
+            dist: node.dist,
+        }));
+        output.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
     }
 
     fn max_neighbors(&self, level: usize) -> usize {
@@ -566,6 +709,339 @@ impl Ord for HeapNode {
     }
 }
 
+struct HnswBuildWorkspace {
+    visited: Vec<usize>,
+    visit_mark: usize,
+    candidates: BinaryHeap<Reverse<HeapNode>>,
+    results: BinaryHeap<HeapNode>,
+    output: Vec<ScoredNode>,
+    selected: Vec<ScoredNode>,
+}
+
+impl HnswBuildWorkspace {
+    fn new(n: usize, ef_construction: usize) -> Self {
+        Self {
+            visited: vec![0; n],
+            visit_mark: 1,
+            candidates: BinaryHeap::with_capacity(ef_construction),
+            results: BinaryHeap::with_capacity(ef_construction),
+            output: Vec::with_capacity(ef_construction),
+            selected: Vec::new(),
+        }
+    }
+
+    fn select_output_neighbors_by(
+        &mut self,
+        max_neighbors: usize,
+        distance_between: impl FnMut(usize, usize) -> f32,
+    ) -> &[ScoredNode] {
+        select_neighbors_sorted_into(
+            &self.output,
+            max_neighbors,
+            &mut self.selected,
+            distance_between,
+        );
+        &self.selected
+    }
+}
+
+struct ParallelHnswBuilder<'a> {
+    d: usize,
+    metric: MetricType,
+    vectors: &'a [f32],
+    levels: &'a [usize],
+    nodes: &'a [RwLock<ParallelBuildNode>],
+    params: HnswBuildParams,
+    entry_point: usize,
+    max_observed_level: usize,
+}
+
+impl ParallelHnswBuilder<'_> {
+    fn insert(&self, node: usize, workspace: &mut HnswBuildWorkspace) {
+        let level = self.nodes[node]
+            .read()
+            .expect("parallel HNSW builder lock poisoned")
+            .level();
+        let mut ep = self.entry_point;
+        let mut ep_dist = self.distance_between(node, ep);
+
+        for layer in ((level + 1)..=self.max_observed_level).rev() {
+            let (next, dist) = self.greedy_search_node(node, ep, ep_dist, 
layer);
+            ep = next;
+            ep_dist = dist;
+        }
+
+        for layer in (0..=level.min(self.max_observed_level)).rev() {
+            self.search_layer_node(node, ep, layer, workspace);
+            let next_ep = workspace.output.first().map(|candidate| 
candidate.id);
+            let selected = workspace
+                .select_output_neighbors_by(self.max_neighbors(layer), 
|candidate, neighbor| {
+                    self.distance_between(candidate, neighbor)
+                });
+            self.connect_selected(node, selected, layer);
+            if let Some(best) = next_ep {
+                ep = best;
+            }
+        }
+    }
+
+    fn connect_selected(&self, node: usize, selected: &[ScoredNode], level: 
usize) {
+        {
+            let mut current = self.nodes[node]
+                .write()
+                .expect("parallel HNSW builder lock poisoned");
+            let current_neighbors = &mut current.levels[level];
+            current_neighbors.clear();
+            current_neighbors.extend_from_slice(selected);
+        }
+
+        for neighbor in selected {
+            let neighbor_id = neighbor.id;
+            if self.levels[neighbor_id] >= level {
+                self.connect_reverse(node, neighbor_id, neighbor.dist, level);
+            }
+        }
+    }
+
+    fn connect_reverse(&self, node: usize, neighbor: usize, distance: f32, 
level: usize) {
+        let max_neighbors = self.max_neighbors(level);
+        {
+            let neighbor_node = self.nodes[neighbor]
+                .read()
+                .expect("parallel HNSW builder lock poisoned");
+            let neighbors = &neighbor_node.levels[level];
+            if neighbors.iter().any(|existing| existing.id == node) {
+                return;
+            }
+            if neighbors.len() >= max_neighbors
+                && !neighbors.iter().any(|existing| distance < existing.dist)
+            {
+                return;
+            }
+        }
+
+        let mut neighbor_node = self.nodes[neighbor]
+            .write()
+            .expect("parallel HNSW builder lock poisoned");
+        let neighbors = &mut neighbor_node.levels[level];
+        if neighbors.iter().any(|existing| existing.id == node) {
+            return;
+        }
+        neighbors.push(ScoredNode {
+            id: node,
+            dist: distance,
+        });
+        if neighbors.len() > max_neighbors {
+            let candidates = std::mem::take(neighbors);
+            let pruned = self.select_neighbors(candidates, max_neighbors);
+            *neighbors = pruned;
+        }
+    }
+
+    fn greedy_search_node(
+        &self,
+        node: usize,
+        mut current: usize,
+        mut current_dist: f32,
+        level: usize,
+    ) -> (usize, f32) {
+        loop {
+            let mut best = current;
+            let mut best_dist = current_dist;
+            self.for_each_neighbor(current, level, |neighbor| {
+                let dist = self.distance_between(node, neighbor);
+                if dist < best_dist {
+                    best = neighbor;
+                    best_dist = dist;
+                }
+            });
+            if best == current {
+                return (current, current_dist);
+            }
+            current = best;
+            current_dist = best_dist;
+        }
+    }
+
+    fn search_layer_node(
+        &self,
+        node: usize,
+        entry: usize,
+        level: usize,
+        workspace: &mut HnswBuildWorkspace,
+    ) {
+        workspace.candidates.clear();
+        workspace.results.clear();
+        workspace.output.clear();
+
+        let visit_mark = workspace.visit_mark;
+        let entry_dist = self.distance_between(node, entry);
+        workspace.visited[entry] = visit_mark;
+        workspace.candidates.push(Reverse(HeapNode {
+            id: entry,
+            dist: entry_dist,
+        }));
+        workspace.results.push(HeapNode {
+            id: entry,
+            dist: entry_dist,
+        });
+
+        while let Some(Reverse(current)) = workspace.candidates.pop() {
+            let worst = workspace
+                .results
+                .peek()
+                .map(|node| node.dist)
+                .unwrap_or(f32::INFINITY);
+            if current.dist > worst && workspace.results.len() >= 
self.params.ef_construction {
+                break;
+            }
+
+            self.for_each_neighbor(current.id, level, |neighbor| {
+                if workspace.visited[neighbor] == visit_mark {
+                    return;
+                }
+                workspace.visited[neighbor] = visit_mark;
+                let dist = self.distance_between(node, neighbor);
+                let worst = workspace
+                    .results
+                    .peek()
+                    .map(|node| node.dist)
+                    .unwrap_or(f32::INFINITY);
+                if workspace.results.len() < self.params.ef_construction || 
dist < worst {
+                    workspace
+                        .candidates
+                        .push(Reverse(HeapNode { id: neighbor, dist }));
+                    workspace.results.push(HeapNode { id: neighbor, dist });
+                    if workspace.results.len() > self.params.ef_construction {
+                        workspace.results.pop();
+                    }
+                }
+            });
+        }
+
+        workspace
+            .output
+            .extend(workspace.results.drain().map(|node| ScoredNode {
+                id: node.id,
+                dist: node.dist,
+            }));
+        workspace
+            .output
+            .sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
+        workspace.visit_mark = advance_visit_mark(&mut workspace.visited, 
visit_mark);
+    }
+
+    fn select_neighbors(
+        &self,
+        mut candidates: Vec<ScoredNode>,
+        max_neighbors: usize,
+    ) -> Vec<ScoredNode> {
+        candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
+        self.select_neighbors_sorted(&candidates, max_neighbors)
+    }
+
+    fn select_neighbors_sorted(
+        &self,
+        candidates: &[ScoredNode],
+        max_neighbors: usize,
+    ) -> Vec<ScoredNode> {
+        let mut selected = 
Vec::with_capacity(max_neighbors.min(candidates.len()));
+        select_neighbors_sorted_into(candidates, max_neighbors, &mut selected, 
|a, b| {
+            self.distance_between(a, b)
+        });
+        selected
+    }
+
+    fn for_each_neighbor(&self, node: usize, level: usize, mut f: impl 
FnMut(usize)) {
+        let node = self.nodes[node]
+            .read()
+            .expect("parallel HNSW builder lock poisoned");
+        if let Some(neighbors) = node.levels.get(level) {
+            for neighbor in neighbors {
+                f(neighbor.id);
+            }
+        }
+    }
+
+    fn max_neighbors(&self, level: usize) -> usize {
+        if level == 0 {
+            self.params.m * 2
+        } else {
+            self.params.m
+        }
+    }
+
+    fn distance_between(&self, a: usize, b: usize) -> f32 {
+        let va = &self.vectors[a * self.d..(a + 1) * self.d];
+        let vb = &self.vectors[b * self.d..(b + 1) * self.d];
+        fvec_distance(va, vb, self.metric)
+    }
+}
+
+struct ParallelBuildNode {
+    levels: Vec<Vec<ScoredNode>>,
+}
+
+impl ParallelBuildNode {
+    fn new(level: usize) -> Self {
+        Self {
+            levels: vec![Vec::new(); level + 1],
+        }
+    }
+
+    fn level(&self) -> usize {
+        self.levels.len() - 1
+    }
+}
+
+fn parallel_build_levels(n: usize, params: HnswBuildParams) -> Vec<usize> {
+    let mut levels: Vec<_> = (0..n)
+        .map(|node| random_level(node, params.m, params.max_level))
+        .collect();
+    if let Some(first) = levels.first_mut() {
+        // LanceDB keeps the fixed entry point reachable from every configured
+        // layer during parallel build. Mirroring that avoids a serialized
+        // "promote newest max-level node" phase while preserving high-level
+        // search quality.
+        *first = params.max_level - 1;
+    }
+    levels
+}
+
+fn select_neighbors_sorted_into(
+    candidates: &[ScoredNode],
+    max_neighbors: usize,
+    selected: &mut Vec<ScoredNode>,
+    mut distance_between: impl FnMut(usize, usize) -> f32,
+) {
+    selected.clear();
+    if candidates.len() <= max_neighbors {
+        selected.extend_from_slice(candidates);
+        return;
+    }
+
+    selected.reserve(max_neighbors.saturating_sub(selected.len()));
+    for &candidate in candidates {
+        if selected.len() >= max_neighbors {
+            break;
+        }
+        let closer_to_selected = selected
+            .iter()
+            .any(|neighbor| distance_between(candidate.id, neighbor.id) < 
candidate.dist);
+        if !closer_to_selected {
+            selected.push(candidate);
+        }
+    }
+    for &candidate in candidates {
+        if selected.len() >= max_neighbors {
+            break;
+        }
+        if !selected.iter().any(|neighbor| neighbor.id == candidate.id) {
+            selected.push(candidate);
+        }
+    }
+}
+
 fn random_level(node: usize, m: usize, max_level: usize) -> usize {
     if node == 0 || max_level <= 1 {
         // Keep the first insertion deterministic. Later higher-level nodes 
replace
@@ -695,6 +1171,36 @@ mod tests {
         assert!(recall >= 0.95, "recall={}", recall);
     }
 
+    #[test]
+    fn test_hnsw_parallel_build_large_partition_recall_tracks_exact_search() {
+        let d = 16;
+        let n = PARALLEL_BUILD_MIN_N + 512;
+        let nq = 32;
+        let k = 10;
+        let data = generate_clustered_data(n, d, 32);
+        let params = HnswBuildParams {
+            m: 16,
+            ef_construction: 200,
+            max_level: 7,
+        };
+
+        let graph = HnswGraph::build(&data, n, d, MetricType::L2, 
params).unwrap();
+        let mut hits = 0usize;
+        for qi in 0..nq {
+            let query = &data[qi * d..(qi + 1) * d];
+            let expected = exact_topk(&data, n, d, query, k);
+            let actual = graph.search(query, k, 200);
+            hits += actual
+                .iter()
+                .filter(|(id, _)| expected.contains(id))
+                .count();
+        }
+
+        let recall = hits as f32 / (nq * k) as f32;
+        assert!(recall >= 0.95, "recall={}", recall);
+        assert!(graph.max_degree() <= params.m * 2);
+    }
+
     #[test]
     fn test_hnsw_neighbor_selection_backfills_after_diversification() {
         let d = 1;
@@ -722,6 +1228,31 @@ mod tests {
         assert_eq!(selected.len(), 3);
     }
 
+    #[test]
+    fn test_hnsw_pruning_keeps_diverse_neighbors() {
+        let graph = HnswGraph::from_parts(
+            vec![0.0, 0.0, 1.0, 0.0, 1.1, 0.0, 0.0, 2.0],
+            4,
+            2,
+            MetricType::L2,
+            vec![0, 0, 0, 0],
+            vec![
+                vec![vec![1, 2, 3]],
+                vec![vec![]],
+                vec![vec![]],
+                vec![vec![]],
+            ],
+            0,
+            0,
+            HnswBuildParams::default(),
+        )
+        .unwrap();
+
+        let selected = graph.pruned_neighbors(0, 0, 2);
+
+        assert_eq!(selected, vec![1, 3]);
+    }
+
     #[test]
     fn test_hnsw_greedy_search_chooses_best_improving_neighbor() {
         let graph = HnswGraph::from_parts(
diff --git a/core/src/hnsw_search.rs b/core/src/hnsw_search.rs
index 74eab88..a975c1c 100644
--- a/core/src/hnsw_search.rs
+++ b/core/src/hnsw_search.rs
@@ -37,6 +37,8 @@ where
     F: FnMut(&HnswSearchList<'a, P>, &mut TopKHeap),
 {
     let mut heap = TopKHeap::new(k);
+    let mut visited = Vec::new();
+    let mut visit_mark = 1usize;
     let force_scan = filter
         .map(|f| count_filtered(lists, f) <= ef_search.max(k))
         .unwrap_or(false);
@@ -47,7 +49,13 @@ where
             continue;
         }
         if let Some(graph) = list.graph {
-            let local_results = graph.search(query, ef_search.max(k), 
ef_search.max(k));
+            let local_results = graph.search_with_workspace(
+                query,
+                ef_search.max(k),
+                ef_search.max(k),
+                &mut visited,
+                &mut visit_mark,
+            );
             for (local_id, dist) in local_results {
                 let row_id = list.ids[local_id];
                 if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
diff --git a/core/src/ivfhnswflat.rs b/core/src/ivfhnswflat.rs
index 0d2db81..5ab508a 100644
--- a/core/src/ivfhnswflat.rs
+++ b/core/src/ivfhnswflat.rs
@@ -22,6 +22,7 @@ use crate::ivfflat::IVFFlatIndex;
 use crate::ivfpq::RowIdFilter;
 use crate::kmeans;
 use crate::topk::TopKHeap;
+use rayon::prelude::*;
 use std::io;
 
 pub struct IVFHNSWFlatIndex {
@@ -51,20 +52,24 @@ impl IVFHNSWFlatIndex {
     }
 
     pub fn build_graphs(&mut self) -> io::Result<()> {
-        for list_id in 0..self.flat.nlist {
-            let count = self.flat.ids[list_id].len();
-            self.graphs[list_id] = if count == 0 {
-                None
-            } else {
-                Some(HnswGraph::build(
-                    &self.flat.vectors[list_id],
-                    count,
-                    self.flat.d,
-                    self.flat.metric,
-                    self.hnsw_params,
-                )?)
-            };
-        }
+        self.graphs = (0..self.flat.nlist)
+            .into_par_iter()
+            .map(|list_id| {
+                let count = self.flat.ids[list_id].len();
+                if count == 0 {
+                    Ok(None)
+                } else {
+                    HnswGraph::build(
+                        &self.flat.vectors[list_id],
+                        count,
+                        self.flat.d,
+                        self.flat.metric,
+                        self.hnsw_params,
+                    )
+                    .map(Some)
+                }
+            })
+            .collect::<io::Result<Vec<_>>>()?;
         Ok(())
     }
 
diff --git a/core/src/ivfhnswsq.rs b/core/src/ivfhnswsq.rs
index 4ce2a6e..a2cc8f2 100644
--- a/core/src/ivfhnswsq.rs
+++ b/core/src/ivfhnswsq.rs
@@ -22,6 +22,7 @@ use crate::ivfpq::RowIdFilter;
 use crate::kmeans::{self, KMeansConfig};
 use crate::sq::ScalarQuantizer;
 use crate::topk::TopKHeap;
+use rayon::prelude::*;
 use std::io;
 
 pub struct IVFHNSWSQIndex {
@@ -82,21 +83,19 @@ impl IVFHNSWSQIndex {
     }
 
     pub fn build_graphs(&mut self) -> io::Result<()> {
-        for list_id in 0..self.nlist {
-            let count = self.ids[list_id].len();
-            self.graphs[list_id] = if count == 0 {
-                None
-            } else {
-                let vectors = self.decode_list_vectors(list_id, count);
-                Some(HnswGraph::build(
-                    &vectors,
-                    count,
-                    self.d,
-                    self.metric,
-                    self.hnsw_params,
-                )?)
-            };
-        }
+        self.graphs = (0..self.nlist)
+            .into_par_iter()
+            .map(|list_id| {
+                let count = self.ids[list_id].len();
+                if count == 0 {
+                    Ok(None)
+                } else {
+                    let vectors = self.decode_list_vectors(list_id, count);
+                    HnswGraph::build_owned(vectors, count, self.d, 
self.metric, self.hnsw_params)
+                        .map(Some)
+                }
+            })
+            .collect::<io::Result<Vec<_>>>()?;
         Ok(())
     }
 


Reply via email to