This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new 8bc5cfe  Add IVF HNSW SQ index (#20)
8bc5cfe is described below

commit 8bc5cfebc5a3151efefdcc4a76cad374be4a66df
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 15:19:40 2026 +0800

    Add IVF HNSW SQ index (#20)
---
 core/benches/recall_bench.rs |  39 ++
 core/src/hnsw_search.rs      |  76 ++++
 core/src/index_io_util.rs    | 346 ++++++++++++++++++
 core/src/ivfhnswflat.rs      |  56 +--
 core/src/ivfhnswflat_io.rs   | 415 ++-------------------
 core/src/ivfhnswsq.rs        | 289 +++++++++++++++
 core/src/ivfhnswsq_io.rs     | 851 +++++++++++++++++++++++++++++++++++++++++++
 core/src/lib.rs              |   5 +
 core/src/sq.rs               | 229 ++++++++++++
 9 files changed, 1879 insertions(+), 427 deletions(-)

diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index 1755402..5103446 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -2,6 +2,7 @@ use paimon_vindex_core::distance::{fvec_distance, MetricType};
 use paimon_vindex_core::hnsw::HnswBuildParams;
 use paimon_vindex_core::ivfflat::IVFFlatIndex;
 use paimon_vindex_core::ivfhnswflat::IVFHNSWFlatIndex;
+use paimon_vindex_core::ivfhnswsq::IVFHNSWSQIndex;
 use paimon_vindex_core::ivfpq::IVFPQIndex;
 use std::collections::HashSet;
 use std::time::Instant;
@@ -99,6 +100,22 @@ fn run_scenario(s: Scenario<'_>) {
     ivfhnswflat.build_graphs().unwrap();
     println!("build IVF-HNSW-FLAT: {:.2}s", start.elapsed().as_secs_f64());
 
+    let start = Instant::now();
+    let mut ivfhnswsq = IVFHNSWSQIndex::new(
+        s.d,
+        s.nlist,
+        MetricType::L2,
+        HnswBuildParams {
+            m: 16,
+            ef_construction: s.hnsw_build_ef,
+            max_level: 7,
+        },
+    );
+    ivfhnswsq.train(&data, s.n);
+    ivfhnswsq.add(&data, &ids, s.n);
+    ivfhnswsq.build_graphs().unwrap();
+    println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
+
     println!();
     println!(
         "index      nprobe  ef      recall@{}  query_ms  us/query",
@@ -157,6 +174,28 @@ fn run_scenario(s: Scenario<'_>) {
                 elapsed,
                 s.nq,
             );
+
+            let mut distances = vec![0.0f32; s.nq * s.k];
+            let mut labels = vec![0i64; s.nq * s.k];
+            let start = Instant::now();
+            ivfhnswsq.search(
+                queries,
+                s.nq,
+                s.k,
+                nprobe,
+                ef_search,
+                &mut distances,
+                &mut labels,
+            );
+            let elapsed = start.elapsed();
+            print_row(
+                "IVF-HSQ",
+                nprobe,
+                Some(ef_search),
+                recall_at_k(&labels, &ground_truth, s.nq, s.k),
+                elapsed,
+                s.nq,
+            );
         }
     }
 }
diff --git a/core/src/hnsw_search.rs b/core/src/hnsw_search.rs
new file mode 100644
index 0000000..74eab88
--- /dev/null
+++ b/core/src/hnsw_search.rs
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::hnsw::HnswGraph;
+use crate::ivfpq::RowIdFilter;
+use crate::topk::TopKHeap;
+
+pub(crate) struct HnswSearchList<'a, P> {
+    pub(crate) ids: &'a [i64],
+    pub(crate) graph: Option<&'a HnswGraph>,
+    pub(crate) payload: P,
+}
+
+pub(crate) fn search_hnsw_lists<'a, P, F>(
+    query: &[f32],
+    lists: &[HnswSearchList<'a, P>],
+    k: usize,
+    ef_search: usize,
+    filter: Option<&dyn RowIdFilter>,
+    mut scan_list: F,
+) -> Vec<(f32, i64)>
+where
+    F: FnMut(&HnswSearchList<'a, P>, &mut TopKHeap),
+{
+    let mut heap = TopKHeap::new(k);
+    let force_scan = filter
+        .map(|f| count_filtered(lists, f) <= ef_search.max(k))
+        .unwrap_or(false);
+
+    for list in lists {
+        if force_scan {
+            scan_list(list, &mut heap);
+            continue;
+        }
+        if let Some(graph) = list.graph {
+            let local_results = graph.search(query, ef_search.max(k), 
ef_search.max(k));
+            for (local_id, dist) in local_results {
+                let row_id = list.ids[local_id];
+                if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+                    heap.push(dist, row_id);
+                }
+            }
+        } else {
+            scan_list(list, &mut heap);
+        }
+    }
+
+    if filter.is_some() && heap.len() < k && !force_scan {
+        for list in lists {
+            scan_list(list, &mut heap);
+        }
+    }
+
+    heap.into_sorted()
+}
+
+fn count_filtered<P>(lists: &[HnswSearchList<'_, P>], filter: &dyn 
RowIdFilter) -> usize {
+    lists
+        .iter()
+        .map(|list| list.ids.iter().filter(|&&id| filter.contains(id)).count())
+        .sum()
+}
diff --git a/core/src/index_io_util.rs b/core/src/index_io_util.rs
new file mode 100644
index 0000000..292953b
--- /dev/null
+++ b/core/src/index_io_util.rs
@@ -0,0 +1,346 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::MetricType;
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::io::{SeekRead, SeekWrite};
+use roaring::RoaringTreemap;
+use std::io;
+
+pub(crate) fn validate_search_inputs(
+    queries: &[f32],
+    nq: usize,
+    d: usize,
+    k: usize,
+    nprobe: usize,
+) -> io::Result<()> {
+    if nq == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nq must be greater than 0",
+        ));
+    }
+    let expected_query_len = nq.checked_mul(d).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nq * dimension overflows usize",
+        )
+    })?;
+    if queries.len() != expected_query_len {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!(
+                "queries length {} does not match nq * dimension {}",
+                queries.len(),
+                expected_query_len
+            ),
+        ));
+    }
+    if k == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "k must be greater than 0",
+        ));
+    }
+    if nprobe == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nprobe must be greater than 0",
+        ));
+    }
+    Ok(())
+}
+
+pub(crate) fn encode_graph(graph: Option<&HnswGraph>) -> io::Result<Vec<u8>> {
+    let Some(graph) = graph else {
+        return Ok(Vec::new());
+    };
+    let mut buf = Vec::new();
+    write_u32_vec(&mut buf, graph.len())?;
+    write_u32_vec(&mut buf, graph.entry_point())?;
+    write_u32_vec(&mut buf, graph.max_observed_level())?;
+    for &level in graph.levels() {
+        write_u32_vec(&mut buf, level)?;
+    }
+    for node_levels in graph.neighbors() {
+        for level_neighbors in node_levels {
+            write_u32_vec(&mut buf, level_neighbors.len())?;
+            for &neighbor in level_neighbors {
+                write_u32_vec(&mut buf, neighbor)?;
+            }
+        }
+    }
+    Ok(buf)
+}
+
+pub(crate) fn decode_graph(
+    bytes: &[u8],
+    vectors: Vec<f32>,
+    count: usize,
+    d: usize,
+    metric: MetricType,
+    hnsw_params: HnswBuildParams,
+) -> io::Result<Option<HnswGraph>> {
+    if bytes.is_empty() {
+        return Ok(None);
+    }
+    let mut pos = 0usize;
+    let graph_count = read_u32_vec(bytes, &mut pos)? as usize;
+    if graph_count != count {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!(
+                "graph count {} does not match list count {}",
+                graph_count, count
+            ),
+        ));
+    }
+    let entry_point = read_u32_vec(bytes, &mut pos)? as usize;
+    let max_observed_level = read_u32_vec(bytes, &mut pos)? as usize;
+    let mut levels = Vec::with_capacity(count);
+    for node in 0..count {
+        let level = read_u32_vec(bytes, &mut pos)? as usize;
+        if level >= hnsw_params.max_level {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!(
+                    "graph node {} level {} exceeds max_level {}",
+                    node,
+                    level,
+                    hnsw_params.max_level - 1
+                ),
+            ));
+        }
+        levels.push(level);
+    }
+    let mut neighbors = Vec::with_capacity(count);
+    for (node, &level) in levels.iter().enumerate() {
+        let mut node_levels = Vec::with_capacity(level + 1);
+        for graph_level in 0..=level {
+            let degree = read_u32_vec(bytes, &mut pos)? as usize;
+            let max_degree = if graph_level == 0 {
+                hnsw_params.m.saturating_mul(2)
+            } else {
+                hnsw_params.m
+            };
+            if degree > max_degree {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    format!(
+                        "graph node {} degree {} at level {} exceeds max 
degree {}",
+                        node, degree, graph_level, max_degree
+                    ),
+                ));
+            }
+            let mut level_neighbors = Vec::with_capacity(degree);
+            for _ in 0..degree {
+                level_neighbors.push(read_u32_vec(bytes, &mut pos)? as usize);
+            }
+            node_levels.push(level_neighbors);
+        }
+        neighbors.push(node_levels);
+    }
+    if pos != bytes.len() {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "trailing bytes in HNSW graph section",
+        ));
+    }
+    Ok(Some(HnswGraph::from_parts(
+        vectors,
+        count,
+        d,
+        metric,
+        levels,
+        neighbors,
+        entry_point,
+        max_observed_level,
+        hnsw_params,
+    )?))
+}
+
+fn write_u32_vec(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
+    if value > u32::MAX as usize {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("value {} exceeds u32 limit", value),
+        ));
+    }
+    buf.extend_from_slice(&(value as u32).to_le_bytes());
+    Ok(())
+}
+
+fn read_u32_vec(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
+    let end = pos.checked_add(4).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidData,
+            "graph section position overflow",
+        )
+    })?;
+    if end > bytes.len() {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "truncated HNSW graph section",
+        ));
+    }
+    let value = u32::from_le_bytes(bytes[*pos..end].try_into().unwrap());
+    *pos = end;
+    Ok(value)
+}
+
+pub(crate) fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
+    out.write_all(&v.to_le_bytes())
+}
+
+pub(crate) fn write_i32_le(out: &mut dyn SeekWrite, v: i32) -> io::Result<()> {
+    out.write_all(&v.to_le_bytes())
+}
+
+pub(crate) fn write_i64_le(out: &mut dyn SeekWrite, v: i64) -> io::Result<()> {
+    out.write_all(&v.to_le_bytes())
+}
+
+pub(crate) fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32]) -> 
io::Result<()> {
+    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
+    out.write_all(&bytes)
+}
+
+pub(crate) fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+    let mut buf = [0u8; 4];
+    reader.read_exact(&mut buf)?;
+    Ok(u32::from_le_bytes(buf))
+}
+
+pub(crate) fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+    let mut buf = [0u8; 4];
+    reader.read_exact(&mut buf)?;
+    Ok(i32::from_le_bytes(buf))
+}
+
+pub(crate) fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+    let mut buf = [0u8; 8];
+    reader.read_exact(&mut buf)?;
+    Ok(i64::from_le_bytes(buf))
+}
+
+pub(crate) fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) -> 
io::Result<Vec<f32>> {
+    let byte_len = count.checked_mul(4).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidData,
+            "f32 section byte length overflow",
+        )
+    })?;
+    let mut buf = vec![0u8; byte_len];
+    reader.read_exact(&mut buf)?;
+    bytes_to_f32_vec(&buf)
+}
+
+pub(crate) fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
+    if !bytes.len().is_multiple_of(4) {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "f32 byte section is not 4-byte aligned",
+        ));
+    }
+    Ok(bytes
+        .chunks_exact(4)
+        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
+        .collect())
+}
+
+pub(crate) fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
+    if val <= 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!("invalid header field {}: {} (must be positive)", field, 
val),
+        ));
+    }
+    Ok(val)
+}
+
+pub(crate) fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
+    if value > i32::MAX as usize {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("{} exceeds i32 length limit: {}", field, value),
+        ));
+    }
+    Ok(value as i32)
+}
+
+pub(crate) fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
+    if value > i64::MAX as usize {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("{} exceeds i64 length limit: {}", field, value),
+        ));
+    }
+    Ok(value as i64)
+}
+
+pub(crate) fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
+    if value > i64::MAX as u64 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("{} exceeds i64 offset limit: {}", field, value),
+        ));
+    }
+    Ok(value as i64)
+}
+
+const MAX_SECTION_ELEMENTS: usize = 1 << 30;
+
+pub(crate) fn checked_section_size(a: usize, b: usize) -> io::Result<usize> {
+    let result = a
+        .checked_mul(b)
+        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "section 
size overflow"))?;
+    if result > MAX_SECTION_ELEMENTS {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!(
+                "section size {} exceeds maximum {}",
+                result, MAX_SECTION_ELEMENTS
+            ),
+        ));
+    }
+    Ok(result)
+}
+
+pub(crate) fn checked_list_offset(offset: i64, list_id: usize) -> 
io::Result<u64> {
+    if offset < 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            format!("negative list offset {} at list {}", offset, list_id),
+        ));
+    }
+    Ok(offset as u64)
+}
+
+pub(crate) fn checked_list_bytes(count: usize, bytes_per_entry: usize) -> 
io::Result<usize> {
+    count
+        .checked_mul(bytes_per_entry)
+        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "list byte 
size overflow"))
+}
+
+pub(crate) fn decode_roaring_filter(bytes: &[u8]) -> 
io::Result<RoaringTreemap> {
+    RoaringTreemap::deserialize_from(bytes).map_err(|e| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("invalid RoaringTreemap filter: {}", e),
+        )
+    })
+}
diff --git a/core/src/ivfhnswflat.rs b/core/src/ivfhnswflat.rs
index e2f7ccf..0d2db81 100644
--- a/core/src/ivfhnswflat.rs
+++ b/core/src/ivfhnswflat.rs
@@ -17,6 +17,7 @@
 
 use crate::distance::{fvec_distance, MetricType};
 use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
 use crate::ivfflat::IVFFlatIndex;
 use crate::ivfpq::RowIdFilter;
 use crate::kmeans;
