This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 8bc5cfe Add IVF HNSW SQ index (#20)
8bc5cfe is described below
commit 8bc5cfebc5a3151efefdcc4a76cad374be4a66df
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 15:19:40 2026 +0800
Add IVF HNSW SQ index (#20)
---
core/benches/recall_bench.rs | 39 ++
core/src/hnsw_search.rs | 76 ++++
core/src/index_io_util.rs | 346 ++++++++++++++++++
core/src/ivfhnswflat.rs | 56 +--
core/src/ivfhnswflat_io.rs | 415 ++-------------------
core/src/ivfhnswsq.rs | 289 +++++++++++++++
core/src/ivfhnswsq_io.rs | 851 +++++++++++++++++++++++++++++++++++++++++++
core/src/lib.rs | 5 +
core/src/sq.rs | 229 ++++++++++++
9 files changed, 1879 insertions(+), 427 deletions(-)
diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index 1755402..5103446 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -2,6 +2,7 @@ use paimon_vindex_core::distance::{fvec_distance, MetricType};
use paimon_vindex_core::hnsw::HnswBuildParams;
use paimon_vindex_core::ivfflat::IVFFlatIndex;
use paimon_vindex_core::ivfhnswflat::IVFHNSWFlatIndex;
+use paimon_vindex_core::ivfhnswsq::IVFHNSWSQIndex;
use paimon_vindex_core::ivfpq::IVFPQIndex;
use std::collections::HashSet;
use std::time::Instant;
@@ -99,6 +100,22 @@ fn run_scenario(s: Scenario<'_>) {
ivfhnswflat.build_graphs().unwrap();
println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
+ let start = Instant::now();
+ let mut ivfhnswsq = IVFHNSWSQIndex::new(
+ s.d,
+ s.nlist,
+ MetricType::L2,
+ HnswBuildParams {
+ m: 16,
+ ef_construction: s.hnsw_build_ef,
+ max_level: 7,
+ },
+ );
+ ivfhnswsq.train(&data, s.n);
+ ivfhnswsq.add(&data, &ids, s.n);
+ ivfhnswsq.build_graphs().unwrap();
+ println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
+
println!();
println!(
"index nprobe ef recall@{} query_ms us/query",
@@ -157,6 +174,28 @@ fn run_scenario(s: Scenario<'_>) {
elapsed,
s.nq,
);
+
+ let mut distances = vec![0.0f32; s.nq * s.k];
+ let mut labels = vec![0i64; s.nq * s.k];
+ let start = Instant::now();
+ ivfhnswsq.search(
+ queries,
+ s.nq,
+ s.k,
+ nprobe,
+ ef_search,
+ &mut distances,
+ &mut labels,
+ );
+ let elapsed = start.elapsed();
+ print_row(
+ "IVF-HSQ",
+ nprobe,
+ Some(ef_search),
+ recall_at_k(&labels, &ground_truth, s.nq, s.k),
+ elapsed,
+ s.nq,
+ );
}
}
}
diff --git a/core/src/hnsw_search.rs b/core/src/hnsw_search.rs
new file mode 100644
index 0000000..74eab88
--- /dev/null
+++ b/core/src/hnsw_search.rs
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::hnsw::HnswGraph;
+use crate::ivfpq::RowIdFilter;
+use crate::topk::TopKHeap;
+
+pub(crate) struct HnswSearchList<'a, P> {
+ pub(crate) ids: &'a [i64],
+ pub(crate) graph: Option<&'a HnswGraph>,
+ pub(crate) payload: P,
+}
+
+pub(crate) fn search_hnsw_lists<'a, P, F>(
+ query: &[f32],
+ lists: &[HnswSearchList<'a, P>],
+ k: usize,
+ ef_search: usize,
+ filter: Option<&dyn RowIdFilter>,
+ mut scan_list: F,
+) -> Vec<(f32, i64)>
+where
+ F: FnMut(&HnswSearchList<'a, P>, &mut TopKHeap),
+{
+ let mut heap = TopKHeap::new(k);
+ let force_scan = filter
+ .map(|f| count_filtered(lists, f) <= ef_search.max(k))
+ .unwrap_or(false);
+
+ for list in lists {
+ if force_scan {
+ scan_list(list, &mut heap);
+ continue;
+ }
+ if let Some(graph) = list.graph {
+ let local_results = graph.search(query, ef_search.max(k),
ef_search.max(k));
+ for (local_id, dist) in local_results {
+ let row_id = list.ids[local_id];
+ if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+ heap.push(dist, row_id);
+ }
+ }
+ } else {
+ scan_list(list, &mut heap);
+ }
+ }
+
+ if filter.is_some() && heap.len() < k && !force_scan {
+ for list in lists {
+ scan_list(list, &mut heap);
+ }
+ }
+
+ heap.into_sorted()
+}
+
+fn count_filtered<P>(lists: &[HnswSearchList<'_, P>], filter: &dyn
RowIdFilter) -> usize {
+ lists
+ .iter()
+ .map(|list| list.ids.iter().filter(|&&id| filter.contains(id)).count())
+ .sum()
+}
diff --git a/core/src/index_io_util.rs b/core/src/index_io_util.rs
new file mode 100644
index 0000000..292953b
--- /dev/null
+++ b/core/src/index_io_util.rs
@@ -0,0 +1,346 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::MetricType;
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::io::{SeekRead, SeekWrite};
+use roaring::RoaringTreemap;
+use std::io;
+
+pub(crate) fn validate_search_inputs(
+ queries: &[f32],
+ nq: usize,
+ d: usize,
+ k: usize,
+ nprobe: usize,
+) -> io::Result<()> {
+ if nq == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nq must be greater than 0",
+ ));
+ }
+ let expected_query_len = nq.checked_mul(d).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nq * dimension overflows usize",
+ )
+ })?;
+ if queries.len() != expected_query_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "queries length {} does not match nq * dimension {}",
+ queries.len(),
+ expected_query_len
+ ),
+ ));
+ }
+ if k == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "k must be greater than 0",
+ ));
+ }
+ if nprobe == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nprobe must be greater than 0",
+ ));
+ }
+ Ok(())
+}
+
+pub(crate) fn encode_graph(graph: Option<&HnswGraph>) -> io::Result<Vec<u8>> {
+ let Some(graph) = graph else {
+ return Ok(Vec::new());
+ };
+ let mut buf = Vec::new();
+ write_u32_vec(&mut buf, graph.len())?;
+ write_u32_vec(&mut buf, graph.entry_point())?;
+ write_u32_vec(&mut buf, graph.max_observed_level())?;
+ for &level in graph.levels() {
+ write_u32_vec(&mut buf, level)?;
+ }
+ for node_levels in graph.neighbors() {
+ for level_neighbors in node_levels {
+ write_u32_vec(&mut buf, level_neighbors.len())?;
+ for &neighbor in level_neighbors {
+ write_u32_vec(&mut buf, neighbor)?;
+ }
+ }
+ }
+ Ok(buf)
+}
+
+pub(crate) fn decode_graph(
+ bytes: &[u8],
+ vectors: Vec<f32>,
+ count: usize,
+ d: usize,
+ metric: MetricType,
+ hnsw_params: HnswBuildParams,
+) -> io::Result<Option<HnswGraph>> {
+ if bytes.is_empty() {
+ return Ok(None);
+ }
+ let mut pos = 0usize;
+ let graph_count = read_u32_vec(bytes, &mut pos)? as usize;
+ if graph_count != count {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph count {} does not match list count {}",
+ graph_count, count
+ ),
+ ));
+ }
+ let entry_point = read_u32_vec(bytes, &mut pos)? as usize;
+ let max_observed_level = read_u32_vec(bytes, &mut pos)? as usize;
+ let mut levels = Vec::with_capacity(count);
+ for node in 0..count {
+ let level = read_u32_vec(bytes, &mut pos)? as usize;
+ if level >= hnsw_params.max_level {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph node {} level {} exceeds max_level {}",
+ node,
+ level,
+ hnsw_params.max_level - 1
+ ),
+ ));
+ }
+ levels.push(level);
+ }
+ let mut neighbors = Vec::with_capacity(count);
+ for (node, &level) in levels.iter().enumerate() {
+ let mut node_levels = Vec::with_capacity(level + 1);
+ for graph_level in 0..=level {
+ let degree = read_u32_vec(bytes, &mut pos)? as usize;
+ let max_degree = if graph_level == 0 {
+ hnsw_params.m.saturating_mul(2)
+ } else {
+ hnsw_params.m
+ };
+ if degree > max_degree {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "graph node {} degree {} at level {} exceeds max
degree {}",
+ node, degree, graph_level, max_degree
+ ),
+ ));
+ }
+ let mut level_neighbors = Vec::with_capacity(degree);
+ for _ in 0..degree {
+ level_neighbors.push(read_u32_vec(bytes, &mut pos)? as usize);
+ }
+ node_levels.push(level_neighbors);
+ }
+ neighbors.push(node_levels);
+ }
+ if pos != bytes.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "trailing bytes in HNSW graph section",
+ ));
+ }
+ Ok(Some(HnswGraph::from_parts(
+ vectors,
+ count,
+ d,
+ metric,
+ levels,
+ neighbors,
+ entry_point,
+ max_observed_level,
+ hnsw_params,
+ )?))
+}
+
+fn write_u32_vec(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
+ if value > u32::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("value {} exceeds u32 limit", value),
+ ));
+ }
+ buf.extend_from_slice(&(value as u32).to_le_bytes());
+ Ok(())
+}
+
+fn read_u32_vec(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
+ let end = pos.checked_add(4).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "graph section position overflow",
+ )
+ })?;
+ if end > bytes.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "truncated HNSW graph section",
+ ));
+ }
+ let value = u32::from_le_bytes(bytes[*pos..end].try_into().unwrap());
+ *pos = end;
+ Ok(value)
+}
+
+pub(crate) fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+pub(crate) fn write_i32_le(out: &mut dyn SeekWrite, v: i32) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+pub(crate) fn write_i64_le(out: &mut dyn SeekWrite, v: i64) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+pub(crate) fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32]) ->
io::Result<()> {
+ let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
+ out.write_all(&bytes)
+}
+
+pub(crate) fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+ let mut buf = [0u8; 4];
+ reader.read_exact(&mut buf)?;
+ Ok(u32::from_le_bytes(buf))
+}
+
+pub(crate) fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+ let mut buf = [0u8; 4];
+ reader.read_exact(&mut buf)?;
+ Ok(i32::from_le_bytes(buf))
+}
+
+pub(crate) fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+ let mut buf = [0u8; 8];
+ reader.read_exact(&mut buf)?;
+ Ok(i64::from_le_bytes(buf))
+}
+
+pub(crate) fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) ->
io::Result<Vec<f32>> {
+ let byte_len = count.checked_mul(4).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "f32 section byte length overflow",
+ )
+ })?;
+ let mut buf = vec![0u8; byte_len];
+ reader.read_exact(&mut buf)?;
+ bytes_to_f32_vec(&buf)
+}
+
+pub(crate) fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
+ if !bytes.len().is_multiple_of(4) {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "f32 byte section is not 4-byte aligned",
+ ));
+ }
+ Ok(bytes
+ .chunks_exact(4)
+ .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
+ .collect())
+}
+
+pub(crate) fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
+ if val <= 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("invalid header field {}: {} (must be positive)", field,
val),
+ ));
+ }
+ Ok(val)
+}
+
+pub(crate) fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
+ if value > i32::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i32 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i32)
+}
+
+pub(crate) fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
+pub(crate) fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as u64 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 offset limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
+const MAX_SECTION_ELEMENTS: usize = 1 << 30;
+
+pub(crate) fn checked_section_size(a: usize, b: usize) -> io::Result<usize> {
+ let result = a
+ .checked_mul(b)
+ .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "section
size overflow"))?;
+ if result > MAX_SECTION_ELEMENTS {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "section size {} exceeds maximum {}",
+ result, MAX_SECTION_ELEMENTS
+ ),
+ ));
+ }
+ Ok(result)
+}
+
+pub(crate) fn checked_list_offset(offset: i64, list_id: usize) ->
io::Result<u64> {
+ if offset < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("negative list offset {} at list {}", offset, list_id),
+ ));
+ }
+ Ok(offset as u64)
+}
+
+pub(crate) fn checked_list_bytes(count: usize, bytes_per_entry: usize) ->
io::Result<usize> {
+ count
+ .checked_mul(bytes_per_entry)
+ .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "list byte
size overflow"))
+}
+
+pub(crate) fn decode_roaring_filter(bytes: &[u8]) ->
io::Result<RoaringTreemap> {
+ RoaringTreemap::deserialize_from(bytes).map_err(|e| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("invalid RoaringTreemap filter: {}", e),
+ )
+ })
+}
diff --git a/core/src/ivfhnswflat.rs b/core/src/ivfhnswflat.rs
index e2f7ccf..0d2db81 100644
--- a/core/src/ivfhnswflat.rs
+++ b/core/src/ivfhnswflat.rs
@@ -17,6 +17,7 @@
use crate::distance::{fvec_distance, MetricType};
use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
use crate::ivfflat::IVFFlatIndex;
use crate::ivfpq::RowIdFilter;
use crate::kmeans;
@@ -114,38 +115,17 @@ impl IVFHNSWFlatIndex {
for qi in 0..nq {
let query = &processed_queries[qi * self.flat.d..(qi + 1) *
self.flat.d];
- let mut heap = TopKHeap::new(k);
- let force_flat_scan = filter
- .map(|f| self.count_filtered(&all_probe_indices[qi], f) <=
ef_search.max(k))
- .unwrap_or(false);
-
- for &list_id in &all_probe_indices[qi] {
- if force_flat_scan {
- self.scan_flat_list(query, list_id, filter, &mut heap);
- continue;
- }
- if let Some(ref graph) = self.graphs[list_id] {
- let local_results = graph.search(query, ef_search.max(k),
ef_search.max(k));
- for (local_id, dist) in local_results {
- let row_id = self.flat.ids[list_id][local_id];
- if let Some(f) = filter {
- if !f.contains(row_id) {
- continue;
- }
- }
- heap.push(dist, row_id);
- }
- } else {
- self.scan_flat_list(query, list_id, filter, &mut heap);
- }
- }
- if filter.is_some() && heap.len() < k {
- for &list_id in &all_probe_indices[qi] {
- self.scan_flat_list(query, list_id, filter, &mut heap);
- }
- }
-
- let sorted = heap.into_sorted();
+ let lists: Vec<_> = all_probe_indices[qi]
+ .iter()
+ .map(|&list_id| HnswSearchList {
+ ids: self.flat.ids[list_id].as_slice(),
+ graph: self.graphs[list_id].as_ref(),
+ payload: list_id,
+ })
+ .collect();
+ let sorted = search_hnsw_lists(query, &lists, k, ef_search,
filter, |list, heap| {
+ self.scan_flat_list(query, list.payload, filter, heap);
+ });
let out_base = qi * k;
for (i, &(dist, id)) in sorted.iter().enumerate() {
result_distances[out_base + i] = dist;
@@ -158,18 +138,6 @@ impl IVFHNSWFlatIndex {
}
}
- fn count_filtered(&self, probe_indices: &[usize], filter: &dyn
RowIdFilter) -> usize {
- probe_indices
- .iter()
- .map(|&list_id| {
- self.flat.ids[list_id]
- .iter()
- .filter(|&&id| filter.contains(id))
- .count()
- })
- .sum()
- }
-
fn scan_flat_list(
&self,
query: &[f32],
diff --git a/core/src/ivfhnswflat_io.rs b/core/src/ivfhnswflat_io.rs
index d504d43..ebd155c 100644
--- a/core/src/ivfhnswflat_io.rs
+++ b/core/src/ivfhnswflat_io.rs
@@ -17,12 +17,18 @@
use crate::distance::{fvec_distance, preprocess_vectors, MetricType};
use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
+use crate::index_io_util::{
+ bytes_to_f32_vec, checked_list_bytes, checked_list_offset,
checked_section_size, decode_graph,
+ decode_roaring_filter, encode_graph, read_f32_vec, read_i32_le,
read_i64_le, read_u32_le,
+ u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32,
validate_search_inputs,
+ write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
+};
use crate::io::{SeekRead, SeekWrite};
use crate::ivfhnswflat::IVFHNSWFlatIndex;
use crate::ivfpq::RowIdFilter;
use crate::kmeans;
use crate::topk::TopKHeap;
-use roaring::RoaringTreemap;
use std::io;
pub const IVF_HNSW_FLAT_MAGIC: u32 = 0x4948464C; // "IHFL"
@@ -356,47 +362,26 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
loaded_lists.push(list);
}
}
- let mut heap = TopKHeap::new(k);
- let force_flat_scan = filter
- .map(|f| count_filtered(&loaded_lists, f) <= ef_search.max(k))
- .unwrap_or(false);
-
- for list in &loaded_lists {
- if force_flat_scan {
- scan_flat_list(
- &q,
- &list.ids,
- list.graph.vectors(),
- self.d,
- self.metric,
- filter,
- &mut heap,
- );
- } else {
- let local_results = list.graph.search(&q, ef_search.max(k),
ef_search.max(k));
- for (local_id, dist) in local_results {
- let row_id = list.ids[local_id];
- if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
- heap.push(dist, row_id);
- }
- }
- }
- }
- if filter.is_some() && heap.len() < k {
- for list in &loaded_lists {
- scan_flat_list(
- &q,
- &list.ids,
- list.graph.vectors(),
- self.d,
- self.metric,
- filter,
- &mut heap,
- );
- }
- }
-
- let sorted = heap.into_sorted();
+ let search_lists: Vec<_> = loaded_lists
+ .iter()
+ .map(|list| HnswSearchList {
+ ids: list.ids.as_slice(),
+ graph: Some(&list.graph),
+ payload: list,
+ })
+ .collect();
+ let sorted = search_hnsw_lists(&q, &search_lists, k, ef_search,
filter, |list, heap| {
+ let list = list.payload;
+ scan_flat_list(
+ &q,
+ &list.ids,
+ list.graph.vectors(),
+ self.d,
+ self.metric,
+ filter,
+ heap,
+ );
+ });
let mut labels: Vec<i64> = sorted.iter().map(|&(_, id)| id).collect();
let mut distances: Vec<f32> = sorted.iter().map(|&(dist, _)|
dist).collect();
labels.resize(k, -1);
@@ -573,13 +558,6 @@ struct LoadedBatchList {
graph: HnswGraph,
}
-fn count_filtered(lists: &[GraphList], filter: &dyn RowIdFilter) -> usize {
- lists
- .iter()
- .map(|list| list.ids.iter().filter(|&&id| filter.contains(id)).count())
- .sum()
-}
-
fn scan_flat_list(
query: &[f32],
ids: &[i64],
@@ -598,257 +576,6 @@ fn scan_flat_list(
}
}
-fn validate_search_inputs(
- queries: &[f32],
- nq: usize,
- d: usize,
- k: usize,
- nprobe: usize,
-) -> io::Result<()> {
- if nq == 0 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- "nq must be greater than 0",
- ));
- }
- let expected_query_len = nq.checked_mul(d).ok_or_else(|| {
- io::Error::new(
- io::ErrorKind::InvalidInput,
- "nq * dimension overflows usize",
- )
- })?;
- if queries.len() != expected_query_len {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- format!(
- "queries length {} does not match nq * dimension {}",
- queries.len(),
- expected_query_len
- ),
- ));
- }
- if k == 0 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- "k must be greater than 0",
- ));
- }
- if nprobe == 0 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- "nprobe must be greater than 0",
- ));
- }
- Ok(())
-}
-
-fn encode_graph(graph: Option<&HnswGraph>) -> io::Result<Vec<u8>> {
- let Some(graph) = graph else {
- return Ok(Vec::new());
- };
- let mut buf = Vec::new();
- write_u32_vec(&mut buf, graph.len())?;
- write_u32_vec(&mut buf, graph.entry_point())?;
- write_u32_vec(&mut buf, graph.max_observed_level())?;
- for &level in graph.levels() {
- write_u32_vec(&mut buf, level)?;
- }
- for node_levels in graph.neighbors() {
- for level_neighbors in node_levels {
- write_u32_vec(&mut buf, level_neighbors.len())?;
- for &neighbor in level_neighbors {
- write_u32_vec(&mut buf, neighbor)?;
- }
- }
- }
- Ok(buf)
-}
-
-fn decode_graph(
- bytes: &[u8],
- vectors: Vec<f32>,
- count: usize,
- d: usize,
- metric: MetricType,
- hnsw_params: HnswBuildParams,
-) -> io::Result<Option<HnswGraph>> {
- if bytes.is_empty() {
- return Ok(None);
- }
- let mut pos = 0usize;
- let graph_count = read_u32_vec(bytes, &mut pos)? as usize;
- if graph_count != count {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- format!(
- "graph count {} does not match list count {}",
- graph_count, count
- ),
- ));
- }
- let entry_point = read_u32_vec(bytes, &mut pos)? as usize;
- let max_observed_level = read_u32_vec(bytes, &mut pos)? as usize;
- let mut levels = Vec::with_capacity(count);
- for node in 0..count {
- let level = read_u32_vec(bytes, &mut pos)? as usize;
- if level >= hnsw_params.max_level {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- format!(
- "graph node {} level {} exceeds max_level {}",
- node,
- level,
- hnsw_params.max_level - 1
- ),
- ));
- }
- levels.push(level);
- }
- let mut neighbors = Vec::with_capacity(count);
- for (node, &level) in levels.iter().enumerate() {
- let mut node_levels = Vec::with_capacity(level + 1);
- for graph_level in 0..=level {
- let degree = read_u32_vec(bytes, &mut pos)? as usize;
- let max_degree = if graph_level == 0 {
- hnsw_params.m.saturating_mul(2)
- } else {
- hnsw_params.m
- };
- if degree > max_degree {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- format!(
- "graph node {} degree {} at level {} exceeds max
degree {}",
- node, degree, graph_level, max_degree
- ),
- ));
- }
- let mut level_neighbors = Vec::with_capacity(degree);
- for _ in 0..degree {
- level_neighbors.push(read_u32_vec(bytes, &mut pos)? as usize);
- }
- node_levels.push(level_neighbors);
- }
- neighbors.push(node_levels);
- }
- if pos != bytes.len() {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "trailing bytes in HNSW graph section",
- ));
- }
- Ok(Some(HnswGraph::from_parts(
- vectors,
- count,
- d,
- metric,
- levels,
- neighbors,
- entry_point,
- max_observed_level,
- hnsw_params,
- )?))
-}
-
-fn write_u32_vec(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
- if value > u32::MAX as usize {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- format!("value {} exceeds u32 limit", value),
- ));
- }
- buf.extend_from_slice(&(value as u32).to_le_bytes());
- Ok(())
-}
-
-fn read_u32_vec(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
- let end = pos.checked_add(4).ok_or_else(|| {
- io::Error::new(
- io::ErrorKind::InvalidData,
- "graph section position overflow",
- )
- })?;
- if end > bytes.len() {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "truncated HNSW graph section",
- ));
- }
- let value = u32::from_le_bytes(bytes[*pos..end].try_into().unwrap());
- *pos = end;
- Ok(value)
-}
-
-fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
- out.write_all(&v.to_le_bytes())
-}
-
-fn write_i32_le(out: &mut dyn SeekWrite, v: i32) -> io::Result<()> {
- out.write_all(&v.to_le_bytes())
-}
-
-fn write_i64_le(out: &mut dyn SeekWrite, v: i64) -> io::Result<()> {
- out.write_all(&v.to_le_bytes())
-}
-
-fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32]) -> io::Result<()> {
- let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
- out.write_all(&bytes)
-}
-
-fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
- let mut buf = [0u8; 4];
- reader.read_exact(&mut buf)?;
- Ok(u32::from_le_bytes(buf))
-}
-
-fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
- let mut buf = [0u8; 4];
- reader.read_exact(&mut buf)?;
- Ok(i32::from_le_bytes(buf))
-}
-
-fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
- let mut buf = [0u8; 8];
- reader.read_exact(&mut buf)?;
- Ok(i64::from_le_bytes(buf))
-}
-
-fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) ->
io::Result<Vec<f32>> {
- let byte_len = count.checked_mul(4).ok_or_else(|| {
- io::Error::new(
- io::ErrorKind::InvalidData,
- "f32 section byte length overflow",
- )
- })?;
- let mut buf = vec![0u8; byte_len];
- reader.read_exact(&mut buf)?;
- bytes_to_f32_vec(&buf)
-}
-
-fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
- if !bytes.len().is_multiple_of(4) {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "f32 byte section is not 4-byte aligned",
- ));
- }
- Ok(bytes
- .chunks_exact(4)
- .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
- .collect())
-}
-
-fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
- if val <= 0 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- format!("invalid header field {}: {} (must be positive)", field,
val),
- ));
- }
- Ok(val)
-}
-
fn validate_index_shape(index: &IVFHNSWFlatIndex) -> io::Result<()> {
if index.flat.d == 0 {
return Err(io::Error::new(
@@ -923,76 +650,6 @@ fn validate_index_shape(index: &IVFHNSWFlatIndex) ->
io::Result<()> {
Ok(())
}
-fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
- if value > i32::MAX as usize {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- format!("{} exceeds i32 length limit: {}", field, value),
- ));
- }
- Ok(value as i32)
-}
-
-fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
- if value > i64::MAX as usize {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- format!("{} exceeds i64 length limit: {}", field, value),
- ));
- }
- Ok(value as i64)
-}
-
-fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
- if value > i64::MAX as u64 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- format!("{} exceeds i64 offset limit: {}", field, value),
- ));
- }
- Ok(value as i64)
-}
-
-const MAX_SECTION_ELEMENTS: usize = 1 << 30;
-
-fn checked_section_size(a: usize, b: usize) -> io::Result<usize> {
- let result = a.checked_mul(b).ok_or_else(|| {
- io::Error::new(
- io::ErrorKind::InvalidData,
- "section size overflow in IVF-HNSW-FLAT header",
- )
- })?;
- if result > MAX_SECTION_ELEMENTS {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- format!(
- "section size {} exceeds maximum {}",
- result, MAX_SECTION_ELEMENTS
- ),
- ));
- }
- Ok(result)
-}
-
-fn checked_list_offset(offset: i64, list_id: usize) -> io::Result<u64> {
- if offset < 0 {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- format!("negative list offset {} at list {}", offset, list_id),
- ));
- }
- Ok(offset as u64)
-}
-
-fn checked_list_bytes(count: usize, bytes_per_entry: usize) ->
io::Result<usize> {
- count.checked_mul(bytes_per_entry).ok_or_else(|| {
- io::Error::new(
- io::ErrorKind::InvalidData,
- "IVF-HNSW-FLAT list byte size overflow",
- )
- })
-}
-
fn list_payload_len(count: usize, d: usize, graph_bytes_len: usize) ->
io::Result<usize> {
let id_bytes = checked_list_bytes(count, 8)?;
let vector_bytes = checked_list_bytes(
@@ -1015,19 +672,11 @@ fn list_payload_len(count: usize, d: usize,
graph_bytes_len: usize) -> io::Resul
})
}
-fn decode_roaring_filter(bytes: &[u8]) -> io::Result<RoaringTreemap> {
- RoaringTreemap::deserialize_from(bytes).map_err(|e| {
- io::Error::new(
- io::ErrorKind::InvalidInput,
- format!("invalid RoaringTreemap filter: {}", e),
- )
- })
-}
-
#[cfg(test)]
mod tests {
use crate::distance::MetricType;
use crate::hnsw::HnswBuildParams;
+ use crate::index_io_util::decode_graph;
use crate::io::{PosWriter, SeekRead};
use crate::ivfhnswflat::IVFHNSWFlatIndex;
use crate::ivfhnswflat_io::{
@@ -1353,8 +1002,8 @@ mod tests {
append_u32(&mut graph_bytes, 0);
append_u32(&mut graph_bytes, params.max_level as u32 + 1);
- let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2,
MetricType::L2, params)
- .unwrap_err();
+ let err =
+ decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2,
params).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("level"));
@@ -1374,8 +1023,8 @@ mod tests {
append_u32(&mut graph_bytes, 0);
append_u32(&mut graph_bytes, (params.m * 2) as u32 + 1);
- let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2,
MetricType::L2, params)
- .unwrap_err();
+ let err =
+ decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2,
params).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("degree"));
diff --git a/core/src/ivfhnswsq.rs b/core/src/ivfhnswsq.rs
new file mode 100644
index 0000000..a945388
--- /dev/null
+++ b/core/src/ivfhnswsq.rs
@@ -0,0 +1,289 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{preprocess_vectors, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans::{self, KMeansConfig};
+use crate::sq::ScalarQuantizer;
+use crate::topk::TopKHeap;
+use std::io;
+
+pub struct IVFHNSWSQIndex {
+ pub d: usize,
+ pub nlist: usize,
+ pub metric: MetricType,
+ pub quantizer_centroids: Vec<f32>,
+ pub sq: ScalarQuantizer,
+ pub ids: Vec<Vec<i64>>,
+ pub codes: Vec<Vec<u8>>,
+ pub graphs: Vec<Option<HnswGraph>>,
+ pub hnsw_params: HnswBuildParams,
+}
+
+impl IVFHNSWSQIndex {
+ pub fn new(d: usize, nlist: usize, metric: MetricType, hnsw_params:
HnswBuildParams) -> Self {
+ Self {
+ d,
+ nlist,
+ metric,
+ quantizer_centroids: Vec::new(),
+ sq: ScalarQuantizer::new(d),
+ ids: vec![Vec::new(); nlist],
+ codes: vec![Vec::new(); nlist],
+ graphs: vec![None; nlist],
+ hnsw_params,
+ }
+ }
+
+ pub fn train(&mut self, data: &[f32], n: usize) {
+ let processed = self.preprocess_vectors(data, n);
+ self.quantizer_centroids =
+ kmeans::kmeans_train(&KMeansConfig::default(), &processed, n,
self.d, self.nlist);
+ self.sq.train(&processed, n);
+ }
+
+ pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
+ let processed = self.preprocess_vectors(data, n);
+ let code_size = self.sq.code_size();
+ let mut encoded = vec![0u8; n * code_size];
+ self.sq.encode_batch(&processed, n, &mut encoded);
+
+ for i in 0..n {
+ let vector = &processed[i * self.d..(i + 1) * self.d];
+ let list_id =
+ kmeans::find_nearest(vector, &self.quantizer_centroids,
self.nlist, self.d);
+ self.ids[list_id].push(ids[i]);
+ self.codes[list_id].extend_from_slice(&encoded[i * code_size..(i +
1) * code_size]);
+ }
+ self.graphs.fill(None);
+ }
+
+ pub fn total_vectors(&self) -> usize {
+ self.ids.iter().map(Vec::len).sum()
+ }
+
+ pub fn build_graphs(&mut self) -> io::Result<()> {
+ for list_id in 0..self.nlist {
+ let count = self.ids[list_id].len();
+ self.graphs[list_id] = if count == 0 {
+ None
+ } else {
+ let mut vectors = vec![0.0f32; count * self.d];
+ self.sq
+ .decode_batch(&self.codes[list_id], count, &mut vectors);
+ Some(HnswGraph::build(
+ &vectors,
+ count,
+ self.d,
+ self.metric,
+ self.hnsw_params,
+ )?)
+ };
+ }
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn search(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ self.search_with_filter(
+ queries,
+ nq,
+ k,
+ nprobe,
+ ef_search,
+ None,
+ result_distances,
+ result_labels,
+ );
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn search_with_filter(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ filter: Option<&dyn RowIdFilter>,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ let processed_queries = self.preprocess_vectors(queries, nq);
+ let (all_probe_indices, _) = kmeans::find_topk_batch(
+ &processed_queries,
+ nq,
+ &self.quantizer_centroids,
+ self.nlist,
+ self.d,
+ nprobe,
+ );
+
+ for qi in 0..nq {
+ let query = &processed_queries[qi * self.d..(qi + 1) * self.d];
+ let lists: Vec<_> = all_probe_indices[qi]
+ .iter()
+ .map(|&list_id| HnswSearchList {
+ ids: self.ids[list_id].as_slice(),
+ graph: self.graphs[list_id].as_ref(),
+ payload: list_id,
+ })
+ .collect();
+ let sorted = search_hnsw_lists(query, &lists, k, ef_search,
filter, |list, heap| {
+ self.scan_sq_list(query, list.payload, filter, heap);
+ });
+ let out_base = qi * k;
+ for (i, &(dist, id)) in sorted.iter().enumerate() {
+ result_distances[out_base + i] = dist;
+ result_labels[out_base + i] = id;
+ }
+ for i in sorted.len()..k {
+ result_distances[out_base + i] = f32::MAX;
+ result_labels[out_base + i] = -1;
+ }
+ }
+ }
+
+ pub(crate) fn preprocess_vectors(&self, data: &[f32], n: usize) ->
Vec<f32> {
+ preprocess_vectors(data, n, self.d, self.metric)
+ }
+
+ fn scan_sq_list(
+ &self,
+ query: &[f32],
+ list_id: usize,
+ filter: Option<&dyn RowIdFilter>,
+ heap: &mut TopKHeap,
+ ) {
+ let context = self.sq.distance_context(query, self.metric);
+ let code_size = self.sq.code_size();
+ for (local_id, &row_id) in self.ids[list_id].iter().enumerate() {
+ if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
+ continue;
+ }
+ let code = &self.codes[list_id][local_id * code_size..(local_id +
1) * code_size];
+ heap.push(
+ self.sq.distance_to_code_with_context(query, code, context),
+ row_id,
+ );
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::hnsw::HnswBuildParams;
+ use std::collections::HashSet;
+
+ #[test]
+ fn test_ivfhnswsq_recalls_query_vector() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [
+ cluster + i as f32 * 2.0,
+ 10.0 + i as f32,
+ 20.0 + i as f32,
+ 30.0 + i as f32,
+ ]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let query_id = 23;
+ let mut distances = vec![0.0; 5];
+ let mut labels = vec![0; 5];
+ index.search(
+ &data[query_id * d..(query_id + 1) * d],
+ 1,
+ 5,
+ nlist,
+ 32,
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(labels[0], ids[query_id]);
+ assert!(distances[0].is_finite());
+ }
+
+ #[test]
+ fn test_ivfhnswsq_without_built_graphs_falls_back_to_sq_scan() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+
+ let mut distances = vec![0.0; 2];
+ let mut labels = vec![0; 2];
+ index.search(&[0.0, 0.0], 1, 2, nlist, 8, &mut distances, &mut labels);
+
+ assert_eq!(labels[0], 10);
+ }
+
+ #[test]
+ fn test_ivfhnswsq_search_with_filter() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+ index.build_graphs().unwrap();
+
+ let filter: HashSet<i64> = [12].into_iter().collect();
+ let mut distances = vec![0.0; 2];
+ let mut labels = vec![0; 2];
+ index.search_with_filter(
+ &[0.0, 0.0],
+ 1,
+ 2,
+ nlist,
+ 8,
+ Some(&filter),
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(labels[0], 12);
+ assert_eq!(labels[1], -1);
+ }
+}
diff --git a/core/src/ivfhnswsq_io.rs b/core/src/ivfhnswsq_io.rs
new file mode 100644
index 0000000..cb06eef
--- /dev/null
+++ b/core/src/ivfhnswsq_io.rs
@@ -0,0 +1,851 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{preprocess_vectors, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
+use crate::index_io_util::{
+ checked_list_bytes, checked_list_offset, checked_section_size,
decode_graph,
+ decode_roaring_filter, encode_graph, read_f32_vec, read_i32_le,
read_i64_le, read_u32_le,
+ u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32,
validate_search_inputs,
+ write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
+};
+use crate::io::{SeekRead, SeekWrite};
+use crate::ivfhnswsq::IVFHNSWSQIndex;
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans;
+use crate::sq::ScalarQuantizer;
+use crate::topk::TopKHeap;
+use std::io;
+
+pub const IVF_HNSW_SQ_MAGIC: u32 = 0x49485351; // "IHSQ"
+pub const IVF_HNSW_SQ_VERSION: u32 = 1;
+pub const IVF_HNSW_SQ_HEADER_SIZE: usize = 64;
+
+pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out: &mut dyn SeekWrite)
-> io::Result<()> {
+ validate_index_shape(index)?;
+ let total_vectors = index.ids.iter().try_fold(0i64, |sum, ids| {
+ let count = usize_to_i64(ids.len(), "total vector count")?;
+ sum.checked_add(count).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "total vector count exceeds i64 length limit",
+ )
+ })
+ })?;
+ let graph_bytes: Vec<Vec<u8>> = (0..index.nlist)
+ .map(|list_id| {
+ if index.ids[list_id].is_empty() {
+ Ok(Vec::new())
+ } else {
+ encode_graph(index.graphs[list_id].as_ref())
+ }
+ })
+ .collect::<io::Result<_>>()?;
+
+ write_u32_le(out, IVF_HNSW_SQ_MAGIC)?;
+ write_u32_le(out, IVF_HNSW_SQ_VERSION)?;
+ write_i32_le(out, usize_to_i32(index.d, "dimension")?)?;
+ write_i32_le(out, usize_to_i32(index.nlist, "nlist")?)?;
+ write_u32_le(out, index.metric as u32)?;
+ write_i64_le(out, total_vectors)?;
+ let params = index.hnsw_params.sanitized();
+ write_i32_le(out, usize_to_i32(params.m, "hnsw m")?)?;
+ write_i32_le(
+ out,
+ usize_to_i32(params.ef_construction, "hnsw ef_construction")?,
+ )?;
+ write_i32_le(out, usize_to_i32(params.max_level, "hnsw max_level")?)?;
+ out.write_all(&index.sq.min.to_le_bytes())?;
+ out.write_all(&index.sq.max.to_le_bytes())?;
+ out.write_all(&[0u8; 16])?;
+
+ write_f32_slice(out, &index.quantizer_centroids)?;
+
+ let offset_table_size = index.nlist.checked_mul(24).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-SQ offset table size overflow",
+ )
+ })?;
+ let data_start = out
+ .pos()
+ .checked_add(offset_table_size as u64)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-SQ data start offset overflow",
+ )
+ })?;
+ let mut list_offsets = vec![0i64; index.nlist];
+ let mut list_counts = vec![0i32; index.nlist];
+ let mut list_graph_bytes_lens = vec![0i32; index.nlist];
+ let mut current_offset = data_start;
+
+ for list_id in 0..index.nlist {
+ list_offsets[list_id] = u64_to_i64(current_offset, "list offset")?;
+ let count = index.ids[list_id].len();
+ list_counts[list_id] = usize_to_i32(count, "list count")?;
+ list_graph_bytes_lens[list_id] =
usize_to_i32(graph_bytes[list_id].len(), "graph bytes")?;
+ if count > 0 {
+ let payload_len =
+ list_payload_len(count, index.sq.code_size(),
graph_bytes[list_id].len())?;
+ current_offset = current_offset
+ .checked_add(payload_len as u64)
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "IVF-HNSW-SQ
offset overflow")
+ })?;
+ }
+ }
+
+ for list_id in 0..index.nlist {
+ write_i64_le(out, list_offsets[list_id])?;
+ write_i32_le(out, list_counts[list_id])?;
+ write_i32_le(out, list_graph_bytes_lens[list_id])?;
+ write_i64_le(out, 0)?;
+ }
+
+ for list_id in 0..index.nlist {
+ if index.ids[list_id].is_empty() {
+ continue;
+ }
+ for &id in &index.ids[list_id] {
+ write_i64_le(out, id)?;
+ }
+ out.write_all(&index.codes[list_id])?;
+ out.write_all(&graph_bytes[list_id])?;
+ }
+
+ Ok(())
+}
+
+pub struct IVFHNSWSQIndexReader<R: SeekRead> {
+ reader: R,
+ pub d: usize,
+ pub nlist: usize,
+ pub metric: MetricType,
+ pub total_vectors: i64,
+ pub hnsw_params: HnswBuildParams,
+ pub sq: ScalarQuantizer,
+ pub quantizer_centroids: Vec<f32>,
+ pub list_offsets: Vec<i64>,
+ pub list_counts: Vec<i32>,
+ pub list_graph_bytes_lens: Vec<i32>,
+ loaded: bool,
+}
+
+impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
+ pub fn open(mut reader: R) -> io::Result<Self> {
+ reader.seek(0)?;
+
+ let magic = read_u32_le(&mut reader)?;
+ if magic != IVF_HNSW_SQ_MAGIC {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Invalid IVF_HNSW_SQ magic: 0x{:08X}", magic),
+ ));
+ }
+ let version = read_u32_le(&mut reader)?;
+ if version != IVF_HNSW_SQ_VERSION {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unsupported IVF_HNSW_SQ version: {}", version),
+ ));
+ }
+
+ let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as
usize;
+ let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")?
as usize;
+ let metric_code = read_u32_le(&mut reader)?;
+ let metric = MetricType::from_code(metric_code).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unknown metric type: {}", metric_code),
+ )
+ })?;
+ let total_vectors = read_i64_le(&mut reader)?;
+ let hnsw_params = HnswBuildParams {
+ m: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw m")? as
usize,
+ ef_construction: validate_positive_i32(
+ read_i32_le(&mut reader)?,
+ "hnsw ef_construction",
+ )? as usize,
+ max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw
max_level")? as usize,
+ }
+ .sanitized();
+ let mut min_bytes = [0u8; 4];
+ let mut max_bytes = [0u8; 4];
+ reader.read_exact(&mut min_bytes)?;
+ reader.read_exact(&mut max_bytes)?;
+ let sq_min = f32::from_le_bytes(min_bytes);
+ let sq_max = f32::from_le_bytes(max_bytes);
+ if !sq_min.is_finite() || !sq_max.is_finite() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "SQ bounds must be finite",
+ ));
+ }
+ let mut reserved = [0u8; 16];
+ reader.read_exact(&mut reserved)?;
+
+ Ok(Self {
+ reader,
+ d,
+ nlist,
+ metric,
+ total_vectors,
+ hnsw_params,
+ sq: ScalarQuantizer::with_bounds(d, sq_min, sq_max),
+ quantizer_centroids: Vec::new(),
+ list_offsets: Vec::new(),
+ list_counts: Vec::new(),
+ list_graph_bytes_lens: Vec::new(),
+ loaded: false,
+ })
+ }
+
+ pub fn ensure_loaded(&mut self) -> io::Result<()> {
+ if self.loaded {
+ return Ok(());
+ }
+
+ self.reader.seek(IVF_HNSW_SQ_HEADER_SIZE as u64)?;
+ self.quantizer_centroids =
+ read_f32_vec(&mut self.reader, checked_section_size(self.nlist,
self.d)?)?;
+ self.list_offsets = vec![0; self.nlist];
+ self.list_counts = vec![0; self.nlist];
+ self.list_graph_bytes_lens = vec![0; self.nlist];
+ for list_id in 0..self.nlist {
+ self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
+ let count = read_i32_le(&mut self.reader)?;
+ if count < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("negative list count {} at list {}", count,
list_id),
+ ));
+ }
+ self.list_counts[list_id] = count;
+ let graph_bytes_len = read_i32_le(&mut self.reader)?;
+ if graph_bytes_len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "negative graph_bytes_len {} at list {}",
+ graph_bytes_len, list_id
+ ),
+ ));
+ }
+ self.list_graph_bytes_lens[list_id] = graph_bytes_len;
+ let _reserved = read_i64_le(&mut self.reader)?;
+ }
+
+ self.loaded = true;
+ Ok(())
+ }
+
+ pub fn read_inverted_list(
+ &mut self,
+ list_id: usize,
+ ) -> io::Result<(Vec<i64>, Vec<u8>, Option<HnswGraph>)> {
+ let Some(list) = self.read_graph_list(list_id)? else {
+ return Ok((Vec::new(), Vec::new(), None));
+ };
+ Ok((list.ids, list.codes, Some(list.graph)))
+ }
+
+ fn read_graph_list(&mut self, list_id: usize) ->
io::Result<Option<GraphList>> {
+ self.ensure_loaded()?;
+ if list_id >= self.nlist {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list_id {} out of range (nlist={})", list_id,
self.nlist),
+ ));
+ }
+ let count = self.list_counts[list_id] as usize;
+ if count == 0 {
+ return Ok(None);
+ }
+
+ let offset = checked_list_offset(self.list_offsets[list_id], list_id)?;
+ let graph_bytes_len = self.list_graph_bytes_lens[list_id] as usize;
+ if graph_bytes_len == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("list {} is missing HNSW graph", list_id),
+ ));
+ }
+ let payload_len = list_payload_len(count, self.sq.code_size(),
graph_bytes_len)?;
+ let mut payload = vec![0u8; payload_len];
+ self.reader.pread(offset, &mut payload)?;
+
+ let ids_bytes_len = checked_list_bytes(count, 8)?;
+ let code_size = self.sq.code_size();
+ let codes_bytes_len = checked_list_bytes(count, code_size)?;
+ let ids = payload[..ids_bytes_len]
+ .chunks_exact(8)
+ .map(|c| i64::from_le_bytes(c.try_into().unwrap()))
+ .collect();
+ let codes = payload[ids_bytes_len..ids_bytes_len +
codes_bytes_len].to_vec();
+ let mut vectors = vec![0.0f32; count * self.d];
+ self.sq.decode_batch(&codes, count, &mut vectors);
+ let graph = decode_graph(
+ &payload[ids_bytes_len + codes_bytes_len..],
+ vectors,
+ count,
+ self.d,
+ self.metric,
+ self.hnsw_params,
+ )?;
+ let graph = graph.ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("list {} is missing HNSW graph", list_id),
+ )
+ })?;
+ Ok(Some(GraphList { ids, codes, graph }))
+ }
+
+ pub fn search(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ self.search_with_filter(query, k, nprobe, ef_search, None)
+ }
+
+ pub fn search_with_filter(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ filter: Option<&dyn RowIdFilter>,
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ self.ensure_loaded()?;
+ validate_search_inputs(query, 1, self.d, k, nprobe)?;
+
+ let q = preprocess_vectors(query, 1, self.d, self.metric);
+ let (probe_indices, _) =
+ kmeans::find_topk(&q, &self.quantizer_centroids, self.nlist,
self.d, nprobe);
+ let mut loaded_lists = Vec::with_capacity(probe_indices.len());
+ for list_id in probe_indices {
+ if let Some(list) = self.read_graph_list(list_id)? {
+ loaded_lists.push(list);
+ }
+ }
+ let search_lists: Vec<_> = loaded_lists
+ .iter()
+ .map(|list| HnswSearchList {
+ ids: list.ids.as_slice(),
+ graph: Some(&list.graph),
+ payload: list,
+ })
+ .collect();
+ let sorted = search_hnsw_lists(&q, &search_lists, k, ef_search,
filter, |list, heap| {
+ let list = list.payload;
+ scan_sq_list(
+ &q,
+ &list.ids,
+ &list.codes,
+ &self.sq,
+ self.metric,
+ filter,
+ heap,
+ );
+ });
+ let mut labels: Vec<i64> = sorted.iter().map(|&(_, id)| id).collect();
+ let mut distances: Vec<f32> = sorted.iter().map(|&(dist, _)|
dist).collect();
+ labels.resize(k, -1);
+ distances.resize(k, f32::MAX);
+ Ok((labels, distances))
+ }
+
+ pub fn search_with_roaring_filter(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ roaring_filter_bytes: &[u8],
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ let filter = decode_roaring_filter(roaring_filter_bytes)?;
+ self.search_with_filter(query, k, nprobe, ef_search, Some(&filter))
+ }
+}
+
+pub fn search_batch_ivfhnswsq_reader<R: SeekRead>(
+ reader: &mut IVFHNSWSQIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ search_batch_ivfhnswsq_reader_filter(reader, queries, nq, k, nprobe,
ef_search, None)
+}
+
+pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
+ reader: &mut IVFHNSWSQIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ filter: Option<&dyn RowIdFilter>,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ reader.ensure_loaded()?;
+ validate_search_inputs(queries, nq, reader.d, k, nprobe)?;
+
+ let processed = preprocess_vectors(queries, nq, reader.d, reader.metric);
+ let (all_probe_indices, _) = kmeans::find_topk_batch(
+ &processed,
+ nq,
+ &reader.quantizer_centroids,
+ reader.nlist,
+ reader.d,
+ nprobe,
+ );
+
+ let mut list_to_queries = vec![Vec::new(); reader.nlist];
+ let mut unique_lists = Vec::new();
+ for (qi, probe_indices) in all_probe_indices.iter().enumerate() {
+ for &list_id in probe_indices {
+ if list_to_queries[list_id].is_empty() {
+ unique_lists.push(list_id);
+ }
+ list_to_queries[list_id].push(qi);
+ }
+ }
+
+ let mut heaps: Vec<TopKHeap> = (0..nq).map(|_| TopKHeap::new(k)).collect();
+ let mut query_filtered_counts = vec![0usize; nq];
+ let mut loaded_lists = Vec::with_capacity(unique_lists.len());
+ for list_id in unique_lists {
+ let Some(list) = reader.read_graph_list(list_id)? else {
+ continue;
+ };
+ if let Some(f) = filter {
+ let filtered = list.ids.iter().filter(|&&id|
f.contains(id)).count();
+ for &qi in &list_to_queries[list_id] {
+ query_filtered_counts[qi] = query_filtered_counts[qi]
+ .checked_add(filtered)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "filtered vector count overflows usize",
+ )
+ })?;
+ }
+ }
+ loaded_lists.push(LoadedBatchList {
+ query_ids: std::mem::take(&mut list_to_queries[list_id]),
+ ids: list.ids,
+ codes: list.codes,
+ graph: list.graph,
+ });
+ }
+
+ for list in &loaded_lists {
+ for &qi in &list.query_ids {
+ let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+ let force_sq_scan = filter
+ .map(|_| query_filtered_counts[qi] <= ef_search.max(k))
+ .unwrap_or(false);
+ if force_sq_scan {
+ scan_sq_list(
+ query,
+ &list.ids,
+ &list.codes,
+ &reader.sq,
+ reader.metric,
+ filter,
+ &mut heaps[qi],
+ );
+ } else {
+ let local_results = list.graph.search(query, ef_search.max(k),
ef_search.max(k));
+ for (local_id, dist) in local_results {
+ let row_id = list.ids[local_id];
+ if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+ heaps[qi].push(dist, row_id);
+ }
+ }
+ }
+ }
+ }
+ if filter.is_some() {
+ for list in &loaded_lists {
+ for &qi in &list.query_ids {
+ if heaps[qi].len() >= k {
+ continue;
+ }
+ if query_filtered_counts[qi] <= ef_search.max(k) {
+ continue;
+ }
+ let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+ scan_sq_list(
+ query,
+ &list.ids,
+ &list.codes,
+ &reader.sq,
+ reader.metric,
+ filter,
+ &mut heaps[qi],
+ );
+ }
+ }
+ }
+
+ let mut result_ids = vec![-1i64; nq * k];
+ let mut result_dists = vec![f32::MAX; nq * k];
+ for qi in 0..nq {
+ let sorted = std::mem::replace(&mut heaps[qi],
TopKHeap::new(0)).into_sorted();
+ let base = qi * k;
+ for (i, &(dist, id)) in sorted.iter().enumerate() {
+ result_ids[base + i] = id;
+ result_dists[base + i] = dist;
+ }
+ }
+
+ Ok((result_ids, result_dists))
+}
+
+pub fn search_batch_ivfhnswsq_reader_roaring_filter<R: SeekRead>(
+ reader: &mut IVFHNSWSQIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ ef_search: usize,
+ roaring_filter_bytes: &[u8],
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ let filter = decode_roaring_filter(roaring_filter_bytes)?;
+ search_batch_ivfhnswsq_reader_filter(reader, queries, nq, k, nprobe,
ef_search, Some(&filter))
+}
+
+struct GraphList {
+ ids: Vec<i64>,
+ codes: Vec<u8>,
+ graph: HnswGraph,
+}
+
+struct LoadedBatchList {
+ query_ids: Vec<usize>,
+ ids: Vec<i64>,
+ codes: Vec<u8>,
+ graph: HnswGraph,
+}
+
+fn scan_sq_list(
+ query: &[f32],
+ ids: &[i64],
+ codes: &[u8],
+ sq: &ScalarQuantizer,
+ metric: MetricType,
+ filter: Option<&dyn RowIdFilter>,
+ heap: &mut TopKHeap,
+) {
+ let context = sq.distance_context(query, metric);
+ let code_size = sq.code_size();
+ for (local_id, &row_id) in ids.iter().enumerate() {
+ if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
+ continue;
+ }
+ let code = &codes[local_id * code_size..(local_id + 1) * code_size];
+ heap.push(
+ sq.distance_to_code_with_context(query, code, context),
+ row_id,
+ );
+ }
+}
+
+fn validate_index_shape(index: &IVFHNSWSQIndex) -> io::Result<()> {
+ if index.d == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "dimension must be greater than 0",
+ ));
+ }
+ if index.nlist == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nlist must be greater than 0",
+ ));
+ }
+ if index.sq.d != index.d {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "SQ dimension does not match index dimension",
+ ));
+ }
+ let centroid_len = checked_section_size(index.nlist, index.d)?;
+ if index.quantizer_centroids.len() != centroid_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "quantizer centroid length {} does not match nlist*d {}",
+ index.quantizer_centroids.len(),
+ centroid_len
+ ),
+ ));
+ }
+ if index.ids.len() != index.nlist
+ || index.codes.len() != index.nlist
+ || index.graphs.len() != index.nlist
+ {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "inverted list count does not match nlist",
+ ));
+ }
+ for list_id in 0..index.nlist {
+ let count = index.ids[list_id].len();
+ let expected_codes_len = checked_list_bytes(count,
index.sq.code_size())?;
+ if index.codes[list_id].len() != expected_codes_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "list {} SQ code length {} does not match count*d {}",
+ list_id,
+ index.codes[list_id].len(),
+ expected_codes_len
+ ),
+ ));
+ }
+ match &index.graphs[list_id] {
+ Some(graph) if count > 0 => {
+ let mut decoded = vec![0.0f32; count * index.d];
+ index
+ .sq
+ .decode_batch(&index.codes[list_id], count, &mut decoded);
+ if graph.len() != count || graph.vectors() !=
decoded.as_slice() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list {} graph does not match SQ code
storage", list_id),
+ ));
+ }
+ }
+ Some(_) => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list {} has graph for an empty list", list_id),
+ ));
+ }
+ None if count == 0 => {}
+ None => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "list {} is missing HNSW graph; call build_graphs
first",
+ list_id
+ ),
+ ));
+ }
+ }
+ }
+ Ok(())
+}
+
+fn list_payload_len(count: usize, code_size: usize, graph_bytes_len: usize) ->
io::Result<usize> {
+ let id_bytes = checked_list_bytes(count, 8)?;
+ let code_bytes = checked_list_bytes(count, code_size)?;
+ id_bytes
+ .checked_add(code_bytes)
+ .and_then(|len| len.checked_add(graph_bytes_len))
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-HNSW-SQ list payload length overflow",
+ )
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::hnsw::HnswBuildParams;
+ use crate::io::PosWriter;
+ use roaring::RoaringTreemap;
+ use std::io::Cursor;
+
+ #[test]
+ fn test_ivfhnswsq_write_read_search_roundtrip() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [
+ cluster + i as f32 * 2.0,
+ 10.0 + i as f32,
+ 20.0 + i as f32,
+ 30.0 + i as f32,
+ ]
+ })
+ .collect();
+ let ids: Vec<i64> = (10_000..10_000 + n as i64).collect();
+
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfhnswsq_index(&index, &mut writer).unwrap();
+
+ let mut reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+ let query_id = 23;
+ let (labels, distances) = reader
+ .search(&data[query_id * d..(query_id + 1) * d], 5, nlist, 32)
+ .unwrap();
+
+ assert_eq!(labels[0], ids[query_id]);
+ assert!(distances[0].is_finite());
+ }
+
+ #[test]
+ fn test_ivfhnswsq_reader_search_with_roaring_filter() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+ let mut reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+
+ let mut filter = RoaringTreemap::new();
+ filter.insert(12);
+ let mut filter_bytes = Vec::new();
+ filter.serialize_into(&mut filter_bytes).unwrap();
+
+ let (labels, _) = reader
+ .search_with_roaring_filter(&[0.0, 0.0], 2, nlist, 8,
&filter_bytes)
+ .unwrap();
+
+ assert_eq!(labels, vec![12, -1]);
+ }
+
+ #[test]
+ fn test_ivfhnswsq_write_read_search_roundtrip_cosine() {
+ let d = 3;
+ let nlist = 2;
+ let data = vec![1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9,
0.1];
+ let ids = vec![10, 11, 12, 13];
+ let mut index =
+ IVFHNSWSQIndex::new(d, nlist, MetricType::Cosine,
HnswBuildParams::default());
+ index.train(&data, 4);
+ index.add(&data, &ids, 4);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+
+ let mut reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+ let (labels, distances) = reader.search(&[1.0, 0.0, 0.0], 2, nlist,
16).unwrap();
+
+ assert_eq!(labels[0], 10);
+ assert!(distances[0].is_finite());
+ }
+
+ #[test]
+ fn test_ivfhnswsq_batch_matches_single_search() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+ let queries = [&data[7 * d..8 * d], &data[23 * d..24 * d]].concat();
+
+ let mut batch_reader =
IVFHNSWSQIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ let (batch_labels, batch_distances) =
+ search_batch_ivfhnswsq_reader(&mut batch_reader, &queries, 2, 3,
nlist, 32).unwrap();
+
+ for qi in 0..2 {
+ let mut single_reader =
IVFHNSWSQIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ let (single_labels, single_distances) = single_reader
+ .search(&queries[qi * d..(qi + 1) * d], 3, nlist, 32)
+ .unwrap();
+ assert_eq!(
+ &batch_labels[qi * 3..(qi + 1) * 3],
+ single_labels.as_slice()
+ );
+ assert_eq!(
+ &batch_distances[qi * 3..(qi + 1) * 3],
+ single_distances.as_slice()
+ );
+ }
+ }
+
+ #[test]
+ fn test_ivfhnswsq_write_requires_graphs() {
+ let mut index = IVFHNSWSQIndex::new(2, 1, MetricType::L2,
HnswBuildParams::default());
+ let data = vec![0.0, 0.0];
+ index.train(&data, 1);
+ index.add(&data, &[1], 1);
+
+ let mut buf = Vec::new();
+ let err = write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut
buf)).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err.to_string().contains("build_graphs"));
+ }
+
+ #[test]
+ fn test_ivfhnswsq_writer_rejects_stale_graph() {
+ let mut index = IVFHNSWSQIndex::new(2, 1, MetricType::L2,
HnswBuildParams::default());
+ let data = vec![0.0, 0.0, 1.0, 0.0];
+ index.train(&data, 2);
+ index.add(&data, &[10, 11], 2);
+ index.build_graphs().unwrap();
+ index.graphs[0] = Some(
+ HnswGraph::build(
+ &[10.0, 10.0, 11.0, 11.0],
+ 2,
+ 2,
+ MetricType::L2,
+ HnswBuildParams::default(),
+ )
+ .unwrap(),
+ );
+
+ let mut buf = Vec::new();
+ let err = write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut
buf)).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err.to_string().contains("graph does not match"));
+ }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index 8f2f033..fd2b62d 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -22,14 +22,19 @@ pub mod blas;
pub mod distance;
pub mod fastscan;
pub mod hnsw;
+pub(crate) mod hnsw_search;
+pub(crate) mod index_io_util;
pub mod io;
pub mod ivfflat;
pub mod ivfflat_io;
pub mod ivfhnswflat;
pub mod ivfhnswflat_io;
+pub mod ivfhnswsq;
+pub mod ivfhnswsq_io;
pub mod ivfpq;
pub mod kmeans;
pub mod opq;
pub mod pq;
pub mod shuffler;
+pub mod sq;
pub mod topk;
diff --git a/core/src/sq.rs b/core/src/sq.rs
new file mode 100644
index 0000000..f1eb5da
--- /dev/null
+++ b/core/src/sq.rs
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_norm_l2sqr, MetricType};
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct ScalarQuantizer {
+ pub d: usize,
+ pub min: f32,
+ pub max: f32,
+}
+
+impl ScalarQuantizer {
+ pub fn new(d: usize) -> Self {
+ Self {
+ d,
+ min: 0.0,
+ max: 0.0,
+ }
+ }
+
+ pub fn with_bounds(d: usize, min: f32, max: f32) -> Self {
+ Self { d, min, max }
+ }
+
+ pub fn train(&mut self, data: &[f32], n: usize) {
+ let values = &data[..n * self.d];
+ if values.is_empty() {
+ self.min = 0.0;
+ self.max = 0.0;
+ return;
+ }
+
+ let mut min = f32::INFINITY;
+ let mut max = f32::NEG_INFINITY;
+ for &value in values {
+ min = min.min(value);
+ max = max.max(value);
+ }
+ self.min = min;
+ self.max = max;
+ }
+
+ pub fn code_size(&self) -> usize {
+ self.d
+ }
+
+ pub fn encode_batch(&self, data: &[f32], n: usize, codes: &mut [u8]) {
+ let len = n * self.d;
+ assert!(data.len() >= len);
+ assert!(codes.len() >= len);
+
+ if self.min >= self.max {
+ codes[..len].fill(0);
+ return;
+ }
+
+ let scale = 255.0 / (self.max - self.min);
+ for i in 0..len {
+ let scaled = ((data[i] - self.min) * scale).clamp(0.0, 255.0);
+ codes[i] = scaled as u8;
+ }
+ }
+
+ pub fn encode(&self, vector: &[f32], code: &mut [u8]) {
+ self.encode_batch(vector, 1, code);
+ }
+
+ pub fn decode_batch(&self, codes: &[u8], n: usize, vectors: &mut [f32]) {
+ let len = n * self.d;
+ assert!(codes.len() >= len);
+ assert!(vectors.len() >= len);
+
+ if self.min >= self.max {
+ vectors[..len].fill(self.min);
+ return;
+ }
+
+ let scale = (self.max - self.min) / 255.0;
+ for i in 0..len {
+ vectors[i] = self.min + codes[i] as f32 * scale;
+ }
+ }
+
+ pub fn decode(&self, code: &[u8], vector: &mut [f32]) {
+ self.decode_batch(code, 1, vector);
+ }
+
+ pub fn distance_to_code(&self, query: &[f32], code: &[u8], metric:
MetricType) -> f32 {
+ self.distance_to_code_with_context(query, code,
self.distance_context(query, metric))
+ }
+
+ pub fn distance_context(&self, query: &[f32], metric: MetricType) ->
DistanceContext {
+ debug_assert!(query.len() >= self.d);
+ DistanceContext::new(&query[..self.d], metric)
+ }
+
+ pub fn distance_to_code_with_context(
+ &self,
+ query: &[f32],
+ code: &[u8],
+ context: DistanceContext,
+ ) -> f32 {
+ debug_assert!(query.len() >= self.d);
+ debug_assert!(code.len() >= self.d);
+
+ match context.metric {
+ MetricType::L2 => {
+ let mut sum = 0.0f32;
+ for i in 0..self.d {
+ let diff = query[i] - self.decode_value(code[i]);
+ sum += diff * diff;
+ }
+ sum
+ }
+ MetricType::InnerProduct => {
+ let mut dot = 0.0f32;
+ for i in 0..self.d {
+ dot += query[i] * self.decode_value(code[i]);
+ }
+ -dot
+ }
+ MetricType::Cosine => {
+ let mut dot = 0.0f32;
+ let mut vector_norm = 0.0f32;
+ for i in 0..self.d {
+ let value = self.decode_value(code[i]);
+ dot += query[i] * value;
+ vector_norm += value * value;
+ }
+ let denom = context.query_norm * vector_norm.sqrt();
+ if denom > 0.0 {
+ 1.0 - dot / denom
+ } else {
+ 1.0
+ }
+ }
+ }
+ }
+
+ fn decode_value(&self, code: u8) -> f32 {
+ if self.min >= self.max {
+ self.min
+ } else {
+ self.min + code as f32 * (self.max - self.min) / 255.0
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DistanceContext {
+ metric: MetricType,
+ query_norm: f32,
+}
+
+impl DistanceContext {
+ pub fn new(query: &[f32], metric: MetricType) -> Self {
+ let query_norm = if metric == MetricType::Cosine {
+ fvec_norm_l2sqr(query).sqrt()
+ } else {
+ 0.0
+ };
+ Self { metric, query_norm }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_scalar_quantizer_round_trips_bounds() {
+ let data = vec![-1.0, 0.0, 1.0, 3.0];
+ let mut sq = ScalarQuantizer::new(2);
+
+ sq.train(&data, 2);
+ let mut codes = vec![0u8; data.len()];
+ sq.encode_batch(&data, 2, &mut codes);
+ let mut decoded = vec![0.0f32; data.len()];
+ sq.decode_batch(&codes, 2, &mut decoded);
+
+ assert_eq!(sq.min, -1.0);
+ assert_eq!(sq.max, 3.0);
+ assert_eq!(codes[0], 0);
+ assert_eq!(codes[3], 255);
+ assert!((decoded[0] + 1.0).abs() < 1e-6);
+ assert!((decoded[3] - 3.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_scalar_quantizer_constant_input() {
+ let data = vec![5.0, 5.0, 5.0, 5.0];
+ let mut sq = ScalarQuantizer::new(2);
+ sq.train(&data, 2);
+
+ let mut codes = vec![7u8; data.len()];
+ sq.encode_batch(&data, 2, &mut codes);
+ let mut decoded = vec![0.0f32; data.len()];
+ sq.decode_batch(&codes, 2, &mut decoded);
+
+ assert_eq!(codes, vec![0, 0, 0, 0]);
+ assert_eq!(decoded, data);
+ }
+
+ #[test]
+ fn test_scalar_quantizer_distance_to_code() {
+ let sq = ScalarQuantizer::with_bounds(2, 0.0, 1.0);
+ let mut code = vec![0u8; 2];
+ sq.encode(&[1.0, 0.0], &mut code);
+
+ let dist = sq.distance_to_code(&[1.0, 0.0], &code, MetricType::L2);
+
+ assert!(dist < 1e-6);
+ }
+}