This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new 1de37a0  Improve IVF HNSW SQ quantization (#22)
1de37a0 is described below

commit 1de37a0c30e17d5d96a2654cbe51fbb6236a1246
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 16:51:16 2026 +0800

    Improve IVF HNSW SQ quantization (#22)
---
 core/benches/recall_bench.rs |  33 ++++++++
 core/src/ivfhnswsq.rs        |  94 ++++++++++++++++++----
 core/src/ivfhnswsq_io.rs     | 182 ++++++++++++++++++++++++++++++++++++-------
 core/src/sq.rs               | 182 ++++++++++++++++++++++++++++++++++---------
 4 files changed, 414 insertions(+), 77 deletions(-)

diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index 5103446..aef434f 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -1,8 +1,12 @@
 use paimon_vindex_core::distance::{fvec_distance, MetricType};
 use paimon_vindex_core::hnsw::HnswBuildParams;
+use paimon_vindex_core::io::{write_index, PosWriter};
 use paimon_vindex_core::ivfflat::IVFFlatIndex;
+use paimon_vindex_core::ivfflat_io::write_ivfflat_index;
 use paimon_vindex_core::ivfhnswflat::IVFHNSWFlatIndex;
+use paimon_vindex_core::ivfhnswflat_io::write_ivfhnswflat_index;
 use paimon_vindex_core::ivfhnswsq::IVFHNSWSQIndex;
+use paimon_vindex_core::ivfhnswsq_io::write_ivfhnswsq_index;
 use paimon_vindex_core::ivfpq::IVFPQIndex;
 use std::collections::HashSet;
 use std::time::Instant;
@@ -115,6 +119,7 @@ fn run_scenario(s: Scenario<'_>) {
     ivfhnswsq.add(&data, &ids, s.n);
     ivfhnswsq.build_graphs().unwrap();
     println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
+    print_sizes(&ivfpq, &ivfflat, &ivfhnswflat, &ivfhnswsq);
 
     println!();
     println!(
@@ -200,6 +205,34 @@ fn run_scenario(s: Scenario<'_>) {
     }
 }
 
+fn print_sizes(
+    ivfpq: &IVFPQIndex,
+    ivfflat: &IVFFlatIndex,
+    ivfhnswflat: &IVFHNSWFlatIndex,
+    ivfhnswsq: &IVFHNSWSQIndex,
+) {
+    let mut pq = Vec::new();
+    write_index(ivfpq, &mut PosWriter::new(&mut pq)).unwrap();
+    let mut flat = Vec::new();
+    write_ivfflat_index(ivfflat, &mut PosWriter::new(&mut flat)).unwrap();
+    let mut hnswflat = Vec::new();
+    write_ivfhnswflat_index(ivfhnswflat, &mut PosWriter::new(&mut 
hnswflat)).unwrap();
+    let mut hnswsq = Vec::new();
+    write_ivfhnswsq_index(ivfhnswsq, &mut PosWriter::new(&mut 
hnswsq)).unwrap();
+
+    println!(
+        "serialized sizes: IVF-PQ={:.2} MiB, IVF-FLAT={:.2} MiB, 
IVF-HNSW-FLAT={:.2} MiB, IVF-HNSW-SQ={:.2} MiB",
+        bytes_to_mib(pq.len()),
+        bytes_to_mib(flat.len()),
+        bytes_to_mib(hnswflat.len()),
+        bytes_to_mib(hnswsq.len())
+    );
+}
+
+fn bytes_to_mib(bytes: usize) -> f64 {
+    bytes as f64 / 1024.0 / 1024.0
+}
+
 fn print_row(
     index: &str,
     nprobe: usize,
diff --git a/core/src/ivfhnswsq.rs b/core/src/ivfhnswsq.rs
index a945388..4ce2a6e 100644
--- a/core/src/ivfhnswsq.rs
+++ b/core/src/ivfhnswsq.rs
@@ -30,6 +30,7 @@ pub struct IVFHNSWSQIndex {
     pub metric: MetricType,
     pub quantizer_centroids: Vec<f32>,
     pub sq: ScalarQuantizer,
+    pub list_sqs: Vec<ScalarQuantizer>,
     pub ids: Vec<Vec<i64>>,
     pub codes: Vec<Vec<u8>>,
     pub graphs: Vec<Option<HnswGraph>>,
@@ -44,6 +45,7 @@ impl IVFHNSWSQIndex {
             metric,
             quantizer_centroids: Vec::new(),
             sq: ScalarQuantizer::new(d),
+            list_sqs: vec![ScalarQuantizer::new(d); nlist],
             ids: vec![Vec::new(); nlist],
             codes: vec![Vec::new(); nlist],
             graphs: vec![None; nlist],
@@ -55,21 +57,22 @@ impl IVFHNSWSQIndex {
         let processed = self.preprocess_vectors(data, n);
         self.quantizer_centroids =
             kmeans::kmeans_train(&KMeansConfig::default(), &processed, n, 
self.d, self.nlist);
-        self.sq.train(&processed, n);
+        let (list_ids, residuals) = self.assign_residuals(&processed, n);
+        self.sq.train(&residuals, n);
+        self.train_list_sqs(&list_ids, &residuals);
     }
 
     pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
         let processed = self.preprocess_vectors(data, n);
-        let code_size = self.sq.code_size();
-        let mut encoded = vec![0u8; n * code_size];
-        self.sq.encode_batch(&processed, n, &mut encoded);
+        let code_size = self.code_size();
+        let (list_ids, residuals) = self.assign_residuals(&processed, n);
 
+        let mut code = vec![0u8; code_size];
         for i in 0..n {
-            let vector = &processed[i * self.d..(i + 1) * self.d];
-            let list_id =
-                kmeans::find_nearest(vector, &self.quantizer_centroids, 
self.nlist, self.d);
+            let list_id = list_ids[i];
+            self.list_sqs[list_id].encode(&residuals[i * self.d..(i + 1) * 
self.d], &mut code);
             self.ids[list_id].push(ids[i]);
-            self.codes[list_id].extend_from_slice(&encoded[i * code_size..(i + 
1) * code_size]);
+            self.codes[list_id].extend_from_slice(&code);
         }
         self.graphs.fill(None);
     }
@@ -84,9 +87,7 @@ impl IVFHNSWSQIndex {
             self.graphs[list_id] = if count == 0 {
                 None
             } else {
-                let mut vectors = vec![0.0f32; count * self.d];
-                self.sq
-                    .decode_batch(&self.codes[list_id], count, &mut vectors);
+                let vectors = self.decode_list_vectors(list_id, count);
                 Some(HnswGraph::build(
                     &vectors,
                     count,
@@ -181,18 +182,85 @@ impl IVFHNSWSQIndex {
         heap: &mut TopKHeap,
     ) {
         let context = self.sq.distance_context(query, self.metric);
-        let code_size = self.sq.code_size();
+        let sq = self.list_sq(list_id);
+        let code_size = self.code_size();
+        let centroid = self.list_centroid(list_id);
         for (local_id, &row_id) in self.ids[list_id].iter().enumerate() {
             if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
                 continue;
             }
             let code = &self.codes[list_id][local_id * code_size..(local_id + 
1) * code_size];
             heap.push(
-                self.sq.distance_to_code_with_context(query, code, context),
+                sq.distance_to_code_with_offset_with_context(query, code, 
centroid, context),
                 row_id,
             );
         }
     }
+
+    pub(crate) fn decode_list_vectors(&self, list_id: usize, count: usize) -> 
Vec<f32> {
+        let mut vectors = vec![0.0f32; count * self.d];
+        self.list_sq(list_id)
+            .decode_batch(&self.codes[list_id], count, &mut vectors);
+        let centroid = self.list_centroid(list_id);
+        for vector in vectors.chunks_exact_mut(self.d) {
+            for i in 0..self.d {
+                vector[i] += centroid[i];
+            }
+        }
+        vectors
+    }
+
+    fn assign_residuals(&self, processed: &[f32], n: usize) -> (Vec<usize>, 
Vec<f32>) {
+        let mut list_ids = Vec::with_capacity(n);
+        let mut residuals = vec![0.0f32; n * self.d];
+        for i in 0..n {
+            let vector = &processed[i * self.d..(i + 1) * self.d];
+            let list_id =
+                kmeans::find_nearest(vector, &self.quantizer_centroids, 
self.nlist, self.d);
+            list_ids.push(list_id);
+            self.write_residual(
+                vector,
+                list_id,
+                &mut residuals[i * self.d..(i + 1) * self.d],
+            );
+        }
+        (list_ids, residuals)
+    }
+
+    fn train_list_sqs(&mut self, list_ids: &[usize], residuals: &[f32]) {
+        let mut list_residuals = vec![Vec::new(); self.nlist];
+        for (i, &list_id) in list_ids.iter().enumerate() {
+            let residual = &residuals[i * self.d..(i + 1) * self.d];
+            list_residuals[list_id].extend_from_slice(residual);
+        }
+        self.list_sqs = vec![self.sq.clone(); self.nlist];
+        for (list_id, values) in list_residuals.iter().enumerate() {
+            if !values.is_empty() {
+                let mut sq = ScalarQuantizer::new(self.d);
+                sq.train(values, values.len() / self.d);
+                self.list_sqs[list_id] = sq;
+            }
+        }
+    }
+
+    fn write_residual(&self, vector: &[f32], list_id: usize, out: &mut [f32]) {
+        let centroid = self.list_centroid(list_id);
+        for i in 0..self.d {
+            out[i] = vector[i] - centroid[i];
+        }
+    }
+
+    fn list_centroid(&self, list_id: usize) -> &[f32] {
+        &self.quantizer_centroids[list_id * self.d..(list_id + 1) * self.d]
+    }
+
+    pub(crate) fn list_sq(&self, list_id: usize) -> &ScalarQuantizer {
+        self.list_sqs.get(list_id).unwrap_or(&self.sq)
+    }
+
+    pub(crate) fn code_size(&self) -> usize {
+        self.sq.code_size()
+    }
 }
 
 #[cfg(test)]
diff --git a/core/src/ivfhnswsq_io.rs b/core/src/ivfhnswsq_io.rs
index cb06eef..c86a478 100644
--- a/core/src/ivfhnswsq_io.rs
+++ b/core/src/ivfhnswsq_io.rs
@@ -70,10 +70,17 @@ pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out: 
&mut dyn SeekWrite) ->
         usize_to_i32(params.ef_construction, "hnsw ef_construction")?,
     )?;
     write_i32_le(out, usize_to_i32(params.max_level, "hnsw max_level")?)?;
-    out.write_all(&index.sq.min.to_le_bytes())?;
-    out.write_all(&index.sq.max.to_le_bytes())?;
+    let (sq_min, sq_max) = sq_global_bounds(&index.sq.mins, &index.sq.maxs);
+    out.write_all(&sq_min.to_le_bytes())?;
+    out.write_all(&sq_max.to_le_bytes())?;
     out.write_all(&[0u8; 16])?;
 
+    write_f32_slice(out, &index.sq.mins)?;
+    write_f32_slice(out, &index.sq.maxs)?;
+    for sq in &index.list_sqs {
+        write_f32_slice(out, &sq.mins)?;
+        write_f32_slice(out, &sq.maxs)?;
+    }
     write_f32_slice(out, &index.quantizer_centroids)?;
 
     let offset_table_size = index.nlist.checked_mul(24).ok_or_else(|| {
@@ -141,6 +148,7 @@ pub struct IVFHNSWSQIndexReader<R: SeekRead> {
     pub total_vectors: i64,
     pub hnsw_params: HnswBuildParams,
     pub sq: ScalarQuantizer,
+    pub list_sqs: Vec<ScalarQuantizer>,
     pub quantizer_centroids: Vec<f32>,
     pub list_offsets: Vec<i64>,
     pub list_counts: Vec<i32>,
@@ -186,21 +194,23 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
             max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw 
max_level")? as usize,
         }
         .sanitized();
-        let mut min_bytes = [0u8; 4];
-        let mut max_bytes = [0u8; 4];
-        reader.read_exact(&mut min_bytes)?;
-        reader.read_exact(&mut max_bytes)?;
-        let sq_min = f32::from_le_bytes(min_bytes);
-        let sq_max = f32::from_le_bytes(max_bytes);
-        if !sq_min.is_finite() || !sq_max.is_finite() {
-            return Err(io::Error::new(
-                io::ErrorKind::InvalidData,
-                "SQ bounds must be finite",
-            ));
-        }
+        let mut bounds_summary = [0u8; 8];
+        reader.read_exact(&mut bounds_summary)?;
         let mut reserved = [0u8; 16];
         reader.read_exact(&mut reserved)?;
 
+        let mins = read_f32_vec(&mut reader, d)?;
+        let maxs = read_f32_vec(&mut reader, d)?;
+        validate_sq_bounds(d, &mins, &maxs)?;
+        let sq = ScalarQuantizer::with_dimension_bounds(d, mins, maxs);
+        let mut list_sqs = Vec::with_capacity(nlist);
+        for _ in 0..nlist {
+            let mins = read_f32_vec(&mut reader, d)?;
+            let maxs = read_f32_vec(&mut reader, d)?;
+            validate_sq_bounds(d, &mins, &maxs)?;
+            list_sqs.push(ScalarQuantizer::with_dimension_bounds(d, mins, 
maxs));
+        }
+
         Ok(Self {
             reader,
             d,
@@ -208,7 +218,8 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
             metric,
             total_vectors,
             hnsw_params,
-            sq: ScalarQuantizer::with_bounds(d, sq_min, sq_max),
+            sq,
+            list_sqs,
             quantizer_centroids: Vec::new(),
             list_offsets: Vec::new(),
             list_counts: Vec::new(),
@@ -222,7 +233,9 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
             return Ok(());
         }
 
-        self.reader.seek(IVF_HNSW_SQ_HEADER_SIZE as u64)?;
+        let quantizer_centroids_offset =
+            IVF_HNSW_SQ_HEADER_SIZE as u64 + (self.d as u64) * 8 * (self.nlist 
as u64 + 1);
+        self.reader.seek(quantizer_centroids_offset)?;
         self.quantizer_centroids =
             read_f32_vec(&mut self.reader, checked_section_size(self.nlist, 
self.d)?)?;
         self.list_offsets = vec![0; self.nlist];
@@ -300,7 +313,14 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
             .collect();
         let codes = payload[ids_bytes_len..ids_bytes_len + 
codes_bytes_len].to_vec();
         let mut vectors = vec![0.0f32; count * self.d];
-        self.sq.decode_batch(&codes, count, &mut vectors);
+        self.list_sq(list_id)
+            .decode_batch(&codes, count, &mut vectors);
+        let centroid = self.list_centroid(list_id).to_vec();
+        for vector in vectors.chunks_exact_mut(self.d) {
+            for i in 0..self.d {
+                vector[i] += centroid[i];
+            }
+        }
         let graph = decode_graph(
             &payload[ids_bytes_len + codes_bytes_len..],
             vectors,
@@ -315,7 +335,21 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
                 format!("list {} is missing HNSW graph", list_id),
             )
         })?;
-        Ok(Some(GraphList { ids, codes, graph }))
+        Ok(Some(GraphList {
+            ids,
+            codes,
+            graph,
+            centroid: Some(centroid),
+            sq: self.list_sq(list_id).clone(),
+        }))
+    }
+
+    fn list_centroid(&self, list_id: usize) -> &[f32] {
+        &self.quantizer_centroids[list_id * self.d..(list_id + 1) * self.d]
+    }
+
+    fn list_sq(&self, list_id: usize) -> &ScalarQuantizer {
+        self.list_sqs.get(list_id).unwrap_or(&self.sq)
     }
 
     pub fn search(
@@ -362,7 +396,8 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
                 &q,
                 &list.ids,
                 &list.codes,
-                &self.sq,
+                list.centroid.as_deref(),
+                &list.sq,
                 self.metric,
                 filter,
                 heap,
@@ -457,6 +492,8 @@ pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
             ids: list.ids,
             codes: list.codes,
             graph: list.graph,
+            centroid: list.centroid,
+            sq: list.sq,
         });
     }
 
@@ -471,7 +508,8 @@ pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
                     query,
                     &list.ids,
                     &list.codes,
-                    &reader.sq,
+                    list.centroid.as_deref(),
+                    &list.sq,
                     reader.metric,
                     filter,
                     &mut heaps[qi],
@@ -501,7 +539,8 @@ pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
                     query,
                     &list.ids,
                     &list.codes,
-                    &reader.sq,
+                    list.centroid.as_deref(),
+                    &list.sq,
                     reader.metric,
                     filter,
                     &mut heaps[qi],
@@ -541,6 +580,8 @@ struct GraphList {
     ids: Vec<i64>,
     codes: Vec<u8>,
     graph: HnswGraph,
+    centroid: Option<Vec<f32>>,
+    sq: ScalarQuantizer,
 }
 
 struct LoadedBatchList {
@@ -548,12 +589,15 @@ struct LoadedBatchList {
     ids: Vec<i64>,
     codes: Vec<u8>,
     graph: HnswGraph,
+    centroid: Option<Vec<f32>>,
+    sq: ScalarQuantizer,
 }
 
 fn scan_sq_list(
     query: &[f32],
     ids: &[i64],
     codes: &[u8],
+    centroid: Option<&[f32]>,
     sq: &ScalarQuantizer,
     metric: MetricType,
     filter: Option<&dyn RowIdFilter>,
@@ -566,10 +610,12 @@ fn scan_sq_list(
             continue;
         }
         let code = &codes[local_id * code_size..(local_id + 1) * code_size];
-        heap.push(
-            sq.distance_to_code_with_context(query, code, context),
-            row_id,
-        );
+        let dist = if let Some(centroid) = centroid {
+            sq.distance_to_code_with_offset_with_context(query, code, 
centroid, context)
+        } else {
+            sq.distance_to_code_with_context(query, code, context)
+        };
+        heap.push(dist, row_id);
     }
 }
 
@@ -592,6 +638,25 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) -> 
io::Result<()> {
             "SQ dimension does not match index dimension",
         ));
     }