@@ -114,38 +115,17 @@ impl IVFHNSWFlatIndex {
 
         for qi in 0..nq {
             let query = &processed_queries[qi * self.flat.d..(qi + 1) * 
self.flat.d];
-            let mut heap = TopKHeap::new(k);
-            let force_flat_scan = filter
-                .map(|f| self.count_filtered(&all_probe_indices[qi], f) <= 
ef_search.max(k))
-                .unwrap_or(false);
-
-            for &list_id in &all_probe_indices[qi] {
-                if force_flat_scan {
-                    self.scan_flat_list(query, list_id, filter, &mut heap);
-                    continue;
-                }
-                if let Some(ref graph) = self.graphs[list_id] {
-                    let local_results = graph.search(query, ef_search.max(k), 
ef_search.max(k));
-                    for (local_id, dist) in local_results {
-                        let row_id = self.flat.ids[list_id][local_id];
-                        if let Some(f) = filter {
-                            if !f.contains(row_id) {
-                                continue;
-                            }
-                        }
-                        heap.push(dist, row_id);
-                    }
-                } else {
-                    self.scan_flat_list(query, list_id, filter, &mut heap);
-                }
-            }
-            if filter.is_some() && heap.len() < k {
-                for &list_id in &all_probe_indices[qi] {
-                    self.scan_flat_list(query, list_id, filter, &mut heap);
-                }
-            }
-
-            let sorted = heap.into_sorted();
+            let lists: Vec<_> = all_probe_indices[qi]
+                .iter()
+                .map(|&list_id| HnswSearchList {
+                    ids: self.flat.ids[list_id].as_slice(),
+                    graph: self.graphs[list_id].as_ref(),
+                    payload: list_id,
+                })
+                .collect();
+            let sorted = search_hnsw_lists(query, &lists, k, ef_search, 
filter, |list, heap| {
+                self.scan_flat_list(query, list.payload, filter, heap);
+            });
             let out_base = qi * k;
             for (i, &(dist, id)) in sorted.iter().enumerate() {
                 result_distances[out_base + i] = dist;
@@ -158,18 +138,6 @@ impl IVFHNSWFlatIndex {
         }
     }
 
-    fn count_filtered(&self, probe_indices: &[usize], filter: &dyn 
RowIdFilter) -> usize {
-        probe_indices
-            .iter()
-            .map(|&list_id| {
-                self.flat.ids[list_id]
-                    .iter()
-                    .filter(|&&id| filter.contains(id))
-                    .count()
-            })
-            .sum()
-    }
-
     fn scan_flat_list(
         &self,
         query: &[f32],
diff --git a/core/src/ivfhnswflat_io.rs b/core/src/ivfhnswflat_io.rs
index d504d43..ebd155c 100644
--- a/core/src/ivfhnswflat_io.rs
+++ b/core/src/ivfhnswflat_io.rs
@@ -17,12 +17,18 @@
 
 use crate::distance::{fvec_distance, preprocess_vectors, MetricType};
 use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
+use crate::index_io_util::{
+    bytes_to_f32_vec, checked_list_bytes, checked_list_offset, 
checked_section_size, decode_graph,
+    decode_roaring_filter, encode_graph, read_f32_vec, read_i32_le, 
read_i64_le, read_u32_le,
+    u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32, 
validate_search_inputs,
+    write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
+};
 use crate::io::{SeekRead, SeekWrite};
 use crate::ivfhnswflat::IVFHNSWFlatIndex;
 use crate::ivfpq::RowIdFilter;
 use crate::kmeans;
 use crate::topk::TopKHeap;
-use roaring::RoaringTreemap;
 use std::io;
 
 pub const IVF_HNSW_FLAT_MAGIC: u32 = 0x4948464C; // "IHFL"
@@ -356,47 +362,26 @@ impl<R: SeekRead> IVFHNSWFlatIndexReader<R> {
                 loaded_lists.push(list);
             }
         }
-        let mut heap = TopKHeap::new(k);
-        let force_flat_scan = filter
-            .map(|f| count_filtered(&loaded_lists, f) <= ef_search.max(k))
-            .unwrap_or(false);
-
-        for list in &loaded_lists {
-            if force_flat_scan {
-                scan_flat_list(
-                    &q,
-                    &list.ids,
-                    list.graph.vectors(),
-                    self.d,
-                    self.metric,
-                    filter,
-                    &mut heap,
-                );
-            } else {
-                let local_results = list.graph.search(&q, ef_search.max(k), 
ef_search.max(k));
-                for (local_id, dist) in local_results {
-                    let row_id = list.ids[local_id];
-                    if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
-                        heap.push(dist, row_id);
-                    }
-                }
-            }
-        }
-        if filter.is_some() && heap.len() < k {
-            for list in &loaded_lists {
-                scan_flat_list(
-                    &q,
-                    &list.ids,
-                    list.graph.vectors(),
-                    self.d,
-                    self.metric,
-                    filter,
-                    &mut heap,
-                );
-            }
-        }
-
-        let sorted = heap.into_sorted();
+        let search_lists: Vec<_> = loaded_lists
+            .iter()
+            .map(|list| HnswSearchList {
+                ids: list.ids.as_slice(),
+                graph: Some(&list.graph),
+                payload: list,
+            })
+            .collect();
+        let sorted = search_hnsw_lists(&q, &search_lists, k, ef_search, 
filter, |list, heap| {
+            let list = list.payload;
+            scan_flat_list(
+                &q,
+                &list.ids,
+                list.graph.vectors(),
+                self.d,
+                self.metric,
+                filter,
+                heap,
+            );
+        });
         let mut labels: Vec<i64> = sorted.iter().map(|&(_, id)| id).collect();
         let mut distances: Vec<f32> = sorted.iter().map(|&(dist, _)| 
dist).collect();
         labels.resize(k, -1);
@@ -573,13 +558,6 @@ struct LoadedBatchList {
     graph: HnswGraph,
 }
 
-fn count_filtered(lists: &[GraphList], filter: &dyn RowIdFilter) -> usize {
-    lists
-        .iter()
-        .map(|list| list.ids.iter().filter(|&&id| filter.contains(id)).count())
-        .sum()
-}
-
 fn scan_flat_list(
     query: &[f32],
     ids: &[i64],
@@ -598,257 +576,6 @@ fn scan_flat_list(
     }
 }
 
-fn validate_search_inputs(
-    queries: &[f32],
-    nq: usize,
-    d: usize,
-    k: usize,
-    nprobe: usize,
-) -> io::Result<()> {
-    if nq == 0 {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            "nq must be greater than 0",
-        ));
-    }
-    let expected_query_len = nq.checked_mul(d).ok_or_else(|| {
-        io::Error::new(
-            io::ErrorKind::InvalidInput,
-            "nq * dimension overflows usize",
-        )
-    })?;
-    if queries.len() != expected_query_len {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            format!(
-                "queries length {} does not match nq * dimension {}",
-                queries.len(),
-                expected_query_len
-            ),
-        ));
-    }
-    if k == 0 {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            "k must be greater than 0",
-        ));
-    }
-    if nprobe == 0 {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            "nprobe must be greater than 0",
-        ));
-    }
-    Ok(())
-}
-
-fn encode_graph(graph: Option<&HnswGraph>) -> io::Result<Vec<u8>> {
-    let Some(graph) = graph else {
-        return Ok(Vec::new());
-    };
-    let mut buf = Vec::new();
-    write_u32_vec(&mut buf, graph.len())?;
-    write_u32_vec(&mut buf, graph.entry_point())?;
-    write_u32_vec(&mut buf, graph.max_observed_level())?;
-    for &level in graph.levels() {
-        write_u32_vec(&mut buf, level)?;
-    }
-    for node_levels in graph.neighbors() {
-        for level_neighbors in node_levels {
-            write_u32_vec(&mut buf, level_neighbors.len())?;
-            for &neighbor in level_neighbors {
-                write_u32_vec(&mut buf, neighbor)?;
-            }
-        }
-    }
-    Ok(buf)
-}
-
-fn decode_graph(
-    bytes: &[u8],
-    vectors: Vec<f32>,
-    count: usize,
-    d: usize,
-    metric: MetricType,
-    hnsw_params: HnswBuildParams,
-) -> io::Result<Option<HnswGraph>> {
-    if bytes.is_empty() {
-        return Ok(None);
-    }
-    let mut pos = 0usize;
-    let graph_count = read_u32_vec(bytes, &mut pos)? as usize;
-    if graph_count != count {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidData,
-            format!(
-                "graph count {} does not match list count {}",
-                graph_count, count
-            ),
-        ));
-    }
-    let entry_point = read_u32_vec(bytes, &mut pos)? as usize;
-    let max_observed_level = read_u32_vec(bytes, &mut pos)? as usize;
-    let mut levels = Vec::with_capacity(count);
-    for node in 0..count {
-        let level = read_u32_vec(bytes, &mut pos)? as usize;
-        if level >= hnsw_params.max_level {
-            return Err(io::Error::new(
-                io::ErrorKind::InvalidData,
-                format!(
-                    "graph node {} level {} exceeds max_level {}",
-                    node,
-                    level,
-                    hnsw_params.max_level - 1
-                ),
-            ));
-        }
-        levels.push(level);
-    }
-    let mut neighbors = Vec::with_capacity(count);
-    for (node, &level) in levels.iter().enumerate() {
-        let mut node_levels = Vec::with_capacity(level + 1);
-        for graph_level in 0..=level {
-            let degree = read_u32_vec(bytes, &mut pos)? as usize;
-            let max_degree = if graph_level == 0 {
-                hnsw_params.m.saturating_mul(2)
-            } else {
-                hnsw_params.m
-            };
-            if degree > max_degree {
-                return Err(io::Error::new(
-                    io::ErrorKind::InvalidData,
-                    format!(
-                        "graph node {} degree {} at level {} exceeds max 
degree {}",
-                        node, degree, graph_level, max_degree
-                    ),
-                ));
-            }
-            let mut level_neighbors = Vec::with_capacity(degree);
-            for _ in 0..degree {
-                level_neighbors.push(read_u32_vec(bytes, &mut pos)? as usize);
-            }
-            node_levels.push(level_neighbors);
-        }
-        neighbors.push(node_levels);
-    }
-    if pos != bytes.len() {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidData,
-            "trailing bytes in HNSW graph section",
-        ));
-    }
-    Ok(Some(HnswGraph::from_parts(
-        vectors,
-        count,
-        d,
-        metric,
-        levels,
-        neighbors,
-        entry_point,
-        max_observed_level,
-        hnsw_params,
-    )?))
-}
-
-fn write_u32_vec(buf: &mut Vec<u8>, value: usize) -> io::Result<()> {
-    if value > u32::MAX as usize {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            format!("value {} exceeds u32 limit", value),
-        ));
-    }
-    buf.extend_from_slice(&(value as u32).to_le_bytes());
-    Ok(())
-}
-
-fn read_u32_vec(bytes: &[u8], pos: &mut usize) -> io::Result<u32> {
-    let end = pos.checked_add(4).ok_or_else(|| {
-        io::Error::new(
-            io::ErrorKind::InvalidData,
-            "graph section position overflow",
-        )
-    })?;
-    if end > bytes.len() {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidData,
-            "truncated HNSW graph section",
-        ));
-    }
-    let value = u32::from_le_bytes(bytes[*pos..end].try_into().unwrap());
-    *pos = end;
-    Ok(value)
-}
-
-fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
-    out.write_all(&v.to_le_bytes())
-}
-
-fn write_i32_le(out: &mut dyn SeekWrite, v: i32) -> io::Result<()> {
-    out.write_all(&v.to_le_bytes())
-}
-
-fn write_i64_le(out: &mut dyn SeekWrite, v: i64) -> io::Result<()> {
-    out.write_all(&v.to_le_bytes())
-}
-
-fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32]) -> io::Result<()> {
-    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
-    out.write_all(&bytes)
-}
-
-fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
-    let mut buf = [0u8; 4];
-    reader.read_exact(&mut buf)?;
-    Ok(u32::from_le_bytes(buf))
-}
-
-fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
-    let mut buf = [0u8; 4];
-    reader.read_exact(&mut buf)?;
-    Ok(i32::from_le_bytes(buf))
-}
-
-fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
-    let mut buf = [0u8; 8];
-    reader.read_exact(&mut buf)?;
-    Ok(i64::from_le_bytes(buf))
-}
-
-fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) -> 
io::Result<Vec<f32>> {
-    let byte_len = count.checked_mul(4).ok_or_else(|| {
-        io::Error::new(
-            io::ErrorKind::InvalidData,
-            "f32 section byte length overflow",
-        )
-    })?;
-    let mut buf = vec![0u8; byte_len];
-    reader.read_exact(&mut buf)?;
-    bytes_to_f32_vec(&buf)
-}
-
-fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
-    if !bytes.len().is_multiple_of(4) {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidData,
-            "f32 byte section is not 4-byte aligned",
-        ));
-    }
-    Ok(bytes
-        .chunks_exact(4)
-        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
-        .collect())
-}
-
-fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
-    if val <= 0 {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidData,
-            format!("invalid header field {}: {} (must be positive)", field, 
val),
-        ));
-    }
-    Ok(val)
-}
-
 fn validate_index_shape(index: &IVFHNSWFlatIndex) -> io::Result<()> {
     if index.flat.d == 0 {
         return Err(io::Error::new(
@@ -923,76 +650,6 @@ fn validate_index_shape(index: &IVFHNSWFlatIndex) -> 
io::Result<()> {
     Ok(())
 }
 
-fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
-    if value > i32::MAX as usize {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            format!("{} exceeds i32 length limit: {}", field, value),
-        ));
-    }
-    Ok(value as i32)
-}
-
-fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
-    if value > i64::MAX as usize {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            format!("{} exceeds i64 length limit: {}", field, value),
-        ));
-    }
-    Ok(value as i64)
-}
-
-fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
-    if value > i64::MAX as u64 {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidInput,
-            format!("{} exceeds i64 offset limit: {}", field, value),
-        ));
-    }
-    Ok(value as i64)
-}
-
-const MAX_SECTION_ELEMENTS: usize = 1 << 30;
-
-fn checked_section_size(a: usize, b: usize) -> io::Result<usize> {
-    let result = a.checked_mul(b).ok_or_else(|| {
-        io::Error::new(
-            io::ErrorKind::InvalidData,
-            "section size overflow in IVF-HNSW-FLAT header",
-        )
-    })?;
-    if result > MAX_SECTION_ELEMENTS {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidData,
-            format!(
-                "section size {} exceeds maximum {}",
-                result, MAX_SECTION_ELEMENTS
-            ),
-        ));
-    }
-    Ok(result)
-}
-
-fn checked_list_offset(offset: i64, list_id: usize) -> io::Result<u64> {
-    if offset < 0 {
-        return Err(io::Error::new(
-            io::ErrorKind::InvalidData,
-            format!("negative list offset {} at list {}", offset, list_id),
-        ));
-    }
-    Ok(offset as u64)
-}
-
-fn checked_list_bytes(count: usize, bytes_per_entry: usize) -> 
io::Result<usize> {
-    count.checked_mul(bytes_per_entry).ok_or_else(|| {
-        io::Error::new(
-            io::ErrorKind::InvalidData,
-            "IVF-HNSW-FLAT list byte size overflow",
-        )
-    })
-}
-
 fn list_payload_len(count: usize, d: usize, graph_bytes_len: usize) -> 
