This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new ac63164 Optimize HNSW index build performance (#24)
ac63164 is described below
commit ac63164cb5c2071fbae415c7c499e2ad30ac8f5c
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 21:58:02 2026 +0800
Optimize HNSW index build performance (#24)
---
core/benches/recall_bench.rs | 118 +++++---
core/src/distance.rs | 109 ++++++-
core/src/hnsw.rs | 697 +++++++++++++++++++++++++++++++++++++------
core/src/hnsw_search.rs | 10 +-
core/src/ivfhnswflat.rs | 33 +-
core/src/ivfhnswsq.rs | 29 +-
6 files changed, 839 insertions(+), 157 deletions(-)
diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index aef434f..0f2887f 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -21,7 +21,6 @@ fn main() {
nlist: 64,
pq_m: 8,
nprobes: &[1, 4, 8, 16, 32, 64],
- hnsw_build_ef: 80,
hnsw_search_efs: &[80],
});
@@ -36,7 +35,6 @@ fn main() {
nlist: 8,
pq_m: 8,
nprobes: &[1, 2, 4, 8],
- hnsw_build_ef: 200,
hnsw_search_efs: &[80, 160, 320],
});
}
@@ -50,7 +48,6 @@ struct Scenario<'a> {
nlist: usize,
pq_m: usize,
nprobes: &'a [usize],
- hnsw_build_ef: usize,
hnsw_search_efs: &'a [usize],
}
@@ -89,39 +86,14 @@ fn run_scenario(s: Scenario<'_>) {
println!("build IVF-FLAT: {:.2}s", start.elapsed().as_secs_f64());
let start = Instant::now();
- let mut ivfhnswflat = IVFHNSWFlatIndex::new(
- s.d,
- s.nlist,
- MetricType::L2,
- HnswBuildParams {
- m: 16,
- ef_construction: s.hnsw_build_ef,
- max_level: 7,
- },
- );
- ivfhnswflat.train(&data, s.n);
- ivfhnswflat.add(&data, &ids, s.n);
- ivfhnswflat.build_graphs().unwrap();
- println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
-
- let start = Instant::now();
- let mut ivfhnswsq = IVFHNSWSQIndex::new(
- s.d,
- s.nlist,
- MetricType::L2,
- HnswBuildParams {
- m: 16,
- ef_construction: s.hnsw_build_ef,
- max_level: 7,
- },
- );
- ivfhnswsq.train(&data, s.n);
- ivfhnswsq.add(&data, &ids, s.n);
- ivfhnswsq.build_graphs().unwrap();
- println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
- print_sizes(&ivfpq, &ivfflat, &ivfhnswflat, &ivfhnswsq);
+ let mut ivfsq = IVFHNSWSQIndex::new(s.d, s.nlist, MetricType::L2,
HnswBuildParams::default());
+ ivfsq.train(&data, s.n);
+ ivfsq.add(&data, &ids, s.n);
+ println!("build IVF-SQ scan: {:.2}s", start.elapsed().as_secs_f64());
+ print_base_sizes(&ivfpq, &ivfflat, &ivfsq);
println!();
+ println!("baseline exact scans over stored representations");
println!(
"index nprobe ef recall@{} query_ms us/query",
s.k
@@ -157,6 +129,50 @@ fn run_scenario(s: Scenario<'_>) {
s.nq,
);
+ let mut distances = vec![0.0f32; s.nq * s.k];
+ let mut labels = vec![0i64; s.nq * s.k];
+ let start = Instant::now();
+ ivfsq.search(queries, s.nq, s.k, nprobe, s.k, &mut distances, &mut
labels);
+ let elapsed = start.elapsed();
+ print_row(
+ "IVF-SQ",
+ nprobe,
+ None,
+ recall_at_k(&labels, &ground_truth, s.nq, s.k),
+ elapsed,
+ s.nq,
+ );
+ }
+
+ let hnsw_params = HnswBuildParams::default();
+ println!();
+ println!(
+ "hnsw params: m={}, ef_construction={}",
+ hnsw_params.m, hnsw_params.ef_construction
+ );
+
+ let start = Instant::now();
+ let mut ivfhnswflat = IVFHNSWFlatIndex::new(s.d, s.nlist, MetricType::L2,
hnsw_params);
+ ivfhnswflat.train(&data, s.n);
+ ivfhnswflat.add(&data, &ids, s.n);
+ ivfhnswflat.build_graphs().unwrap();
+ println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
+
+ let start = Instant::now();
+ let mut ivfhnswsq = IVFHNSWSQIndex::new(s.d, s.nlist, MetricType::L2,
hnsw_params);
+ ivfhnswsq.train(&data, s.n);
+ ivfhnswsq.add(&data, &ids, s.n);
+ ivfhnswsq.build_graphs().unwrap();
+ println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
+ print_hnsw_sizes(&ivfhnswflat, &ivfhnswsq);
+
+ println!(
+ "index nprobe ef recall@{} query_ms us/query",
+ s.k
+ );
+ println!("--------- ------ ------ --------- -------- --------");
+
+ for &nprobe in s.nprobes {
for &ef_search in s.hnsw_search_efs {
let mut distances = vec![0.0f32; s.nq * s.k];
let mut labels = vec![0i64; s.nq * s.k];
@@ -205,30 +221,46 @@ fn run_scenario(s: Scenario<'_>) {
}
}
-fn print_sizes(
- ivfpq: &IVFPQIndex,
- ivfflat: &IVFFlatIndex,
- ivfhnswflat: &IVFHNSWFlatIndex,
- ivfhnswsq: &IVFHNSWSQIndex,
-) {
+fn print_base_sizes(ivfpq: &IVFPQIndex, ivfflat: &IVFFlatIndex, ivfsq:
&IVFHNSWSQIndex) {
let mut pq = Vec::new();
write_index(ivfpq, &mut PosWriter::new(&mut pq)).unwrap();
let mut flat = Vec::new();
write_ivfflat_index(ivfflat, &mut PosWriter::new(&mut flat)).unwrap();
+
+ println!(
+ "baseline sizes: IVF-PQ={:.2} MiB, IVF-FLAT={:.2} MiB, IVF-SQ
payload~{:.2} MiB",
+ bytes_to_mib(pq.len()),
+ bytes_to_mib(flat.len()),
+ bytes_to_mib(ivfsq_payload_bytes(ivfsq))
+ );
+}
+
+fn print_hnsw_sizes(ivfhnswflat: &IVFHNSWFlatIndex, ivfhnswsq:
&IVFHNSWSQIndex) {
let mut hnswflat = Vec::new();
write_ivfhnswflat_index(ivfhnswflat, &mut PosWriter::new(&mut
hnswflat)).unwrap();
let mut hnswsq = Vec::new();
write_ivfhnswsq_index(ivfhnswsq, &mut PosWriter::new(&mut
hnswsq)).unwrap();
println!(
- "serialized sizes: IVF-PQ={:.2} MiB, IVF-FLAT={:.2} MiB,
IVF-HNSW-FLAT={:.2} MiB, IVF-HNSW-SQ={:.2} MiB",
- bytes_to_mib(pq.len()),
- bytes_to_mib(flat.len()),
+ "serialized sizes: IVF-HNSW-FLAT={:.2} MiB, IVF-HNSW-SQ={:.2} MiB",
bytes_to_mib(hnswflat.len()),
bytes_to_mib(hnswsq.len())
);
}
+fn ivfsq_payload_bytes(index: &IVFHNSWSQIndex) -> usize {
+ let id_bytes: usize = index.ids.iter().map(|ids| ids.len() * 8).sum();
+ let code_bytes: usize = index.codes.iter().map(Vec::len).sum();
+ let centroid_bytes = index.quantizer_centroids.len() *
std::mem::size_of::<f32>();
+ let global_sq_bytes = (index.sq.mins.len() + index.sq.maxs.len()) *
std::mem::size_of::<f32>();
+ let list_sq_bytes: usize = index
+ .list_sqs
+ .iter()
+ .map(|sq| (sq.mins.len() + sq.maxs.len()) * std::mem::size_of::<f32>())
+ .sum();
+ id_bytes + code_bytes + centroid_bytes + global_sq_bytes + list_sq_bytes
+}
+
fn bytes_to_mib(bytes: usize) -> f64 {
bytes as f64 / 1024.0 / 1024.0
}
diff --git a/core/src/distance.rs b/core/src/distance.rs
index 4cbadd6..3d2ec90 100644
--- a/core/src/distance.rs
+++ b/core/src/distance.rs
@@ -35,8 +35,44 @@ impl MetricType {
}
/// Squared L2 distance between two vectors.
+#[inline]
pub fn fvec_l2sqr(a: &[f32], b: &[f32]) -> f32 {
- debug_assert_eq!(a.len(), b.len());
+ assert_eq!(
+ a.len(),
+ b.len(),
+ "fvec_l2sqr inputs must have the same length"
+ );
+ fvec_l2sqr_simd(a, b)
+}
+
+#[cfg(target_arch = "x86_64")]
+#[inline]
+fn fvec_l2sqr_simd(a: &[f32], b: &[f32]) -> f32 {
+ if is_x86_feature_detected!("avx2") && a.len() >= 8 {
+ unsafe { fvec_l2sqr_avx2(a, b) }
+ } else {
+ fvec_l2sqr_scalar(a, b)
+ }
+}
+
+#[cfg(target_arch = "aarch64")]
+#[inline]
+fn fvec_l2sqr_simd(a: &[f32], b: &[f32]) -> f32 {
+ unsafe { fvec_l2sqr_neon(a, b) }
+}
+
+#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
+#[inline]
+fn fvec_l2sqr_simd(a: &[f32], b: &[f32]) -> f32 {
+ fvec_l2sqr_scalar(a, b)
+}
+
+#[cfg(any(
+ target_arch = "x86_64",
+ not(any(target_arch = "x86_64", target_arch = "aarch64"))
+))]
+#[inline]
+fn fvec_l2sqr_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
@@ -45,6 +81,69 @@ pub fn fvec_l2sqr(a: &[f32], b: &[f32]) -> f32 {
sum
}
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn fvec_l2sqr_avx2(a: &[f32], b: &[f32]) -> f32 {
+ use std::arch::x86_64::*;
+
+ let n = a.len();
+ let mut sum = _mm256_setzero_ps();
+ let mut i = 0;
+ while i + 8 <= n {
+ let va = unsafe { _mm256_loadu_ps(a.as_ptr().add(i)) };
+ let vb = unsafe { _mm256_loadu_ps(b.as_ptr().add(i)) };
+ let diff = _mm256_sub_ps(va, vb);
+ sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
+ i += 8;
+ }
+
+ let hi = _mm256_extractf128_ps::<1>(sum);
+ let lo = _mm256_castps256_ps128(sum);
+ let sum128 = _mm_add_ps(lo, hi);
+ let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
+ let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps::<1>(sum64, sum64));
+ let mut result = _mm_cvtss_f32(sum32);
+
+ while i < n {
+ let d = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) };
+ result += d * d;
+ i += 1;
+ }
+ result
+}
+
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn fvec_l2sqr_neon(a: &[f32], b: &[f32]) -> f32 {
+ use std::arch::aarch64::*;
+
+ let n = a.len();
+ let mut sum0 = vdupq_n_f32(0.0);
+ let mut sum1 = vdupq_n_f32(0.0);
+ let mut i = 0;
+ while i + 8 <= n {
+ let va0 = unsafe { vld1q_f32(a.as_ptr().add(i)) };
+ let vb0 = unsafe { vld1q_f32(b.as_ptr().add(i)) };
+ let diff0 = vsubq_f32(va0, vb0);
+ sum0 = vmlaq_f32(sum0, diff0, diff0);
+
+ let va1 = unsafe { vld1q_f32(a.as_ptr().add(i + 4)) };
+ let vb1 = unsafe { vld1q_f32(b.as_ptr().add(i + 4)) };
+ let diff1 = vsubq_f32(va1, vb1);
+ sum1 = vmlaq_f32(sum1, diff1, diff1);
+
+ i += 8;
+ }
+
+ let mut result = vaddvq_f32(vaddq_f32(sum0, sum1));
+ while i < n {
+ let d = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) };
+ result += d * d;
+ i += 1;
+ }
+ result
+}
+
/// Squared L2 distance on sub-vectors.
pub fn fvec_l2sqr_sub(a: &[f32], a_off: usize, b: &[f32], b_off: usize, len:
usize) -> f32 {
let mut sum = 0.0f32;
@@ -478,6 +577,14 @@ mod tests {
assert!((fvec_l2sqr(&a, &b) - 27.0).abs() < 1e-6);
}
+ #[test]
+ #[should_panic(expected = "fvec_l2sqr inputs must have the same length")]
+ fn test_l2sqr_rejects_mismatched_lengths() {
+ let a = [1.0, 2.0, 3.0];
+ let b = [4.0, 5.0];
+ let _ = fvec_l2sqr(&a, &b);
+ }
+
#[test]
fn test_inner_product() {
let a = [1.0, 2.0, 3.0];
diff --git a/core/src/hnsw.rs b/core/src/hnsw.rs
index 4e4683a..8b939d9 100644
--- a/core/src/hnsw.rs
+++ b/core/src/hnsw.rs
@@ -16,9 +16,16 @@
// under the License.
use crate::distance::{fvec_distance, MetricType};
+use rayon::prelude::*;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::io;
+use std::sync::RwLock;
+
+// Parallel insertion pays off once an IVF list is large enough to amortize
+// lock and per-worker visited-set setup. Smaller lists stay on the lean
+// sequential path to avoid nested Rayon overhead.
+const PARALLEL_BUILD_MIN_N: usize = 5_000;
#[derive(Debug, Clone, Copy)]
pub struct HnswBuildParams {
@@ -81,11 +88,39 @@ impl HnswGraph {
));
}
+ Self::build_owned(vectors[..expected_len].to_vec(), n, d, metric,
params)
+ }
+
+ pub(crate) fn build_owned(
+ vectors: Vec<f32>,
+ n: usize,
+ d: usize,
+ metric: MetricType,
+ params: HnswBuildParams,
+ ) -> io::Result<Self> {
+ let expected_len = n.checked_mul(d).ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "n * dimension
overflows usize")
+ })?;
+ if vectors.len() != expected_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "vector data length {} does not match n*d {}",
+ vectors.len(),
+ expected_len
+ ),
+ ));
+ }
+
let params = params.sanitized();
+ if n >= PARALLEL_BUILD_MIN_N {
+ return Ok(Self::build_parallel(vectors, n, d, metric, params));
+ }
+
let mut graph = HnswGraph {
d,
metric,
- vectors: vectors[..n * d].to_vec(),
+ vectors,
levels: Vec::with_capacity(n),
neighbors: Vec::with_capacity(n),
entry_point: 0,
@@ -93,12 +128,68 @@ impl HnswGraph {
params,
};
+ let mut workspace = HnswBuildWorkspace::new(n, params.ef_construction);
for node in 0..n {
- graph.insert(node);
+ graph.insert(node, &mut workspace);
}
Ok(graph)
}
+ fn build_parallel(
+ vectors: Vec<f32>,
+ n: usize,
+ d: usize,
+ metric: MetricType,
+ params: HnswBuildParams,
+ ) -> Self {
+ let levels = parallel_build_levels(n, params);
+ let max_observed_level = levels.iter().copied().max().unwrap_or(0);
+ let nodes = levels
+ .iter()
+ .map(|&level| RwLock::new(ParallelBuildNode::new(level)))
+ .collect::<Vec<_>>();
+
+ {
+ let builder = ParallelHnswBuilder {
+ d,
+ metric,
+ vectors: &vectors,
+ levels: &levels,
+ nodes: &nodes,
+ params,
+ entry_point: 0,
+ max_observed_level,
+ };
+ (1..n).into_par_iter().for_each_init(
+ || HnswBuildWorkspace::new(n, params.ef_construction),
+ |workspace, node| builder.insert(node, workspace),
+ );
+ }
+
+ let neighbors = nodes
+ .into_iter()
+ .map(|node| {
+ node.into_inner()
+ .expect("parallel HNSW builder lock poisoned")
+ .levels
+ .into_iter()
+ .map(|level| level.into_iter().map(|neighbor|
neighbor.id).collect())
+ .collect()
+ })
+ .collect();
+
+ Self {
+ d,
+ metric,
+ vectors,
+ levels,
+ neighbors,
+ entry_point: 0,
+ max_observed_level,
+ params,
+ }
+ }
+
#[allow(clippy::too_many_arguments)]
pub(crate) fn from_parts(
vectors: Vec<f32>,
@@ -198,9 +289,25 @@ impl HnswGraph {
}
pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize,
f32)> {
+ let mut visited = Vec::new();
+ let mut visit_mark = 1usize;
+ self.search_with_workspace(query, k, ef, &mut visited, &mut visit_mark)
+ }
+
+ pub(crate) fn search_with_workspace(
+ &self,
+ query: &[f32],
+ k: usize,
+ ef: usize,
+ visited: &mut Vec<usize>,
+ visit_mark: &mut usize,
+ ) -> Vec<(usize, f32)> {
if self.levels.is_empty() || k == 0 {
return Vec::new();
}
+ if visited.len() < self.levels.len() {
+ visited.resize(self.levels.len(), 0);
+ }
let mut ep = self.entry_point;
let mut ep_dist = self.distance_to_query(query, ep);
@@ -210,8 +317,9 @@ impl HnswGraph {
ep_dist = dist;
}
- let mut visited = vec![0usize; self.levels.len()];
- let candidates = self.search_layer_query(query, ep, ef.max(k), 0, &mut
visited, 1);
+ let current_mark = *visit_mark;
+ let candidates = self.search_layer_query(query, ep, ef.max(k), 0,
visited, current_mark);
+ *visit_mark = advance_visit_mark(visited, current_mark);
candidates
.into_iter()
.take(k)
@@ -255,7 +363,7 @@ impl HnswGraph {
self.max_observed_level
}
- fn insert(&mut self, node: usize) {
+ fn insert(&mut self, node: usize, workspace: &mut HnswBuildWorkspace) {
let level = random_level(node, self.params.m, self.params.max_level);
self.levels.push(level);
self.neighbors.push(vec![Vec::new(); level + 1]);
@@ -275,30 +383,15 @@ impl HnswGraph {
ep_dist = dist;
}
- let mut visited = vec![0usize; self.levels.len()];
- let mut visit_mark = 1usize;
for layer in (0..=level.min(self.max_observed_level)).rev() {
- let candidates = self.search_layer_node(
- node,
- ep,
- self.params.ef_construction,
- layer,
- &mut visited,
- visit_mark,
- );
- visit_mark = advance_visit_mark(&mut visited, visit_mark);
- let selected = self.select_neighbors(candidates,
self.max_neighbors(layer));
- for neighbor in selected {
- self.connect(node, neighbor.id, layer);
- }
- if let Some(best) = self.neighbors[node][layer]
- .iter()
- .copied()
- .min_by(|&a, &b| {
- self.distance_between(node, a)
- .total_cmp(&self.distance_between(node, b))
- })
- {
+ self.search_layer_node_with_workspace(node, ep, layer, workspace);
+ let next_ep = workspace.output.first().map(|candidate|
candidate.id);
+ let selected = workspace
+ .select_output_neighbors_by(self.max_neighbors(layer),
|candidate, neighbor| {
+ self.distance_between(candidate, neighbor)
+ });
+ self.connect_selected(node, selected, layer);
+ if let Some(best) = next_ep {
ep = best;
}
}
@@ -309,32 +402,57 @@ impl HnswGraph {
}
}
- fn connect(&mut self, a: usize, b: usize, level: usize) {
- if !self.neighbors[a][level].contains(&b) {
- self.neighbors[a][level].push(b);
+ fn connect_selected(&mut self, node: usize, selected: &[ScoredNode],
level: usize) {
+ let node_neighbors = &mut self.neighbors[node][level];
+ node_neighbors.clear();
+ node_neighbors.extend(selected.iter().map(|neighbor| neighbor.id));
+
+ for neighbor in selected {
+ let neighbor_id = neighbor.id;
+ if level < self.neighbors[neighbor_id].len()
+ && !self.neighbors[neighbor_id][level].contains(&node)
+ {
+ self.connect_reverse(node, neighbor_id, neighbor.dist, level);
+ }
}
- if level < self.neighbors[b].len() &&
!self.neighbors[b][level].contains(&a) {
- self.neighbors[b][level].push(a);
- let pruned = self.pruned_neighbors(b, level,
self.max_neighbors(level));
- self.neighbors[b][level] = pruned;
+ }
+
+ fn connect_reverse(&mut self, node: usize, neighbor: usize, distance: f32,
level: usize) {
+ let max_neighbors = self.max_neighbors(level);
+ {
+ let neighbors = &self.neighbors[neighbor][level];
+ if neighbors.len() >= max_neighbors
+ && !neighbors
+ .iter()
+ .any(|&existing| distance <
self.distance_between(neighbor, existing))
+ {
+ return;
+ }
+ }
+
+ self.neighbors[neighbor][level].push(node);
+ if self.neighbors[neighbor][level].len() > max_neighbors {
+ let pruned = self.pruned_neighbors(neighbor, level, max_neighbors);
+ self.neighbors[neighbor][level] = pruned;
}
- let pruned = self.pruned_neighbors(a, level,
self.max_neighbors(level));
- self.neighbors[a][level] = pruned;
}
fn pruned_neighbors(&self, node: usize, level: usize, max_neighbors:
usize) -> Vec<usize> {
- let mut ranked: Vec<ScoredNode> = self.neighbors[node][level]
+ let neighbors = &self.neighbors[node][level];
+ if neighbors.len() <= max_neighbors {
+ return neighbors.clone();
+ }
+
+ let ranked: Vec<ScoredNode> = neighbors
.iter()
.map(|&id| ScoredNode {
id,
dist: self.distance_between(node, id),
})
.collect();
- ranked.sort_by(|a, b| a.dist.total_cmp(&b.dist));
- ranked
+ self.select_neighbors(ranked, max_neighbors)
.into_iter()
- .take(max_neighbors)
- .map(|node| node.id)
+ .map(|neighbor| neighbor.id)
.collect()
}
@@ -343,30 +461,19 @@ impl HnswGraph {
mut candidates: Vec<ScoredNode>,
max_neighbors: usize,
) -> Vec<ScoredNode> {
- candidates.sort_by(|a, b| a.dist.total_cmp(&b.dist));
- let mut selected: Vec<ScoredNode> = Vec::with_capacity(max_neighbors);
- let mut backfill: Vec<ScoredNode> = Vec::new();
- for candidate in candidates {
- if selected.len() >= max_neighbors {
- break;
- }
- let closer_to_selected = selected
- .iter()
- .any(|neighbor| self.distance_between(candidate.id,
neighbor.id) < candidate.dist);
- if !closer_to_selected {
- selected.push(candidate);
- } else {
- backfill.push(candidate);
- }
- }
- for candidate in backfill {
- if selected.len() >= max_neighbors {
- break;
- }
- if !selected.iter().any(|neighbor| neighbor.id == candidate.id) {
- selected.push(candidate);
- }
- }
+ candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
+ self.select_neighbors_sorted(&candidates, max_neighbors)
+ }
+
+ fn select_neighbors_sorted(
+ &self,
+ candidates: &[ScoredNode],
+ max_neighbors: usize,
+ ) -> Vec<ScoredNode> {
+ let mut selected =
Vec::with_capacity(max_neighbors.min(candidates.len()));
+ select_neighbors_sorted_into(candidates, max_neighbors, &mut selected,
|a, b| {
+ self.distance_between(a, b)
+ });
selected
}
@@ -434,39 +541,79 @@ impl HnswGraph {
})
}
- fn search_layer_node(
+ fn search_layer_node_with_workspace(
&self,
node: usize,
entry: usize,
+ level: usize,
+ workspace: &mut HnswBuildWorkspace,
+ ) {
+ let visit_mark = workspace.visit_mark;
+ self.search_layer_into(
+ entry,
+ self.params.ef_construction,
+ level,
+ &mut workspace.visited,
+ visit_mark,
+ &mut workspace.candidates,
+ &mut workspace.results,
+ &mut workspace.output,
+ |id| self.distance_between(node, id),
+ );
+ workspace.visit_mark = advance_visit_mark(&mut workspace.visited,
visit_mark);
+ }
+
+ fn search_layer(
+ &self,
+ entry: usize,
ef: usize,
level: usize,
visited: &mut [usize],
visit_mark: usize,
+ distance: impl FnMut(usize) -> f32,
) -> Vec<ScoredNode> {
- self.search_layer(entry, ef, level, visited, visit_mark, |id| {
- self.distance_between(node, id)
- })
+ let mut candidates = BinaryHeap::with_capacity(ef);
+ let mut results = BinaryHeap::with_capacity(ef);
+ let mut output = Vec::with_capacity(ef);
+ self.search_layer_into(
+ entry,
+ ef,
+ level,
+ visited,
+ visit_mark,
+ &mut candidates,
+ &mut results,
+ &mut output,
+ distance,
+ );
+ output
}
- fn search_layer(
+ #[allow(clippy::too_many_arguments)]
+ fn search_layer_into(
&self,
entry: usize,
ef: usize,
level: usize,
visited: &mut [usize],
visit_mark: usize,
+ candidates: &mut BinaryHeap<Reverse<HeapNode>>,
+ results: &mut BinaryHeap<HeapNode>,
+ output: &mut Vec<ScoredNode>,
mut distance: impl FnMut(usize) -> f32,
- ) -> Vec<ScoredNode> {
+ ) {
+ candidates.clear();
+ results.clear();
+ output.clear();
+
let entry_dist = distance(entry);
visited[entry] = visit_mark;
- let mut candidates = BinaryHeap::new();
candidates.push(Reverse(HeapNode {
id: entry,
dist: entry_dist,
}));
- let mut results = BinaryHeap::new();
results.push(HeapNode {
id: entry,
dist: entry_dist,
@@ -501,15 +648,11 @@ impl HnswGraph {
}
}
- let mut result: Vec<ScoredNode> = results
- .into_iter()
- .map(|node| ScoredNode {
- id: node.id,
- dist: node.dist,
- })
- .collect();
- result.sort_by(|a, b| a.dist.total_cmp(&b.dist));
- result
+ output.extend(results.drain().map(|node| ScoredNode {
+ id: node.id,
+ dist: node.dist,
+ }));
+ output.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
}
fn max_neighbors(&self, level: usize) -> usize {
@@ -566,6 +709,339 @@ impl Ord for HeapNode {
}
}
+struct HnswBuildWorkspace {
+ visited: Vec<usize>,
+ visit_mark: usize,
+ candidates: BinaryHeap<Reverse<HeapNode>>,
+ results: BinaryHeap<HeapNode>,
+ output: Vec<ScoredNode>,
+ selected: Vec<ScoredNode>,
+}
+
+impl HnswBuildWorkspace {
+ fn new(n: usize, ef_construction: usize) -> Self {
+ Self {
+ visited: vec![0; n],
+ visit_mark: 1,
+ candidates: BinaryHeap::with_capacity(ef_construction),
+ results: BinaryHeap::with_capacity(ef_construction),
+ output: Vec::with_capacity(ef_construction),
+ selected: Vec::new(),
+ }
+ }
+
+ fn select_output_neighbors_by(
+ &mut self,
+ max_neighbors: usize,
+ distance_between: impl FnMut(usize, usize) -> f32,
+ ) -> &[ScoredNode] {
+ select_neighbors_sorted_into(
+ &self.output,
+ max_neighbors,
+ &mut self.selected,
+ distance_between,
+ );
+ &self.selected
+ }
+}
+
+struct ParallelHnswBuilder<'a> {
+ d: usize,
+ metric: MetricType,
+ vectors: &'a [f32],
+ levels: &'a [usize],
+ nodes: &'a [RwLock<ParallelBuildNode>],
+ params: HnswBuildParams,
+ entry_point: usize,
+ max_observed_level: usize,
+}
+
+impl ParallelHnswBuilder<'_> {
+ fn insert(&self, node: usize, workspace: &mut HnswBuildWorkspace) {
+ let level = self.nodes[node]
+ .read()
+ .expect("parallel HNSW builder lock poisoned")
+ .level();
+ let mut ep = self.entry_point;
+ let mut ep_dist = self.distance_between(node, ep);
+
+ for layer in ((level + 1)..=self.max_observed_level).rev() {
+ let (next, dist) = self.greedy_search_node(node, ep, ep_dist,
layer);
+ ep = next;
+ ep_dist = dist;
+ }
+
+ for layer in (0..=level.min(self.max_observed_level)).rev() {
+ self.search_layer_node(node, ep, layer, workspace);
+ let next_ep = workspace.output.first().map(|candidate|
candidate.id);
+ let selected = workspace
+ .select_output_neighbors_by(self.max_neighbors(layer),
|candidate, neighbor| {
+ self.distance_between(candidate, neighbor)
+ });
+ self.connect_selected(node, selected, layer);
+ if let Some(best) = next_ep {
+ ep = best;
+ }
+ }
+ }
+
+ fn connect_selected(&self, node: usize, selected: &[ScoredNode], level:
usize) {
+ {
+ let mut current = self.nodes[node]
+ .write()
+ .expect("parallel HNSW builder lock poisoned");
+ let current_neighbors = &mut current.levels[level];
+ current_neighbors.clear();
+ current_neighbors.extend_from_slice(selected);
+ }
+
+ for neighbor in selected {
+ let neighbor_id = neighbor.id;
+ if self.levels[neighbor_id] >= level {
+ self.connect_reverse(node, neighbor_id, neighbor.dist, level);
+ }
+ }
+ }
+
+ fn connect_reverse(&self, node: usize, neighbor: usize, distance: f32,
level: usize) {
+ let max_neighbors = self.max_neighbors(level);
+ {
+ let neighbor_node = self.nodes[neighbor]
+ .read()
+ .expect("parallel HNSW builder lock poisoned");
+ let neighbors = &neighbor_node.levels[level];
+ if neighbors.iter().any(|existing| existing.id == node) {
+ return;
+ }
+ if neighbors.len() >= max_neighbors
+ && !neighbors.iter().any(|existing| distance < existing.dist)
+ {
+ return;
+ }
+ }
+
+ let mut neighbor_node = self.nodes[neighbor]
+ .write()
+ .expect("parallel HNSW builder lock poisoned");
+ let neighbors = &mut neighbor_node.levels[level];
+ if neighbors.iter().any(|existing| existing.id == node) {
+ return;
+ }
+ neighbors.push(ScoredNode {
+ id: node,
+ dist: distance,
+ });
+ if neighbors.len() > max_neighbors {
+ let candidates = std::mem::take(neighbors);
+ let pruned = self.select_neighbors(candidates, max_neighbors);
+ *neighbors = pruned;
+ }
+ }
+
+ fn greedy_search_node(
+ &self,
+ node: usize,
+ mut current: usize,
+ mut current_dist: f32,
+ level: usize,
+ ) -> (usize, f32) {
+ loop {
+ let mut best = current;
+ let mut best_dist = current_dist;
+ self.for_each_neighbor(current, level, |neighbor| {
+ let dist = self.distance_between(node, neighbor);
+ if dist < best_dist {
+ best = neighbor;
+ best_dist = dist;
+ }
+ });
+ if best == current {
+ return (current, current_dist);
+ }
+ current = best;
+ current_dist = best_dist;
+ }
+ }
+
+ fn search_layer_node(
+ &self,
+ node: usize,
+ entry: usize,
+ level: usize,
+ workspace: &mut HnswBuildWorkspace,
+ ) {
+ workspace.candidates.clear();
+ workspace.results.clear();
+ workspace.output.clear();
+
+ let visit_mark = workspace.visit_mark;
+ let entry_dist = self.distance_between(node, entry);
+ workspace.visited[entry] = visit_mark;
+ workspace.candidates.push(Reverse(HeapNode {
+ id: entry,
+ dist: entry_dist,
+ }));
+ workspace.results.push(HeapNode {
+ id: entry,
+ dist: entry_dist,
+ });
+
+ while let Some(Reverse(current)) = workspace.candidates.pop() {
+ let worst = workspace
+ .results
+ .peek()
+ .map(|node| node.dist)
+ .unwrap_or(f32::INFINITY);
+ if current.dist > worst && workspace.results.len() >=
self.params.ef_construction {
+ break;
+ }
+
+ self.for_each_neighbor(current.id, level, |neighbor| {
+ if workspace.visited[neighbor] == visit_mark {
+ return;
+ }
+ workspace.visited[neighbor] = visit_mark;
+ let dist = self.distance_between(node, neighbor);
+ let worst = workspace
+ .results
+ .peek()
+ .map(|node| node.dist)
+ .unwrap_or(f32::INFINITY);
+ if workspace.results.len() < self.params.ef_construction ||
dist < worst {
+ workspace
+ .candidates
+ .push(Reverse(HeapNode { id: neighbor, dist }));
+ workspace.results.push(HeapNode { id: neighbor, dist });
+ if workspace.results.len() > self.params.ef_construction {
+ workspace.results.pop();
+ }
+ }
+ });
+ }
+
+ workspace
+ .output
+ .extend(workspace.results.drain().map(|node| ScoredNode {
+ id: node.id,
+ dist: node.dist,
+ }));
+ workspace
+ .output
+ .sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
+ workspace.visit_mark = advance_visit_mark(&mut workspace.visited,
visit_mark);
+ }
+
+ fn select_neighbors(
+ &self,
+ mut candidates: Vec<ScoredNode>,
+ max_neighbors: usize,
+ ) -> Vec<ScoredNode> {
+ candidates.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
+ self.select_neighbors_sorted(&candidates, max_neighbors)
+ }
+
+ fn select_neighbors_sorted(
+ &self,
+ candidates: &[ScoredNode],
+ max_neighbors: usize,
+ ) -> Vec<ScoredNode> {
+ let mut selected =
Vec::with_capacity(max_neighbors.min(candidates.len()));
+ select_neighbors_sorted_into(candidates, max_neighbors, &mut selected,
|a, b| {
+ self.distance_between(a, b)
+ });
+ selected
+ }
+
+ fn for_each_neighbor(&self, node: usize, level: usize, mut f: impl
FnMut(usize)) {
+ let node = self.nodes[node]
+ .read()
+ .expect("parallel HNSW builder lock poisoned");
+ if let Some(neighbors) = node.levels.get(level) {
+ for neighbor in neighbors {
+ f(neighbor.id);
+ }
+ }
+ }
+
+ fn max_neighbors(&self, level: usize) -> usize {
+ if level == 0 {
+ self.params.m * 2
+ } else {
+ self.params.m
+ }
+ }
+
+ fn distance_between(&self, a: usize, b: usize) -> f32 {
+ let va = &self.vectors[a * self.d..(a + 1) * self.d];
+ let vb = &self.vectors[b * self.d..(b + 1) * self.d];
+ fvec_distance(va, vb, self.metric)
+ }
+}
+
+struct ParallelBuildNode {
+ levels: Vec<Vec<ScoredNode>>,
+}
+
+impl ParallelBuildNode {
+ fn new(level: usize) -> Self {
+ Self {
+ levels: vec![Vec::new(); level + 1],
+ }
+ }
+
+ fn level(&self) -> usize {
+ self.levels.len() - 1
+ }
+}
+
+fn parallel_build_levels(n: usize, params: HnswBuildParams) -> Vec<usize> {
+ let mut levels: Vec<_> = (0..n)
+ .map(|node| random_level(node, params.m, params.max_level))
+ .collect();
+ if let Some(first) = levels.first_mut() {
+ // LanceDB keeps the fixed entry point reachable from every configured
+ // layer during parallel build. Mirroring that avoids a serialized
+ // "promote newest max-level node" phase while preserving high-level
+ // search quality.
+ *first = params.max_level - 1;
+ }
+ levels
+}
+
+fn select_neighbors_sorted_into(
+ candidates: &[ScoredNode],
+ max_neighbors: usize,
+ selected: &mut Vec<ScoredNode>,
+ mut distance_between: impl FnMut(usize, usize) -> f32,
+) {
+ selected.clear();
+ if candidates.len() <= max_neighbors {
+ selected.extend_from_slice(candidates);
+ return;
+ }
+
+ selected.reserve(max_neighbors.saturating_sub(selected.len()));
+ for &candidate in candidates {
+ if selected.len() >= max_neighbors {
+ break;
+ }
+ let closer_to_selected = selected
+ .iter()
+ .any(|neighbor| distance_between(candidate.id, neighbor.id) <
candidate.dist);
+ if !closer_to_selected {
+ selected.push(candidate);
+ }
+ }
+ for &candidate in candidates {
+ if selected.len() >= max_neighbors {
+ break;
+ }
+ if !selected.iter().any(|neighbor| neighbor.id == candidate.id) {
+ selected.push(candidate);
+ }
+ }
+}
+
fn random_level(node: usize, m: usize, max_level: usize) -> usize {
if node == 0 || max_level <= 1 {
// Keep the first insertion deterministic. Later higher-level nodes
replace
@@ -695,6 +1171,36 @@ mod tests {
assert!(recall >= 0.95, "recall={}", recall);
}
+ #[test]
+ fn test_hnsw_parallel_build_large_partition_recall_tracks_exact_search() {
+ let d = 16;
+ let n = PARALLEL_BUILD_MIN_N + 512;
+ let nq = 32;
+ let k = 10;
+ let data = generate_clustered_data(n, d, 32);
+ let params = HnswBuildParams {
+ m: 16,
+ ef_construction: 200,
+ max_level: 7,
+ };
+
+ let graph = HnswGraph::build(&data, n, d, MetricType::L2,
params).unwrap();
+ let mut hits = 0usize;
+ for qi in 0..nq {
+ let query = &data[qi * d..(qi + 1) * d];
+ let expected = exact_topk(&data, n, d, query, k);
+ let actual = graph.search(query, k, 200);
+ hits += actual
+ .iter()
+ .filter(|(id, _)| expected.contains(id))
+ .count();
+ }
+
+ let recall = hits as f32 / (nq * k) as f32;
+ assert!(recall >= 0.95, "recall={}", recall);
+ assert!(graph.max_degree() <= params.m * 2);
+ }
+
#[test]
fn test_hnsw_neighbor_selection_backfills_after_diversification() {
let d = 1;
@@ -722,6 +1228,31 @@ mod tests {
assert_eq!(selected.len(), 3);
}
+ #[test]
+ fn test_hnsw_pruning_keeps_diverse_neighbors() {
+ let graph = HnswGraph::from_parts(
+ vec![0.0, 0.0, 1.0, 0.0, 1.1, 0.0, 0.0, 2.0],
+ 4,
+ 2,
+ MetricType::L2,
+ vec![0, 0, 0, 0],
+ vec![
+ vec![vec![1, 2, 3]],
+ vec![vec![]],
+ vec![vec![]],
+ vec![vec![]],
+ ],
+ 0,
+ 0,
+ HnswBuildParams::default(),
+ )
+ .unwrap();
+
+ let selected = graph.pruned_neighbors(0, 0, 2);
+
+ assert_eq!(selected, vec![1, 3]);
+ }
+
#[test]
fn test_hnsw_greedy_search_chooses_best_improving_neighbor() {
let graph = HnswGraph::from_parts(
diff --git a/core/src/hnsw_search.rs b/core/src/hnsw_search.rs
index 74eab88..a975c1c 100644
--- a/core/src/hnsw_search.rs
+++ b/core/src/hnsw_search.rs
@@ -37,6 +37,8 @@ where
F: FnMut(&HnswSearchList<'a, P>, &mut TopKHeap),
{
let mut heap = TopKHeap::new(k);
+ let mut visited = Vec::new();
+ let mut visit_mark = 1usize;
let force_scan = filter
.map(|f| count_filtered(lists, f) <= ef_search.max(k))
.unwrap_or(false);
@@ -47,7 +49,13 @@ where
continue;
}
if let Some(graph) = list.graph {
- let local_results = graph.search(query, ef_search.max(k),
ef_search.max(k));
+ let local_results = graph.search_with_workspace(
+ query,
+ ef_search.max(k),
+ ef_search.max(k),
+ &mut visited,
+ &mut visit_mark,
+ );
for (local_id, dist) in local_results {
let row_id = list.ids[local_id];
if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
diff --git a/core/src/ivfhnswflat.rs b/core/src/ivfhnswflat.rs
index 0d2db81..5ab508a 100644
--- a/core/src/ivfhnswflat.rs
+++ b/core/src/ivfhnswflat.rs
@@ -22,6 +22,7 @@ use crate::ivfflat::IVFFlatIndex;
use crate::ivfpq::RowIdFilter;
use crate::kmeans;
use crate::topk::TopKHeap;
+use rayon::prelude::*;
use std::io;
pub struct IVFHNSWFlatIndex {
@@ -51,20 +52,24 @@ impl IVFHNSWFlatIndex {
}
pub fn build_graphs(&mut self) -> io::Result<()> {
- for list_id in 0..self.flat.nlist {
- let count = self.flat.ids[list_id].len();
- self.graphs[list_id] = if count == 0 {
- None
- } else {
- Some(HnswGraph::build(
- &self.flat.vectors[list_id],
- count,
- self.flat.d,
- self.flat.metric,
- self.hnsw_params,
- )?)
- };
- }
+ self.graphs = (0..self.flat.nlist)
+ .into_par_iter()
+ .map(|list_id| {
+ let count = self.flat.ids[list_id].len();
+ if count == 0 {
+ Ok(None)
+ } else {
+ HnswGraph::build(
+ &self.flat.vectors[list_id],
+ count,
+ self.flat.d,
+ self.flat.metric,
+ self.hnsw_params,
+ )
+ .map(Some)
+ }
+ })
+ .collect::<io::Result<Vec<_>>>()?;
Ok(())
}
diff --git a/core/src/ivfhnswsq.rs b/core/src/ivfhnswsq.rs
index 4ce2a6e..a2cc8f2 100644
--- a/core/src/ivfhnswsq.rs
+++ b/core/src/ivfhnswsq.rs
@@ -22,6 +22,7 @@ use crate::ivfpq::RowIdFilter;
use crate::kmeans::{self, KMeansConfig};
use crate::sq::ScalarQuantizer;
use crate::topk::TopKHeap;
+use rayon::prelude::*;
use std::io;
pub struct IVFHNSWSQIndex {
@@ -82,21 +83,19 @@ impl IVFHNSWSQIndex {
}
pub fn build_graphs(&mut self) -> io::Result<()> {
- for list_id in 0..self.nlist {
- let count = self.ids[list_id].len();
- self.graphs[list_id] = if count == 0 {
- None
- } else {
- let vectors = self.decode_list_vectors(list_id, count);
- Some(HnswGraph::build(
- &vectors,
- count,
- self.d,
- self.metric,
- self.hnsw_params,
- )?)
- };
- }
+ self.graphs = (0..self.nlist)
+ .into_par_iter()
+ .map(|list_id| {
+ let count = self.ids[list_id].len();
+ if count == 0 {
+ Ok(None)
+ } else {
+ let vectors = self.decode_list_vectors(list_id, count);
+ HnswGraph::build_owned(vectors, count, self.d,
self.metric, self.hnsw_params)
+ .map(Some)
+ }
+ })
+ .collect::<io::Result<Vec<_>>>()?;
Ok(())
}