+    validate_sq_bounds(index.d, &index.sq.mins, &index.sq.maxs)?;
+    if index.list_sqs.len() != index.nlist {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "SQ list bounds count does not match nlist",
+        ));
+    }
+    for (list_id, sq) in index.list_sqs.iter().enumerate() {
+        if sq.d != index.d {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "SQ dimension for list {} does not match index dimension",
+                    list_id
+                ),
+            ));
+        }
+        validate_sq_bounds(index.d, &sq.mins, &sq.maxs)?;
+    }
     let centroid_len = checked_section_size(index.nlist, index.d)?;
     if index.quantizer_centroids.len() != centroid_len {
         return Err(io::Error::new(
@@ -628,10 +693,7 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) -> 
io::Result<()> {
         }
         match &index.graphs[list_id] {
             Some(graph) if count > 0 => {
-                let mut decoded = vec![0.0f32; count * index.d];
-                index
-                    .sq
-                    .decode_batch(&index.codes[list_id], count, &mut decoded);
+                let decoded = index.decode_list_vectors(list_id, count);
                 if graph.len() != count || graph.vectors() != 
decoded.as_slice() {
                     return Err(io::Error::new(
                         io::ErrorKind::InvalidInput,
@@ -660,6 +722,39 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) -> 
io::Result<()> {
     Ok(())
 }
 
+fn validate_sq_bounds(d: usize, mins: &[f32], maxs: &[f32]) -> io::Result<()> {
+    if mins.len() != d || maxs.len() != d {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!(
+                "SQ bounds length mismatch: d={}, mins={}, maxs={}",
+                d,
+                mins.len(),
+                maxs.len()
+            ),
+        ));
+    }
+    for (dim, (&min, &max)) in mins.iter().zip(maxs.iter()).enumerate() {
+        if !min.is_finite() || !max.is_finite() {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!("SQ bounds at dimension {} must be finite", dim),
+            ));
+        }
+    }
+    Ok(())
+}
+
+fn sq_global_bounds(mins: &[f32], maxs: &[f32]) -> (f32, f32) {
+    let min = mins.iter().copied().fold(f32::INFINITY, f32::min);
+    let max = maxs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
+    if min.is_finite() && max.is_finite() {
+        (min, max)
+    } else {
+        (0.0, 0.0)
+    }
+}
+
 fn list_payload_len(count: usize, code_size: usize, graph_bytes_len: usize) -> 
io::Result<usize> {
     let id_bytes = checked_list_bytes(count, 8)?;
     let code_bytes = checked_list_bytes(count, code_size)?;
@@ -719,6 +814,35 @@ mod tests {
         assert!(distances[0].is_finite());
     }
 
+    #[test]
+    fn test_ivfhnswsq_write_read_preserves_sq_dimension_bounds() {
+        let d = 2;
+        let nlist = 1;
+        let data = vec![0.0, -100.0, 1.0, 100.0];
+        let ids = vec![10, 11];
+        let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, 
HnswBuildParams::default());
+        index.train(&data, 2);
+        index.add(&data, &ids, 2);
+        index.build_graphs().unwrap();
+
+        let mut buf = Vec::new();
+        write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+        assert_eq!(
+            u32::from_le_bytes(buf[4..8].try_into().unwrap()),
+            IVF_HNSW_SQ_VERSION
+        );
+
+        let reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+
+        assert_eq!(reader.sq.mins, index.sq.mins);
+        assert_eq!(reader.sq.maxs, index.sq.maxs);
+        assert_eq!(reader.sq.min, index.sq.min);
+        assert_eq!(reader.sq.max, index.sq.max);
+        assert_eq!(reader.list_sqs.len(), nlist);
+        assert_eq!(reader.list_sqs[0].mins, index.list_sqs[0].mins);
+        assert_eq!(reader.list_sqs[0].maxs, index.list_sqs[0].maxs);
+    }
+
     #[test]
     fn test_ivfhnswsq_reader_search_with_roaring_filter() {
         let d = 2;
diff --git a/core/src/sq.rs b/core/src/sq.rs
index f1eb5da..812f94f 100644
--- a/core/src/sq.rs
+++ b/core/src/sq.rs
@@ -17,11 +17,13 @@
 
 use crate::distance::{fvec_norm_l2sqr, MetricType};
 
-#[derive(Debug, Clone, Copy, PartialEq)]
+#[derive(Debug, Clone, PartialEq)]
 pub struct ScalarQuantizer {
     pub d: usize,
     pub min: f32,
     pub max: f32,
+    pub mins: Vec<f32>,
+    pub maxs: Vec<f32>,
 }
 
 impl ScalarQuantizer {
@@ -30,29 +32,56 @@ impl ScalarQuantizer {
             d,
             min: 0.0,
             max: 0.0,
+            mins: vec![0.0; d],
+            maxs: vec![0.0; d],
         }
     }
 
     pub fn with_bounds(d: usize, min: f32, max: f32) -> Self {
-        Self { d, min, max }
+        Self {
+            d,
+            min,
+            max,
+            mins: vec![min; d],
+            maxs: vec![max; d],
+        }
+    }
+
+    pub fn with_dimension_bounds(d: usize, mins: Vec<f32>, maxs: Vec<f32>) -> 
Self {
+        assert_eq!(mins.len(), d);
+        assert_eq!(maxs.len(), d);
+        let mut sq = Self {
+            d,
+            min: 0.0,
+            max: 0.0,
+            mins,
+            maxs,
+        };
+        sq.refresh_global_bounds();
+        sq
     }
 
     pub fn train(&mut self, data: &[f32], n: usize) {
-        let values = &data[..n * self.d];
-        if values.is_empty() {
+        let len = n * self.d;
+        let values = &data[..len];
+        if n == 0 || self.d == 0 {
             self.min = 0.0;
             self.max = 0.0;
+            self.mins.fill(0.0);
+            self.maxs.fill(0.0);
             return;
         }
 
-        let mut min = f32::INFINITY;
-        let mut max = f32::NEG_INFINITY;
-        for &value in values {
-            min = min.min(value);
-            max = max.max(value);
+        self.ensure_bounds_len();
+        self.mins.fill(f32::INFINITY);
+        self.maxs.fill(f32::NEG_INFINITY);
+        for vector in values.chunks_exact(self.d) {
+            for i in 0..self.d {
+                self.mins[i] = self.mins[i].min(vector[i]);
+                self.maxs[i] = self.maxs[i].max(vector[i]);
+            }
         }
-        self.min = min;
-        self.max = max;
+        self.refresh_global_bounds();
     }
 
     pub fn code_size(&self) -> usize {
@@ -64,15 +93,19 @@ impl ScalarQuantizer {
         assert!(data.len() >= len);
         assert!(codes.len() >= len);
 
-        if self.min >= self.max {
-            codes[..len].fill(0);
-            return;
-        }
-
-        let scale = 255.0 / (self.max - self.min);
-        for i in 0..len {
-            let scaled = ((data[i] - self.min) * scale).clamp(0.0, 255.0);
-            codes[i] = scaled as u8;
+        for row in 0..n {
+            let base = row * self.d;
+            for dim in 0..self.d {
+                let min = self.mins[dim];
+                let max = self.maxs[dim];
+                let out = base + dim;
+                codes[out] = if min >= max {
+                    0
+                } else {
+                    let scaled = ((data[out] - min) * 255.0 / (max - 
min)).clamp(0.0, 255.0);
+                    scaled.round() as u8
+                };
+            }
         }
     }
 
@@ -85,14 +118,11 @@ impl ScalarQuantizer {
         assert!(codes.len() >= len);
         assert!(vectors.len() >= len);
 
-        if self.min >= self.max {
-            vectors[..len].fill(self.min);
-            return;
-        }
-
-        let scale = (self.max - self.min) / 255.0;
-        for i in 0..len {
-            vectors[i] = self.min + codes[i] as f32 * scale;
+        for row in 0..n {
+            let base = row * self.d;
+            for dim in 0..self.d {
+                vectors[base + dim] = self.decode_value(codes[base + dim], 
dim);
+            }
         }
     }
 
@@ -114,6 +144,31 @@ impl ScalarQuantizer {
         query: &[f32],
         code: &[u8],
         context: DistanceContext,
+    ) -> f32 {
+        self.distance_to_code_impl(query, code, &[], false, context)
+    }
+
+    pub fn distance_to_code_with_offset_with_context(
+        &self,
+        query: &[f32],
+        code: &[u8],
+        offset: &[f32],
+        context: DistanceContext,
+    ) -> f32 {
+        debug_assert!(query.len() >= self.d);
+        debug_assert!(code.len() >= self.d);
+        debug_assert!(offset.len() >= self.d);
+
+        self.distance_to_code_impl(query, code, offset, true, context)
+    }
+
+    fn distance_to_code_impl(
+        &self,
+        query: &[f32],
+        code: &[u8],
+        offset: &[f32],
+        use_offset: bool,
+        context: DistanceContext,
     ) -> f32 {
         debug_assert!(query.len() >= self.d);
         debug_assert!(code.len() >= self.d);
@@ -122,7 +177,8 @@ impl ScalarQuantizer {
             MetricType::L2 => {
                 let mut sum = 0.0f32;
                 for i in 0..self.d {
-                    let diff = query[i] - self.decode_value(code[i]);
+                    let diff =
+                        query[i] - self.decode_value_with_offset(code[i], i, 
offset, use_offset);
                     sum += diff * diff;
                 }
                 sum
@@ -130,7 +186,7 @@ impl ScalarQuantizer {
             MetricType::InnerProduct => {
                 let mut dot = 0.0f32;
                 for i in 0..self.d {
-                    dot += query[i] * self.decode_value(code[i]);
+                    dot += query[i] * self.decode_value_with_offset(code[i], 
i, offset, use_offset);
                 }
                 -dot
             }
@@ -138,7 +194,7 @@ impl ScalarQuantizer {
                 let mut dot = 0.0f32;
                 let mut vector_norm = 0.0f32;
                 for i in 0..self.d {
-                    let value = self.decode_value(code[i]);
+                    let value = self.decode_value_with_offset(code[i], i, 
offset, use_offset);
                     dot += query[i] * value;
                     vector_norm += value * value;
                 }
@@ -152,12 +208,48 @@ impl ScalarQuantizer {
         }
     }
 
-    fn decode_value(&self, code: u8) -> f32 {
-        if self.min >= self.max {
-            self.min
+    fn decode_value_with_offset(
+        &self,
+        code: u8,
+        dim: usize,
+        offset: &[f32],
+        use_offset: bool,
+    ) -> f32 {
+        let value = self.decode_value(code, dim);
+        if use_offset {
+            value + offset[dim]
+        } else {
+            value
+        }
+    }
+
+    fn decode_value(&self, code: u8, dim: usize) -> f32 {
+        let min = self.mins[dim];
+        let max = self.maxs[dim];
+        if min >= max {
+            min
         } else {
-            self.min + code as f32 * (self.max - self.min) / 255.0
+            min + code as f32 * (max - min) / 255.0
+        }
+    }
+
+    fn ensure_bounds_len(&mut self) {
+        if self.mins.len() != self.d {
+            self.mins.resize(self.d, 0.0);
+        }
+        if self.maxs.len() != self.d {
+            self.maxs.resize(self.d, 0.0);
+        }
+    }
+
+    fn refresh_global_bounds(&mut self) {
+        if self.d == 0 {
+            self.min = 0.0;
+            self.max = 0.0;
+            return;
         }
+        self.min = self.mins.iter().copied().fold(f32::INFINITY, f32::min);
+        self.max = self.maxs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
     }
 }
 
@@ -195,6 +287,8 @@ mod tests {
 
         assert_eq!(sq.min, -1.0);
         assert_eq!(sq.max, 3.0);
+        assert_eq!(sq.mins, vec![-1.0, 0.0]);
+        assert_eq!(sq.maxs, vec![1.0, 3.0]);
         assert_eq!(codes[0], 0);
         assert_eq!(codes[3], 255);
         assert!((decoded[0] + 1.0).abs() < 1e-6);
@@ -216,6 +310,24 @@ mod tests {
         assert_eq!(decoded, data);
     }
 
+    #[test]
+    fn test_scalar_quantizer_uses_per_dimension_bounds() {
+        let data = vec![0.0, -100.0, 1.0, 100.0];
+        let mut sq = ScalarQuantizer::new(2);
+        sq.train(&data, 2);
+
+        let mut codes = vec![0u8; data.len()];
+        sq.encode_batch(&data, 2, &mut codes);
+        let mut decoded = vec![0.0f32; data.len()];
+        sq.decode_batch(&codes, 2, &mut decoded);
+
+        assert_eq!(codes, vec![0, 0, 255, 255]);
+        assert!((decoded[0] - 0.0).abs() < 1e-6);
+        assert!((decoded[1] + 100.0).abs() < 1e-6);
+        assert!((decoded[2] - 1.0).abs() < 1e-6);
+        assert!((decoded[3] - 100.0).abs() < 1e-6);
+    }
+
     #[test]
     fn test_scalar_quantizer_distance_to_code() {
         let sq = ScalarQuantizer::with_bounds(2, 0.0, 1.0);

Reply via email to