io::Result<usize> {
     let id_bytes = checked_list_bytes(count, 8)?;
     let vector_bytes = checked_list_bytes(
@@ -1015,19 +672,11 @@ fn list_payload_len(count: usize, d: usize, 
graph_bytes_len: usize) -> io::Resul
         })
 }
 
-fn decode_roaring_filter(bytes: &[u8]) -> io::Result<RoaringTreemap> {
-    RoaringTreemap::deserialize_from(bytes).map_err(|e| {
-        io::Error::new(
-            io::ErrorKind::InvalidInput,
-            format!("invalid RoaringTreemap filter: {}", e),
-        )
-    })
-}
-
 #[cfg(test)]
 mod tests {
     use crate::distance::MetricType;
     use crate::hnsw::HnswBuildParams;
+    use crate::index_io_util::decode_graph;
     use crate::io::{PosWriter, SeekRead};
     use crate::ivfhnswflat::IVFHNSWFlatIndex;
     use crate::ivfhnswflat_io::{
@@ -1353,8 +1002,8 @@ mod tests {
         append_u32(&mut graph_bytes, 0);
         append_u32(&mut graph_bytes, params.max_level as u32 + 1);
 
-        let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, 
MetricType::L2, params)
-            .unwrap_err();
+        let err =
+            decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2, 
params).unwrap_err();
 
         assert_eq!(err.kind(), io::ErrorKind::InvalidData);
         assert!(err.to_string().contains("level"));
@@ -1374,8 +1023,8 @@ mod tests {
         append_u32(&mut graph_bytes, 0);
         append_u32(&mut graph_bytes, (params.m * 2) as u32 + 1);
 
-        let err = super::decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, 
MetricType::L2, params)
-            .unwrap_err();
+        let err =
+            decode_graph(&graph_bytes, vec![0.0, 0.0], 1, 2, MetricType::L2, 
params).unwrap_err();
 
         assert_eq!(err.kind(), io::ErrorKind::InvalidData);
         assert!(err.to_string().contains("degree"));
diff --git a/core/src/ivfhnswsq.rs b/core/src/ivfhnswsq.rs
new file mode 100644
index 0000000..a945388
--- /dev/null
+++ b/core/src/ivfhnswsq.rs
@@ -0,0 +1,289 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{preprocess_vectors, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans::{self, KMeansConfig};
+use crate::sq::ScalarQuantizer;
+use crate::topk::TopKHeap;
+use std::io;
+
+pub struct IVFHNSWSQIndex {
+    pub d: usize,
+    pub nlist: usize,
+    pub metric: MetricType,
+    pub quantizer_centroids: Vec<f32>,
+    pub sq: ScalarQuantizer,
+    pub ids: Vec<Vec<i64>>,
+    pub codes: Vec<Vec<u8>>,
+    pub graphs: Vec<Option<HnswGraph>>,
+    pub hnsw_params: HnswBuildParams,
+}
+
+impl IVFHNSWSQIndex {
+    pub fn new(d: usize, nlist: usize, metric: MetricType, hnsw_params: 
HnswBuildParams) -> Self {
+        Self {
+            d,
+            nlist,
+            metric,
+            quantizer_centroids: Vec::new(),
+            sq: ScalarQuantizer::new(d),
+            ids: vec![Vec::new(); nlist],
+            codes: vec![Vec::new(); nlist],
+            graphs: vec![None; nlist],
+            hnsw_params,
+        }
+    }
+
+    pub fn train(&mut self, data: &[f32], n: usize) {
+        let processed = self.preprocess_vectors(data, n);
+        self.quantizer_centroids =
+            kmeans::kmeans_train(&KMeansConfig::default(), &processed, n, 
self.d, self.nlist);
+        self.sq.train(&processed, n);
+    }
+
+    pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
+        let processed = self.preprocess_vectors(data, n);
+        let code_size = self.sq.code_size();
+        let mut encoded = vec![0u8; n * code_size];
+        self.sq.encode_batch(&processed, n, &mut encoded);
+
+        for i in 0..n {
+            let vector = &processed[i * self.d..(i + 1) * self.d];
+            let list_id =
+                kmeans::find_nearest(vector, &self.quantizer_centroids, 
self.nlist, self.d);
+            self.ids[list_id].push(ids[i]);
+            self.codes[list_id].extend_from_slice(&encoded[i * code_size..(i + 
1) * code_size]);
+        }
+        self.graphs.fill(None);
+    }
+
+    pub fn total_vectors(&self) -> usize {
+        self.ids.iter().map(Vec::len).sum()
+    }
+
+    pub fn build_graphs(&mut self) -> io::Result<()> {
+        for list_id in 0..self.nlist {
+            let count = self.ids[list_id].len();
+            self.graphs[list_id] = if count == 0 {
+                None
+            } else {
+                let mut vectors = vec![0.0f32; count * self.d];
+                self.sq
+                    .decode_batch(&self.codes[list_id], count, &mut vectors);
+                Some(HnswGraph::build(
+                    &vectors,
+                    count,
+                    self.d,
+                    self.metric,
+                    self.hnsw_params,
+                )?)
+            };
+        }
+        Ok(())
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    pub fn search(
+        &self,
+        queries: &[f32],
+        nq: usize,
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        result_distances: &mut [f32],
+        result_labels: &mut [i64],
+    ) {
+        self.search_with_filter(
+            queries,
+            nq,
+            k,
+            nprobe,
+            ef_search,
+            None,
+            result_distances,
+            result_labels,
+        );
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    pub fn search_with_filter(
+        &self,
+        queries: &[f32],
+        nq: usize,
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        filter: Option<&dyn RowIdFilter>,
+        result_distances: &mut [f32],
+        result_labels: &mut [i64],
+    ) {
+        let processed_queries = self.preprocess_vectors(queries, nq);
+        let (all_probe_indices, _) = kmeans::find_topk_batch(
+            &processed_queries,
+            nq,
+            &self.quantizer_centroids,
+            self.nlist,
+            self.d,
+            nprobe,
+        );
+
+        for qi in 0..nq {
+            let query = &processed_queries[qi * self.d..(qi + 1) * self.d];
+            let lists: Vec<_> = all_probe_indices[qi]
+                .iter()
+                .map(|&list_id| HnswSearchList {
+                    ids: self.ids[list_id].as_slice(),
+                    graph: self.graphs[list_id].as_ref(),
+                    payload: list_id,
+                })
+                .collect();
+            let sorted = search_hnsw_lists(query, &lists, k, ef_search, 
filter, |list, heap| {
+                self.scan_sq_list(query, list.payload, filter, heap);
+            });
+            let out_base = qi * k;
+            for (i, &(dist, id)) in sorted.iter().enumerate() {
+                result_distances[out_base + i] = dist;
+                result_labels[out_base + i] = id;
+            }
+            for i in sorted.len()..k {
+                result_distances[out_base + i] = f32::MAX;
+                result_labels[out_base + i] = -1;
+            }
+        }
+    }
+
+    pub(crate) fn preprocess_vectors(&self, data: &[f32], n: usize) -> 
Vec<f32> {
+        preprocess_vectors(data, n, self.d, self.metric)
+    }
+
+    fn scan_sq_list(
+        &self,
+        query: &[f32],
+        list_id: usize,
+        filter: Option<&dyn RowIdFilter>,
+        heap: &mut TopKHeap,
+    ) {
+        let context = self.sq.distance_context(query, self.metric);
+        let code_size = self.sq.code_size();
+        for (local_id, &row_id) in self.ids[list_id].iter().enumerate() {
+            if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
+                continue;
+            }
+            let code = &self.codes[list_id][local_id * code_size..(local_id + 
1) * code_size];
+            heap.push(
+                self.sq.distance_to_code_with_context(query, code, context),
+                row_id,
+            );
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::hnsw::HnswBuildParams;
+    use std::collections::HashSet;
+
+    #[test]
+    fn test_ivfhnswsq_recalls_query_vector() {
+        let d = 4;
+        let nlist = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| {
+                let cluster = (i % nlist) as f32 * 100.0;
+                [
+                    cluster + i as f32 * 2.0,
+                    10.0 + i as f32,
+                    20.0 + i as f32,
+                    30.0 + i as f32,
+                ]
+            })
+            .collect();
+        let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+        let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let query_id = 23;
+        let mut distances = vec![0.0; 5];
+        let mut labels = vec![0; 5];
+        index.search(
+            &data[query_id * d..(query_id + 1) * d],
+            1,
+            5,
+            nlist,
+            32,
+            &mut distances,
+            &mut labels,
+        );
+
+        assert_eq!(labels[0], ids[query_id]);
+        assert!(distances[0].is_finite());
+    }
+
+    #[test]
+    fn test_ivfhnswsq_without_built_graphs_falls_back_to_sq_scan() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+        let ids = vec![10, 11, 12];
+        let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 3);
+        index.add(&data, &ids, 3);
+
+        let mut distances = vec![0.0; 2];
+        let mut labels = vec![0; 2];
+        index.search(&[0.0, 0.0], 1, 2, nlist, 8, &mut distances, &mut labels);
+
+        assert_eq!(labels[0], 10);
+    }
+
+    #[test]
+    fn test_ivfhnswsq_search_with_filter() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+        let ids = vec![10, 11, 12];
+        let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 3);
+        index.add(&data, &ids, 3);
+        index.build_graphs().unwrap();
+
+        let filter: HashSet<i64> = [12].into_iter().collect();
+        let mut distances = vec![0.0; 2];
+        let mut labels = vec![0; 2];
+        index.search_with_filter(
+            &[0.0, 0.0],
+            1,
+            2,
+            nlist,
+            8,
+            Some(&filter),
+            &mut distances,
+            &mut labels,
+        );
+
+        assert_eq!(labels[0], 12);
+        assert_eq!(labels[1], -1);
+    }
+}
diff --git a/core/src/ivfhnswsq_io.rs b/core/src/ivfhnswsq_io.rs
new file mode 100644
index 0000000..cb06eef
--- /dev/null
+++ b/core/src/ivfhnswsq_io.rs
@@ -0,0 +1,851 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{preprocess_vectors, MetricType};
+use crate::hnsw::{HnswBuildParams, HnswGraph};
+use crate::hnsw_search::{search_hnsw_lists, HnswSearchList};
+use crate::index_io_util::{
+    checked_list_bytes, checked_list_offset, checked_section_size, 
decode_graph,
+    decode_roaring_filter, encode_graph, read_f32_vec, read_i32_le, 
read_i64_le, read_u32_le,
+    u64_to_i64, usize_to_i32, usize_to_i64, validate_positive_i32, 
validate_search_inputs,
+    write_f32_slice, write_i32_le, write_i64_le, write_u32_le,
+};
+use crate::io::{SeekRead, SeekWrite};
+use crate::ivfhnswsq::IVFHNSWSQIndex;
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans;
+use crate::sq::ScalarQuantizer;
+use crate::topk::TopKHeap;
+use std::io;
+
+pub const IVF_HNSW_SQ_MAGIC: u32 = 0x49485351; // "IHSQ"
+pub const IVF_HNSW_SQ_VERSION: u32 = 1;
+pub const IVF_HNSW_SQ_HEADER_SIZE: usize = 64;
+
+pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out: &mut dyn SeekWrite) 
-> io::Result<()> {
+    validate_index_shape(index)?;
+    let total_vectors = index.ids.iter().try_fold(0i64, |sum, ids| {
+        let count = usize_to_i64(ids.len(), "total vector count")?;
+        sum.checked_add(count).ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "total vector count exceeds i64 length limit",
+            )
+        })
+    })?;
+    let graph_bytes: Vec<Vec<u8>> = (0..index.nlist)
+        .map(|list_id| {
+            if index.ids[list_id].is_empty() {
+                Ok(Vec::new())
+            } else {
+                encode_graph(index.graphs[list_id].as_ref())
+            }
+        })
+        .collect::<io::Result<_>>()?;
+
+    write_u32_le(out, IVF_HNSW_SQ_MAGIC)?;
+    write_u32_le(out, IVF_HNSW_SQ_VERSION)?;
+    write_i32_le(out, usize_to_i32(index.d, "dimension")?)?;
+    write_i32_le(out, usize_to_i32(index.nlist, "nlist")?)?;
+    write_u32_le(out, index.metric as u32)?;
+    write_i64_le(out, total_vectors)?;
+    let params = index.hnsw_params.sanitized();
+    write_i32_le(out, usize_to_i32(params.m, "hnsw m")?)?;
+    write_i32_le(
+        out,
+        usize_to_i32(params.ef_construction, "hnsw ef_construction")?,
+    )?;
+    write_i32_le(out, usize_to_i32(params.max_level, "hnsw max_level")?)?;
+    out.write_all(&index.sq.min.to_le_bytes())?;
+    out.write_all(&index.sq.max.to_le_bytes())?;
+    out.write_all(&[0u8; 16])?;
+
+    write_f32_slice(out, &index.quantizer_centroids)?;
+
+    let offset_table_size = index.nlist.checked_mul(24).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "IVF-HNSW-SQ offset table size overflow",
+        )
+    })?;
+    let data_start = out
+        .pos()
+        .checked_add(offset_table_size as u64)
+        .ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "IVF-HNSW-SQ data start offset overflow",
+            )
+        })?;
+    let mut list_offsets = vec![0i64; index.nlist];
+    let mut list_counts = vec![0i32; index.nlist];
+    let mut list_graph_bytes_lens = vec![0i32; index.nlist];
+    let mut current_offset = data_start;
+
+    for list_id in 0..index.nlist {
+        list_offsets[list_id] = u64_to_i64(current_offset, "list offset")?;
+        let count = index.ids[list_id].len();
+        list_counts[list_id] = usize_to_i32(count, "list count")?;
+        list_graph_bytes_lens[list_id] = 
usize_to_i32(graph_bytes[list_id].len(), "graph bytes")?;
+        if count > 0 {
+            let payload_len =
+                list_payload_len(count, index.sq.code_size(), 
graph_bytes[list_id].len())?;
+            current_offset = current_offset
+                .checked_add(payload_len as u64)
+                .ok_or_else(|| {
+                    io::Error::new(io::ErrorKind::InvalidInput, "IVF-HNSW-SQ 
offset overflow")
+                })?;
+        }
+    }
+
+    for list_id in 0..index.nlist {
+        write_i64_le(out, list_offsets[list_id])?;
+        write_i32_le(out, list_counts[list_id])?;
+        write_i32_le(out, list_graph_bytes_lens[list_id])?;
+        write_i64_le(out, 0)?;
+    }
+
+    for list_id in 0..index.nlist {
+        if index.ids[list_id].is_empty() {
+            continue;
+        }
+        for &id in &index.ids[list_id] {
+            write_i64_le(out, id)?;
+        }
+        out.write_all(&index.codes[list_id])?;
+        out.write_all(&graph_bytes[list_id])?;
+    }
+
+    Ok(())
+}
+
+pub struct IVFHNSWSQIndexReader<R: SeekRead> {
+    reader: R,
+    pub d: usize,
+    pub nlist: usize,
+    pub metric: MetricType,
+    pub total_vectors: i64,
+    pub hnsw_params: HnswBuildParams,
+    pub sq: ScalarQuantizer,
+    pub quantizer_centroids: Vec<f32>,
+    pub list_offsets: Vec<i64>,
+    pub list_counts: Vec<i32>,
+    pub list_graph_bytes_lens: Vec<i32>,
+    loaded: bool,
+}
+
+impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
+    pub fn open(mut reader: R) -> io::Result<Self> {
+        reader.seek(0)?;
+
+        let magic = read_u32_le(&mut reader)?;
+        if magic != IVF_HNSW_SQ_MAGIC {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("Invalid IVF_HNSW_SQ magic: 0x{:08X}", magic),
+            ));
+        }
+        let version = read_u32_le(&mut reader)?;
+        if version != IVF_HNSW_SQ_VERSION {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("Unsupported IVF_HNSW_SQ version: {}", version),
+            ));
+        }
+
+        let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as 
usize;
+        let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")? 
as usize;
+        let metric_code = read_u32_le(&mut reader)?;
+        let metric = MetricType::from_code(metric_code).ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("Unknown metric type: {}", metric_code),
+            )
+        })?;
+        let total_vectors = read_i64_le(&mut reader)?;
+        let hnsw_params = HnswBuildParams {
+            m: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw m")? as 
usize,
+            ef_construction: validate_positive_i32(
+                read_i32_le(&mut reader)?,
+                "hnsw ef_construction",
+            )? as usize,
+            max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw 
max_level")? as usize,
+        }
+        .sanitized();
+        let mut min_bytes = [0u8; 4];
+        let mut max_bytes = [0u8; 4];
+        reader.read_exact(&mut min_bytes)?;
+        reader.read_exact(&mut max_bytes)?;
+        let sq_min = f32::from_le_bytes(min_bytes);
+        let sq_max = f32::from_le_bytes(max_bytes);
+        if !sq_min.is_finite() || !sq_max.is_finite() {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                "SQ bounds must be finite",
+            ));
+        }
+        let mut reserved = [0u8; 16];
+        reader.read_exact(&mut reserved)?;
+
+        Ok(Self {
+            reader,
+            d,
+            nlist,
+            metric,
+            total_vectors,
+            hnsw_params,
+            sq: ScalarQuantizer::with_bounds(d, sq_min, sq_max),
+            quantizer_centroids: Vec::new(),
+            list_offsets: Vec::new(),
+            list_counts: Vec::new(),
+            list_graph_bytes_lens: Vec::new(),
+            loaded: false,
+        })
+    }
+
+    pub fn ensure_loaded(&mut self) -> io::Result<()> {
+        if self.loaded {
+            return Ok(());
+        }
+
+        self.reader.seek(IVF_HNSW_SQ_HEADER_SIZE as u64)?;
+        self.quantizer_centroids =
+            read_f32_vec(&mut self.reader, checked_section_size(self.nlist, 
self.d)?)?;
+        self.list_offsets = vec![0; self.nlist];
+        self.list_counts = vec![0; self.nlist];
+        self.list_graph_bytes_lens = vec![0; self.nlist];
+        for list_id in 0..self.nlist {
+            self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
+            let count = read_i32_le(&mut self.reader)?;
+            if count < 0 {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    format!("negative list count {} at list {}", count, 
list_id),
+                ));
+            }
+            self.list_counts[list_id] = count;
+            let graph_bytes_len = read_i32_le(&mut self.reader)?;
+            if graph_bytes_len < 0 {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidData,
+                    format!(
+                        "negative graph_bytes_len {} at list {}",
+                        graph_bytes_len, list_id
+                    ),
+                ));
+            }
+            self.list_graph_bytes_lens[list_id] = graph_bytes_len;
+            let _reserved = read_i64_le(&mut self.reader)?;
+        }
+
+        self.loaded = true;
+        Ok(())
+    }
+
+    pub fn read_inverted_list(
+        &mut self,
+        list_id: usize,
+    ) -> io::Result<(Vec<i64>, Vec<u8>, Option<HnswGraph>)> {
+        let Some(list) = self.read_graph_list(list_id)? else {
+            return Ok((Vec::new(), Vec::new(), None));
+        };
+        Ok((list.ids, list.codes, Some(list.graph)))
+    }
+
+    fn read_graph_list(&mut self, list_id: usize) -> 
io::Result<Option<GraphList>> {
+        self.ensure_loaded()?;
+        if list_id >= self.nlist {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!("list_id {} out of range (nlist={})", list_id, 
self.nlist),
+            ));
+        }
+        let count = self.list_counts[list_id] as usize;
+        if count == 0 {
+            return Ok(None);
+        }
+
+        let offset = checked_list_offset(self.list_offsets[list_id], list_id)?;
+        let graph_bytes_len = self.list_graph_bytes_lens[list_id] as usize;
+        if graph_bytes_len == 0 {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("list {} is missing HNSW graph", list_id),
+            ));
+        }
+        let payload_len = list_payload_len(count, self.sq.code_size(), 
graph_bytes_len)?;
+        let mut payload = vec![0u8; payload_len];
+        self.reader.pread(offset, &mut payload)?;
+
+        let ids_bytes_len = checked_list_bytes(count, 8)?;
+        let code_size = self.sq.code_size();
+        let codes_bytes_len = checked_list_bytes(count, code_size)?;
+        let ids = payload[..ids_bytes_len]
+            .chunks_exact(8)
+            .map(|c| i64::from_le_bytes(c.try_into().unwrap()))
+            .collect();
+        let codes = payload[ids_bytes_len..ids_bytes_len + 
codes_bytes_len].to_vec();
+        let mut vectors = vec![0.0f32; count * self.d];
+        self.sq.decode_batch(&codes, count, &mut vectors);
+        let graph = decode_graph(
+            &payload[ids_bytes_len + codes_bytes_len..],
+            vectors,
+            count,
+            self.d,
+            self.metric,
+            self.hnsw_params,
+        )?;
+        let graph = graph.ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                format!("list {} is missing HNSW graph", list_id),
+            )
+        })?;
+        Ok(Some(GraphList { ids, codes, graph }))
+    }
+
+    pub fn search(
+        &mut self,
+        query: &[f32],
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+    ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        self.search_with_filter(query, k, nprobe, ef_search, None)
+    }
+
+    pub fn search_with_filter(
+        &mut self,
+        query: &[f32],
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        filter: Option<&dyn RowIdFilter>,
+    ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        self.ensure_loaded()?;
+        validate_search_inputs(query, 1, self.d, k, nprobe)?;
+
+        let q = preprocess_vectors(query, 1, self.d, self.metric);
+        let (probe_indices, _) =
+            kmeans::find_topk(&q, &self.quantizer_centroids, self.nlist, 
self.d, nprobe);
+        let mut loaded_lists = Vec::with_capacity(probe_indices.len());
+        for list_id in probe_indices {
+            if let Some(list) = self.read_graph_list(list_id)? {
+                loaded_lists.push(list);
+            }
+        }
+        let search_lists: Vec<_> = loaded_lists
+            .iter()
+            .map(|list| HnswSearchList {
+                ids: list.ids.as_slice(),
+                graph: Some(&list.graph),
+                payload: list,
+            })
+            .collect();
+        let sorted = search_hnsw_lists(&q, &search_lists, k, ef_search, 
filter, |list, heap| {
+            let list = list.payload;
+            scan_sq_list(
+                &q,
+                &list.ids,
+                &list.codes,
+                &self.sq,
+                self.metric,
+                filter,
+                heap,
+            );
+        });
+        let mut labels: Vec<i64> = sorted.iter().map(|&(_, id)| id).collect();
+        let mut distances: Vec<f32> = sorted.iter().map(|&(dist, _)| 
dist).collect();
+        labels.resize(k, -1);
+        distances.resize(k, f32::MAX);
+        Ok((labels, distances))
+    }
+
+    pub fn search_with_roaring_filter(
+        &mut self,
+        query: &[f32],
+        k: usize,
+        nprobe: usize,
+        ef_search: usize,
+        roaring_filter_bytes: &[u8],
+    ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        let filter = decode_roaring_filter(roaring_filter_bytes)?;
+        self.search_with_filter(query, k, nprobe, ef_search, Some(&filter))
+    }
+}
+
+pub fn search_batch_ivfhnswsq_reader<R: SeekRead>(
+    reader: &mut IVFHNSWSQIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    ef_search: usize,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    search_batch_ivfhnswsq_reader_filter(reader, queries, nq, k, nprobe, 
ef_search, None)
+}
+
+pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
+    reader: &mut IVFHNSWSQIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    ef_search: usize,
+    filter: Option<&dyn RowIdFilter>,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    reader.ensure_loaded()?;
+    validate_search_inputs(queries, nq, reader.d, k, nprobe)?;
+
+    let processed = preprocess_vectors(queries, nq, reader.d, reader.metric);
+    let (all_probe_indices, _) = kmeans::find_topk_batch(
+        &processed,
+        nq,
+        &reader.quantizer_centroids,
+        reader.nlist,
+        reader.d,
+        nprobe,
+    );
+
+    let mut list_to_queries = vec![Vec::new(); reader.nlist];
+    let mut unique_lists = Vec::new();
+    for (qi, probe_indices) in all_probe_indices.iter().enumerate() {
+        for &list_id in probe_indices {
+            if list_to_queries[list_id].is_empty() {
+                unique_lists.push(list_id);
+            }
+            list_to_queries[list_id].push(qi);
+        }
+    }
+
+    let mut heaps: Vec<TopKHeap> = (0..nq).map(|_| TopKHeap::new(k)).collect();
+    let mut query_filtered_counts = vec![0usize; nq];
+    let mut loaded_lists = Vec::with_capacity(unique_lists.len());
+    for list_id in unique_lists {
+        let Some(list) = reader.read_graph_list(list_id)? else {
+            continue;
+        };
+        if let Some(f) = filter {
+            let filtered = list.ids.iter().filter(|&&id| 
f.contains(id)).count();
+            for &qi in &list_to_queries[list_id] {
+                query_filtered_counts[qi] = query_filtered_counts[qi]
+                    .checked_add(filtered)
+                    .ok_or_else(|| {
+                        io::Error::new(
+                            io::ErrorKind::InvalidData,
+                            "filtered vector count overflows usize",
+                        )
+                    })?;
+            }
+        }
+        loaded_lists.push(LoadedBatchList {
+            query_ids: std::mem::take(&mut list_to_queries[list_id]),
+            ids: list.ids,
+            codes: list.codes,
+            graph: list.graph,
+        });
+    }
+
+    for list in &loaded_lists {
+        for &qi in &list.query_ids {
+            let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+            let force_sq_scan = filter
+                .map(|_| query_filtered_counts[qi] <= ef_search.max(k))
+                .unwrap_or(false);
+            if force_sq_scan {
+                scan_sq_list(
+                    query,
+                    &list.ids,
+                    &list.codes,
+                    &reader.sq,
+                    reader.metric,
+                    filter,
+                    &mut heaps[qi],
+                );
+            } else {
+                let local_results = list.graph.search(query, ef_search.max(k), 
ef_search.max(k));
+                for (local_id, dist) in local_results {
+                    let row_id = list.ids[local_id];
+                    if filter.map(|f| f.contains(row_id)).unwrap_or(true) {
+                        heaps[qi].push(dist, row_id);
+                    }
+                }
+            }
+        }
+    }
+    if filter.is_some() {
+        for list in &loaded_lists {
+            for &qi in &list.query_ids {
+                if heaps[qi].len() >= k {
+                    continue;
+                }
+                if query_filtered_counts[qi] <= ef_search.max(k) {
+                    continue;
+                }
+                let query = &processed[qi * reader.d..(qi + 1) * reader.d];
+                scan_sq_list(
+                    query,
+                    &list.ids,
+                    &list.codes,
+                    &reader.sq,
+                    reader.metric,
+                    filter,
+                    &mut heaps[qi],
+                );
+            }
+        }
+    }
+
+    let mut result_ids = vec![-1i64; nq * k];
+    let mut result_dists = vec![f32::MAX; nq * k];
+    for qi in 0..nq {
+        let sorted = std::mem::replace(&mut heaps[qi], 
TopKHeap::new(0)).into_sorted();
+        let base = qi * k;
+        for (i, &(dist, id)) in sorted.iter().enumerate() {
+            result_ids[base + i] = id;
+            result_dists[base + i] = dist;
+        }
+    }
+
+    Ok((result_ids, result_dists))
+}
+
+pub fn search_batch_ivfhnswsq_reader_roaring_filter<R: SeekRead>(
+    reader: &mut IVFHNSWSQIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    ef_search: usize,
+    roaring_filter_bytes: &[u8],
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    let filter = decode_roaring_filter(roaring_filter_bytes)?;
+    search_batch_ivfhnswsq_reader_filter(reader, queries, nq, k, nprobe, 
ef_search, Some(&filter))
+}
+
+struct GraphList {
+    ids: Vec<i64>,
+    codes: Vec<u8>,
+    graph: HnswGraph,
+}
+
+struct LoadedBatchList {
+    query_ids: Vec<usize>,
+    ids: Vec<i64>,
+    codes: Vec<u8>,
+    graph: HnswGraph,
+}
+
+fn scan_sq_list(
+    query: &[f32],
+    ids: &[i64],
+    codes: &[u8],
+    sq: &ScalarQuantizer,
+    metric: MetricType,
+    filter: Option<&dyn RowIdFilter>,
+    heap: &mut TopKHeap,
+) {
+    let context = sq.distance_context(query, metric);
+    let code_size = sq.code_size();
+    for (local_id, &row_id) in ids.iter().enumerate() {
+        if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
+            continue;
+        }
+        let code = &codes[local_id * code_size..(local_id + 1) * code_size];
+        heap.push(
+            sq.distance_to_code_with_context(query, code, context),
+            row_id,
+        );
+    }
+}
+
+fn validate_index_shape(index: &IVFHNSWSQIndex) -> io::Result<()> {
+    if index.d == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "dimension must be greater than 0",
+        ));
+    }
+    if index.nlist == 0 {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nlist must be greater than 0",
+        ));
+    }
+    if index.sq.d != index.d {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "SQ dimension does not match index dimension",
+        ));
+    }
+    let centroid_len = checked_section_size(index.nlist, index.d)?;
+    if index.quantizer_centroids.len() != centroid_len {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!(
+                "quantizer centroid length {} does not match nlist*d {}",
+                index.quantizer_centroids.len(),
+                centroid_len
+            ),
+        ));
+    }
+    if index.ids.len() != index.nlist
+        || index.codes.len() != index.nlist
+        || index.graphs.len() != index.nlist
+    {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "inverted list count does not match nlist",
+        ));
+    }
+    for list_id in 0..index.nlist {
+        let count = index.ids[list_id].len();
+        let expected_codes_len = checked_list_bytes(count, 
index.sq.code_size())?;
+        if index.codes[list_id].len() != expected_codes_len {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "list {} SQ code length {} does not match count*d {}",
+                    list_id,
+                    index.codes[list_id].len(),
+                    expected_codes_len
+                ),
+            ));
+        }
+        match &index.graphs[list_id] {
+            Some(graph) if count > 0 => {
+                let mut decoded = vec![0.0f32; count * index.d];
+                index
+                    .sq
+                    .decode_batch(&index.codes[list_id], count, &mut decoded);
+                if graph.len() != count || graph.vectors() != 
decoded.as_slice() {
+                    return Err(io::Error::new(
+                        io::ErrorKind::InvalidInput,
+                        format!("list {} graph does not match SQ code 
storage", list_id),
+                    ));
+                }
+            }
+            Some(_) => {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidInput,
+                    format!("list {} has graph for an empty list", list_id),
+                ));
+            }
+            None if count == 0 => {}
+            None => {
+                return Err(io::Error::new(
+                    io::ErrorKind::InvalidInput,
+                    format!(
+                        "list {} is missing HNSW graph; call build_graphs 
first",
+                        list_id
+                    ),
+                ));
+            }
+        }
+    }
+    Ok(())
+}
+
+fn list_payload_len(count: usize, code_size: usize, graph_bytes_len: usize) -> 
io::Result<usize> {
+    let id_bytes = checked_list_bytes(count, 8)?;
+    let code_bytes = checked_list_bytes(count, code_size)?;
+    id_bytes
+        .checked_add(code_bytes)
+        .and_then(|len| len.checked_add(graph_bytes_len))
+        .ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidInput,
+                "IVF-HNSW-SQ list payload length overflow",
+            )
+        })
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::hnsw::HnswBuildParams;
+    use crate::io::PosWriter;
+    use roaring::RoaringTreemap;
+    use std::io::Cursor;
+
+    #[test]
+    fn test_ivfhnswsq_write_read_search_roundtrip() {
+        let d = 4;
+        let nlist = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| {
+                let cluster = (i % nlist) as f32 * 100.0;
+                [
+                    cluster + i as f32 * 2.0,
+                    10.0 + i as f32,
+                    20.0 + i as f32,
+                    30.0 + i as f32,
+                ]
+            })
+            .collect();
+        let ids: Vec<i64> = (10_000..10_000 + n as i64).collect();
+
+        let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_ivfhnswsq_index(&index, &mut writer).unwrap();
+
+        let mut reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+        let query_id = 23;
+        let (labels, distances) = reader
+            .search(&data[query_id * d..(query_id + 1) * d], 5, nlist, 32)
+            .unwrap();
+
+        assert_eq!(labels[0], ids[query_id]);
+        assert!(distances[0].is_finite());
+    }
+
+    #[test]
+    fn test_ivfhnswsq_reader_search_with_roaring_filter() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+        let ids = vec![10, 11, 12];
+        let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 3);
+        index.add(&data, &ids, 3);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+        let mut reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+
+        let mut filter = RoaringTreemap::new();
+        filter.insert(12);
+        let mut filter_bytes = Vec::new();
+        filter.serialize_into(&mut filter_bytes).unwrap();
+
+        let (labels, _) = reader
+            .search_with_roaring_filter(&[0.0, 0.0], 2, nlist, 8, 
&filter_bytes)
+            .unwrap();
+
+        assert_eq!(labels, vec![12, -1]);
+    }
+
+    #[test]
+    fn test_ivfhnswsq_write_read_search_roundtrip_cosine() {
+        let d = 3;
+        let nlist = 2;
+        let data = vec![1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9, 
0.1];
+        let ids = vec![10, 11, 12, 13];
+        let mut index =
+            IVFHNSWSQIndex::new(d, nlist, MetricType::Cosine, 
HnswBuildParams::default());
+        index.train(&data, 4);
+        index.add(&data, &ids, 4);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+
+        let mut reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+        let (labels, distances) = reader.search(&[1.0, 0.0, 0.0], 2, nlist, 
16).unwrap();
+
+        assert_eq!(labels[0], 10);
+        assert!(distances[0].is_finite());
+    }
+
+    #[test]
+    fn test_ivfhnswsq_batch_matches_single_search() {
+        let d = 4;
+        let nlist = 4;
+        let n = 128;
+        let data: Vec<f32> = (0..n)
+            .flat_map(|i| {
+                let cluster = (i % nlist) as f32 * 100.0;
+                [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+            })
+            .collect();
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+        let queries = [&data[7 * d..8 * d], &data[23 * d..24 * d]].concat();
+
+        let mut batch_reader = 
IVFHNSWSQIndexReader::open(Cursor::new(buf.clone())).unwrap();
+        let (batch_labels, batch_distances) =
+            search_batch_ivfhnswsq_reader(&mut batch_reader, &queries, 2, 3, 
nlist, 32).unwrap();
+
+        for qi in 0..2 {
+            let mut single_reader = 
IVFHNSWSQIndexReader::open(Cursor::new(buf.clone())).unwrap();
+            let (single_labels, single_distances) = single_reader
+                .search(&queries[qi * d..(qi + 1) * d], 3, nlist, 32)
+                .unwrap();
+            assert_eq!(
+                &batch_labels[qi * 3..(qi + 1) * 3],
+                single_labels.as_slice()
+            );
+            assert_eq!(
+                &batch_distances[qi * 3..(qi + 1) * 3],
+                single_distances.as_slice()
+            );
+        }
+    }
+
+    #[test]
+    fn test_ivfhnswsq_write_requires_graphs() {
+        let mut index = IVFHNSWSQIndex::new(2, 1, MetricType::L2, 
HnswBuildParams::default());
+        let data = vec![0.0, 0.0];
+        index.train(&data, 1);
+        index.add(&data, &[1], 1);
+
+        let mut buf = Vec::new();
+        let err = write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut 
buf)).unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(err.to_string().contains("build_graphs"));
+    }
+
+    #[test]
+    fn test_ivfhnswsq_writer_rejects_stale_graph() {
+        let mut index = IVFHNSWSQIndex::new(2, 1, MetricType::L2, 
HnswBuildParams::default());
+        let data = vec![0.0, 0.0, 1.0, 0.0];
+        index.train(&data, 2);
+        index.add(&data, &[10, 11], 2);
+        index.build_graphs().unwrap();
+        index.graphs[0] = Some(
+            HnswGraph::build(
+                &[10.0, 10.0, 11.0, 11.0],
+                2,
+                2,
+                MetricType::L2,
+                HnswBuildParams::default(),
+            )
+            .unwrap(),
+        );
+
+        let mut buf = Vec::new();
+        let err = write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut 
buf)).unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(err.to_string().contains("graph does not match"));
+    }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index 8f2f033..fd2b62d 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -22,14 +22,19 @@ pub mod blas;
 pub mod distance;
 pub mod fastscan;
 pub mod hnsw;
+pub(crate) mod hnsw_search;
+pub(crate) mod index_io_util;
 pub mod io;
 pub mod ivfflat;
 pub mod ivfflat_io;
 pub mod ivfhnswflat;
 pub mod ivfhnswflat_io;
+pub mod ivfhnswsq;
+pub mod ivfhnswsq_io;
 pub mod ivfpq;
 pub mod kmeans;
 pub mod opq;
 pub mod pq;
 pub mod shuffler;
+pub mod sq;
 pub mod topk;
diff --git a/core/src/sq.rs b/core/src/sq.rs
new file mode 100644
index 0000000..f1eb5da
--- /dev/null
+++ b/core/src/sq.rs
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_norm_l2sqr, MetricType};
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct ScalarQuantizer {
+    pub d: usize,
+    pub min: f32,
+    pub max: f32,
+}
+
+impl ScalarQuantizer {
+    pub fn new(d: usize) -> Self {
+        Self {
+            d,
+            min: 0.0,
+            max: 0.0,
+        }
+    }
+
+    pub fn with_bounds(d: usize, min: f32, max: f32) -> Self {
+        Self { d, min, max }
+    }
+
+    pub fn train(&mut self, data: &[f32], n: usize) {
+        let values = &data[..n * self.d];
+        if values.is_empty() {
+            self.min = 0.0;
+            self.max = 0.0;
+            return;
+        }
+
+        let mut min = f32::INFINITY;
+        let mut max = f32::NEG_INFINITY;
+        for &value in values {
+            min = min.min(value);
+            max = max.max(value);
+        }
+        self.min = min;
+        self.max = max;
+    }
+
+    pub fn code_size(&self) -> usize {
+        self.d
+    }
+
+    pub fn encode_batch(&self, data: &[f32], n: usize, codes: &mut [u8]) {
+        let len = n * self.d;
+        assert!(data.len() >= len);
+        assert!(codes.len() >= len);
+
+        if self.min >= self.max {
+            codes[..len].fill(0);
+            return;
+        }
+
+        let scale = 255.0 / (self.max - self.min);
+        for i in 0..len {
+            let scaled = ((data[i] - self.min) * scale).clamp(0.0, 255.0);
+            codes[i] = scaled as u8;
+        }
+    }
+
+    pub fn encode(&self, vector: &[f32], code: &mut [u8]) {
+        self.encode_batch(vector, 1, code);
+    }
+
+    pub fn decode_batch(&self, codes: &[u8], n: usize, vectors: &mut [f32]) {
+        let len = n * self.d;
+        assert!(codes.len() >= len);
+        assert!(vectors.len() >= len);
+
+        if self.min >= self.max {
+            vectors[..len].fill(self.min);
+            return;
+        }
+
+        let scale = (self.max - self.min) / 255.0;
+        for i in 0..len {
+            vectors[i] = self.min + codes[i] as f32 * scale;
+        }
+    }
+
+    pub fn decode(&self, code: &[u8], vector: &mut [f32]) {
+        self.decode_batch(code, 1, vector);
+    }
+
+    pub fn distance_to_code(&self, query: &[f32], code: &[u8], metric: 
MetricType) -> f32 {
+        self.distance_to_code_with_context(query, code, 
self.distance_context(query, metric))
+    }
+
+    pub fn distance_context(&self, query: &[f32], metric: MetricType) -> 
DistanceContext {
+        debug_assert!(query.len() >= self.d);
+        DistanceContext::new(&query[..self.d], metric)
+    }
+
+    pub fn distance_to_code_with_context(
+        &self,
+        query: &[f32],
+        code: &[u8],
+        context: DistanceContext,
+    ) -> f32 {
+        debug_assert!(query.len() >= self.d);
+        debug_assert!(code.len() >= self.d);
+
+        match context.metric {
+            MetricType::L2 => {
+                let mut sum = 0.0f32;
+                for i in 0..self.d {
+                    let diff = query[i] - self.decode_value(code[i]);
+                    sum += diff * diff;
+                }
+                sum
+            }
+            MetricType::InnerProduct => {
+                let mut dot = 0.0f32;
+                for i in 0..self.d {
+                    dot += query[i] * self.decode_value(code[i]);
+                }
+                -dot
+            }
+            MetricType::Cosine => {
+                let mut dot = 0.0f32;
+                let mut vector_norm = 0.0f32;
+                for i in 0..self.d {
+                    let value = self.decode_value(code[i]);
+                    dot += query[i] * value;
+                    vector_norm += value * value;
+                }
+                let denom = context.query_norm * vector_norm.sqrt();
+                if denom > 0.0 {
+                    1.0 - dot / denom
+                } else {
+                    1.0
+                }
+            }
+        }
+    }
+
+    fn decode_value(&self, code: u8) -> f32 {
+        if self.min >= self.max {
+            self.min
+        } else {
+            self.min + code as f32 * (self.max - self.min) / 255.0
+        }
+    }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DistanceContext {
+    metric: MetricType,
+    query_norm: f32,
+}
+
+impl DistanceContext {
+    pub fn new(query: &[f32], metric: MetricType) -> Self {
+        let query_norm = if metric == MetricType::Cosine {
+            fvec_norm_l2sqr(query).sqrt()
+        } else {
+            0.0
+        };
+        Self { metric, query_norm }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_scalar_quantizer_round_trips_bounds() {
+        let data = vec![-1.0, 0.0, 1.0, 3.0];
+        let mut sq = ScalarQuantizer::new(2);
+
+        sq.train(&data, 2);
+        let mut codes = vec![0u8; data.len()];
+        sq.encode_batch(&data, 2, &mut codes);
+        let mut decoded = vec![0.0f32; data.len()];
+        sq.decode_batch(&codes, 2, &mut decoded);
+
+        assert_eq!(sq.min, -1.0);
+        assert_eq!(sq.max, 3.0);
+        assert_eq!(codes[0], 0);
+        assert_eq!(codes[3], 255);
+        assert!((decoded[0] + 1.0).abs() < 1e-6);
+        assert!((decoded[3] - 3.0).abs() < 1e-6);
+    }
+
+    #[test]
+    fn test_scalar_quantizer_constant_input() {
+        let data = vec![5.0, 5.0, 5.0, 5.0];
+        let mut sq = ScalarQuantizer::new(2);
+        sq.train(&data, 2);
+
+        let mut codes = vec![7u8; data.len()];
+        sq.encode_batch(&data, 2, &mut codes);
+        let mut decoded = vec![0.0f32; data.len()];
+        sq.decode_batch(&codes, 2, &mut decoded);
+
+        assert_eq!(codes, vec![0, 0, 0, 0]);
+        assert_eq!(decoded, data);
+    }
+
+    #[test]
+    fn test_scalar_quantizer_distance_to_code() {
+        let sq = ScalarQuantizer::with_bounds(2, 0.0, 1.0);
+        let mut code = vec![0u8; 2];
+        sq.encode(&[1.0, 0.0], &mut code);
+
+        let dist = sq.distance_to_code(&[1.0, 0.0], &code, MetricType::L2);
+
+        assert!(dist < 1e-6);
+    }
+}

Reply via email to