This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 1de37a0 Improve IVF HNSW SQ quantization (#22)
1de37a0 is described below
commit 1de37a0c30e17d5d96a2654cbe51fbb6236a1246
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 16:51:16 2026 +0800
Improve IVF HNSW SQ quantization (#22)
---
core/benches/recall_bench.rs | 33 ++++++++
core/src/ivfhnswsq.rs | 94 ++++++++++++++++++----
core/src/ivfhnswsq_io.rs | 182 ++++++++++++++++++++++++++++++++++++-------
core/src/sq.rs | 182 ++++++++++++++++++++++++++++++++++---------
4 files changed, 414 insertions(+), 77 deletions(-)
diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
index 5103446..aef434f 100644
--- a/core/benches/recall_bench.rs
+++ b/core/benches/recall_bench.rs
@@ -1,8 +1,12 @@
use paimon_vindex_core::distance::{fvec_distance, MetricType};
use paimon_vindex_core::hnsw::HnswBuildParams;
+use paimon_vindex_core::io::{write_index, PosWriter};
use paimon_vindex_core::ivfflat::IVFFlatIndex;
+use paimon_vindex_core::ivfflat_io::write_ivfflat_index;
use paimon_vindex_core::ivfhnswflat::IVFHNSWFlatIndex;
+use paimon_vindex_core::ivfhnswflat_io::write_ivfhnswflat_index;
use paimon_vindex_core::ivfhnswsq::IVFHNSWSQIndex;
+use paimon_vindex_core::ivfhnswsq_io::write_ivfhnswsq_index;
use paimon_vindex_core::ivfpq::IVFPQIndex;
use std::collections::HashSet;
use std::time::Instant;
@@ -115,6 +119,7 @@ fn run_scenario(s: Scenario<'_>) {
ivfhnswsq.add(&data, &ids, s.n);
ivfhnswsq.build_graphs().unwrap();
println!("build IVF-HNSW-SQ: {:.2}s", start.elapsed().as_secs_f64());
+ print_sizes(&ivfpq, &ivfflat, &ivfhnswflat, &ivfhnswsq);
println!();
println!(
@@ -200,6 +205,34 @@ fn run_scenario(s: Scenario<'_>) {
}
}
+fn print_sizes(
+ ivfpq: &IVFPQIndex,
+ ivfflat: &IVFFlatIndex,
+ ivfhnswflat: &IVFHNSWFlatIndex,
+ ivfhnswsq: &IVFHNSWSQIndex,
+) {
+ let mut pq = Vec::new();
+ write_index(ivfpq, &mut PosWriter::new(&mut pq)).unwrap();
+ let mut flat = Vec::new();
+ write_ivfflat_index(ivfflat, &mut PosWriter::new(&mut flat)).unwrap();
+ let mut hnswflat = Vec::new();
+ write_ivfhnswflat_index(ivfhnswflat, &mut PosWriter::new(&mut
hnswflat)).unwrap();
+ let mut hnswsq = Vec::new();
+ write_ivfhnswsq_index(ivfhnswsq, &mut PosWriter::new(&mut
hnswsq)).unwrap();
+
+ println!(
+ "serialized sizes: IVF-PQ={:.2} MiB, IVF-FLAT={:.2} MiB,
IVF-HNSW-FLAT={:.2} MiB, IVF-HNSW-SQ={:.2} MiB",
+ bytes_to_mib(pq.len()),
+ bytes_to_mib(flat.len()),
+ bytes_to_mib(hnswflat.len()),
+ bytes_to_mib(hnswsq.len())
+ );
+}
+
+fn bytes_to_mib(bytes: usize) -> f64 {
+ bytes as f64 / 1024.0 / 1024.0
+}
+
fn print_row(
index: &str,
nprobe: usize,
diff --git a/core/src/ivfhnswsq.rs b/core/src/ivfhnswsq.rs
index a945388..4ce2a6e 100644
--- a/core/src/ivfhnswsq.rs
+++ b/core/src/ivfhnswsq.rs
@@ -30,6 +30,7 @@ pub struct IVFHNSWSQIndex {
pub metric: MetricType,
pub quantizer_centroids: Vec<f32>,
pub sq: ScalarQuantizer,
+ pub list_sqs: Vec<ScalarQuantizer>,
pub ids: Vec<Vec<i64>>,
pub codes: Vec<Vec<u8>>,
pub graphs: Vec<Option<HnswGraph>>,
@@ -44,6 +45,7 @@ impl IVFHNSWSQIndex {
metric,
quantizer_centroids: Vec::new(),
sq: ScalarQuantizer::new(d),
+ list_sqs: vec![ScalarQuantizer::new(d); nlist],
ids: vec![Vec::new(); nlist],
codes: vec![Vec::new(); nlist],
graphs: vec![None; nlist],
@@ -55,21 +57,22 @@ impl IVFHNSWSQIndex {
let processed = self.preprocess_vectors(data, n);
self.quantizer_centroids =
kmeans::kmeans_train(&KMeansConfig::default(), &processed, n,
self.d, self.nlist);
- self.sq.train(&processed, n);
+ let (list_ids, residuals) = self.assign_residuals(&processed, n);
+ self.sq.train(&residuals, n);
+ self.train_list_sqs(&list_ids, &residuals);
}
pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
let processed = self.preprocess_vectors(data, n);
- let code_size = self.sq.code_size();
- let mut encoded = vec![0u8; n * code_size];
- self.sq.encode_batch(&processed, n, &mut encoded);
+ let code_size = self.code_size();
+ let (list_ids, residuals) = self.assign_residuals(&processed, n);
+ let mut code = vec![0u8; code_size];
for i in 0..n {
- let vector = &processed[i * self.d..(i + 1) * self.d];
- let list_id =
- kmeans::find_nearest(vector, &self.quantizer_centroids,
self.nlist, self.d);
+ let list_id = list_ids[i];
+ self.list_sqs[list_id].encode(&residuals[i * self.d..(i + 1) *
self.d], &mut code);
self.ids[list_id].push(ids[i]);
- self.codes[list_id].extend_from_slice(&encoded[i * code_size..(i +
1) * code_size]);
+ self.codes[list_id].extend_from_slice(&code);
}
self.graphs.fill(None);
}
@@ -84,9 +87,7 @@ impl IVFHNSWSQIndex {
self.graphs[list_id] = if count == 0 {
None
} else {
- let mut vectors = vec![0.0f32; count * self.d];
- self.sq
- .decode_batch(&self.codes[list_id], count, &mut vectors);
+ let vectors = self.decode_list_vectors(list_id, count);
Some(HnswGraph::build(
&vectors,
count,
@@ -181,18 +182,85 @@ impl IVFHNSWSQIndex {
heap: &mut TopKHeap,
) {
let context = self.sq.distance_context(query, self.metric);
- let code_size = self.sq.code_size();
+ let sq = self.list_sq(list_id);
+ let code_size = self.code_size();
+ let centroid = self.list_centroid(list_id);
for (local_id, &row_id) in self.ids[list_id].iter().enumerate() {
if filter.map(|f| !f.contains(row_id)).unwrap_or(false) {
continue;
}
let code = &self.codes[list_id][local_id * code_size..(local_id +
1) * code_size];
heap.push(
- self.sq.distance_to_code_with_context(query, code, context),
+ sq.distance_to_code_with_offset_with_context(query, code,
centroid, context),
row_id,
);
}
}
+
+ pub(crate) fn decode_list_vectors(&self, list_id: usize, count: usize) ->
Vec<f32> {
+ let mut vectors = vec![0.0f32; count * self.d];
+ self.list_sq(list_id)
+ .decode_batch(&self.codes[list_id], count, &mut vectors);
+ let centroid = self.list_centroid(list_id);
+ for vector in vectors.chunks_exact_mut(self.d) {
+ for i in 0..self.d {
+ vector[i] += centroid[i];
+ }
+ }
+ vectors
+ }
+
+ fn assign_residuals(&self, processed: &[f32], n: usize) -> (Vec<usize>,
Vec<f32>) {
+ let mut list_ids = Vec::with_capacity(n);
+ let mut residuals = vec![0.0f32; n * self.d];
+ for i in 0..n {
+ let vector = &processed[i * self.d..(i + 1) * self.d];
+ let list_id =
+ kmeans::find_nearest(vector, &self.quantizer_centroids,
self.nlist, self.d);
+ list_ids.push(list_id);
+ self.write_residual(
+ vector,
+ list_id,
+ &mut residuals[i * self.d..(i + 1) * self.d],
+ );
+ }
+ (list_ids, residuals)
+ }
+
+ fn train_list_sqs(&mut self, list_ids: &[usize], residuals: &[f32]) {
+ let mut list_residuals = vec![Vec::new(); self.nlist];
+ for (i, &list_id) in list_ids.iter().enumerate() {
+ let residual = &residuals[i * self.d..(i + 1) * self.d];
+ list_residuals[list_id].extend_from_slice(residual);
+ }
+ self.list_sqs = vec![self.sq.clone(); self.nlist];
+ for (list_id, values) in list_residuals.iter().enumerate() {
+ if !values.is_empty() {
+ let mut sq = ScalarQuantizer::new(self.d);
+ sq.train(values, values.len() / self.d);
+ self.list_sqs[list_id] = sq;
+ }
+ }
+ }
+
+ fn write_residual(&self, vector: &[f32], list_id: usize, out: &mut [f32]) {
+ let centroid = self.list_centroid(list_id);
+ for i in 0..self.d {
+ out[i] = vector[i] - centroid[i];
+ }
+ }
+
+ fn list_centroid(&self, list_id: usize) -> &[f32] {
+ &self.quantizer_centroids[list_id * self.d..(list_id + 1) * self.d]
+ }
+
+ pub(crate) fn list_sq(&self, list_id: usize) -> &ScalarQuantizer {
+ self.list_sqs.get(list_id).unwrap_or(&self.sq)
+ }
+
+ pub(crate) fn code_size(&self) -> usize {
+ self.sq.code_size()
+ }
}
#[cfg(test)]
diff --git a/core/src/ivfhnswsq_io.rs b/core/src/ivfhnswsq_io.rs
index cb06eef..c86a478 100644
--- a/core/src/ivfhnswsq_io.rs
+++ b/core/src/ivfhnswsq_io.rs
@@ -70,10 +70,17 @@ pub fn write_ivfhnswsq_index(index: &IVFHNSWSQIndex, out:
&mut dyn SeekWrite) ->
usize_to_i32(params.ef_construction, "hnsw ef_construction")?,
)?;
write_i32_le(out, usize_to_i32(params.max_level, "hnsw max_level")?)?;
- out.write_all(&index.sq.min.to_le_bytes())?;
- out.write_all(&index.sq.max.to_le_bytes())?;
+ let (sq_min, sq_max) = sq_global_bounds(&index.sq.mins, &index.sq.maxs);
+ out.write_all(&sq_min.to_le_bytes())?;
+ out.write_all(&sq_max.to_le_bytes())?;
out.write_all(&[0u8; 16])?;
+ write_f32_slice(out, &index.sq.mins)?;
+ write_f32_slice(out, &index.sq.maxs)?;
+ for sq in &index.list_sqs {
+ write_f32_slice(out, &sq.mins)?;
+ write_f32_slice(out, &sq.maxs)?;
+ }
write_f32_slice(out, &index.quantizer_centroids)?;
let offset_table_size = index.nlist.checked_mul(24).ok_or_else(|| {
@@ -141,6 +148,7 @@ pub struct IVFHNSWSQIndexReader<R: SeekRead> {
pub total_vectors: i64,
pub hnsw_params: HnswBuildParams,
pub sq: ScalarQuantizer,
+ pub list_sqs: Vec<ScalarQuantizer>,
pub quantizer_centroids: Vec<f32>,
pub list_offsets: Vec<i64>,
pub list_counts: Vec<i32>,
@@ -186,21 +194,23 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
max_level: validate_positive_i32(read_i32_le(&mut reader)?, "hnsw
max_level")? as usize,
}
.sanitized();
- let mut min_bytes = [0u8; 4];
- let mut max_bytes = [0u8; 4];
- reader.read_exact(&mut min_bytes)?;
- reader.read_exact(&mut max_bytes)?;
- let sq_min = f32::from_le_bytes(min_bytes);
- let sq_max = f32::from_le_bytes(max_bytes);
- if !sq_min.is_finite() || !sq_max.is_finite() {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "SQ bounds must be finite",
- ));
- }
+ let mut bounds_summary = [0u8; 8];
+ reader.read_exact(&mut bounds_summary)?;
let mut reserved = [0u8; 16];
reader.read_exact(&mut reserved)?;
+ let mins = read_f32_vec(&mut reader, d)?;
+ let maxs = read_f32_vec(&mut reader, d)?;
+ validate_sq_bounds(d, &mins, &maxs)?;
+ let sq = ScalarQuantizer::with_dimension_bounds(d, mins, maxs);
+ let mut list_sqs = Vec::with_capacity(nlist);
+ for _ in 0..nlist {
+ let mins = read_f32_vec(&mut reader, d)?;
+ let maxs = read_f32_vec(&mut reader, d)?;
+ validate_sq_bounds(d, &mins, &maxs)?;
+ list_sqs.push(ScalarQuantizer::with_dimension_bounds(d, mins,
maxs));
+ }
+
Ok(Self {
reader,
d,
@@ -208,7 +218,8 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
metric,
total_vectors,
hnsw_params,
- sq: ScalarQuantizer::with_bounds(d, sq_min, sq_max),
+ sq,
+ list_sqs,
quantizer_centroids: Vec::new(),
list_offsets: Vec::new(),
list_counts: Vec::new(),
@@ -222,7 +233,9 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
return Ok(());
}
- self.reader.seek(IVF_HNSW_SQ_HEADER_SIZE as u64)?;
+ let quantizer_centroids_offset =
+ IVF_HNSW_SQ_HEADER_SIZE as u64 + (self.d as u64) * 8 * (self.nlist
as u64 + 1);
+ self.reader.seek(quantizer_centroids_offset)?;
self.quantizer_centroids =
read_f32_vec(&mut self.reader, checked_section_size(self.nlist,
self.d)?)?;
self.list_offsets = vec![0; self.nlist];
@@ -300,7 +313,14 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
.collect();
let codes = payload[ids_bytes_len..ids_bytes_len +
codes_bytes_len].to_vec();
let mut vectors = vec![0.0f32; count * self.d];
- self.sq.decode_batch(&codes, count, &mut vectors);
+ self.list_sq(list_id)
+ .decode_batch(&codes, count, &mut vectors);
+ let centroid = self.list_centroid(list_id).to_vec();
+ for vector in vectors.chunks_exact_mut(self.d) {
+ for i in 0..self.d {
+ vector[i] += centroid[i];
+ }
+ }
let graph = decode_graph(
&payload[ids_bytes_len + codes_bytes_len..],
vectors,
@@ -315,7 +335,21 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
format!("list {} is missing HNSW graph", list_id),
)
})?;
- Ok(Some(GraphList { ids, codes, graph }))
+ Ok(Some(GraphList {
+ ids,
+ codes,
+ graph,
+ centroid: Some(centroid),
+ sq: self.list_sq(list_id).clone(),
+ }))
+ }
+
+ fn list_centroid(&self, list_id: usize) -> &[f32] {
+ &self.quantizer_centroids[list_id * self.d..(list_id + 1) * self.d]
+ }
+
+ fn list_sq(&self, list_id: usize) -> &ScalarQuantizer {
+ self.list_sqs.get(list_id).unwrap_or(&self.sq)
}
pub fn search(
@@ -362,7 +396,8 @@ impl<R: SeekRead> IVFHNSWSQIndexReader<R> {
&q,
&list.ids,
&list.codes,
- &self.sq,
+ list.centroid.as_deref(),
+ &list.sq,
self.metric,
filter,
heap,
@@ -457,6 +492,8 @@ pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
ids: list.ids,
codes: list.codes,
graph: list.graph,
+ centroid: list.centroid,
+ sq: list.sq,
});
}
@@ -471,7 +508,8 @@ pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
query,
&list.ids,
&list.codes,
- &reader.sq,
+ list.centroid.as_deref(),
+ &list.sq,
reader.metric,
filter,
&mut heaps[qi],
@@ -501,7 +539,8 @@ pub fn search_batch_ivfhnswsq_reader_filter<R: SeekRead>(
query,
&list.ids,
&list.codes,
- &reader.sq,
+ list.centroid.as_deref(),
+ &list.sq,
reader.metric,
filter,
&mut heaps[qi],
@@ -541,6 +580,8 @@ struct GraphList {
ids: Vec<i64>,
codes: Vec<u8>,
graph: HnswGraph,
+ centroid: Option<Vec<f32>>,
+ sq: ScalarQuantizer,
}
struct LoadedBatchList {
@@ -548,12 +589,15 @@ struct LoadedBatchList {
ids: Vec<i64>,
codes: Vec<u8>,
graph: HnswGraph,
+ centroid: Option<Vec<f32>>,
+ sq: ScalarQuantizer,
}
fn scan_sq_list(
query: &[f32],
ids: &[i64],
codes: &[u8],
+ centroid: Option<&[f32]>,
sq: &ScalarQuantizer,
metric: MetricType,
filter: Option<&dyn RowIdFilter>,
@@ -566,10 +610,12 @@ fn scan_sq_list(
continue;
}
let code = &codes[local_id * code_size..(local_id + 1) * code_size];
- heap.push(
- sq.distance_to_code_with_context(query, code, context),
- row_id,
- );
+ let dist = if let Some(centroid) = centroid {
+ sq.distance_to_code_with_offset_with_context(query, code,
centroid, context)
+ } else {
+ sq.distance_to_code_with_context(query, code, context)
+ };
+ heap.push(dist, row_id);
}
}
@@ -592,6 +638,25 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) ->
io::Result<()> {
"SQ dimension does not match index dimension",
));
}
+ validate_sq_bounds(index.d, &index.sq.mins, &index.sq.maxs)?;
+ if index.list_sqs.len() != index.nlist {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "SQ list bounds count does not match nlist",
+ ));
+ }
+ for (list_id, sq) in index.list_sqs.iter().enumerate() {
+ if sq.d != index.d {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "SQ dimension for list {} does not match index dimension",
+ list_id
+ ),
+ ));
+ }
+ validate_sq_bounds(index.d, &sq.mins, &sq.maxs)?;
+ }
let centroid_len = checked_section_size(index.nlist, index.d)?;
if index.quantizer_centroids.len() != centroid_len {
return Err(io::Error::new(
@@ -628,10 +693,7 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) ->
io::Result<()> {
}
match &index.graphs[list_id] {
Some(graph) if count > 0 => {
- let mut decoded = vec![0.0f32; count * index.d];
- index
- .sq
- .decode_batch(&index.codes[list_id], count, &mut decoded);
+ let decoded = index.decode_list_vectors(list_id, count);
if graph.len() != count || graph.vectors() !=
decoded.as_slice() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
@@ -660,6 +722,39 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) ->
io::Result<()> {
Ok(())
}
+fn validate_sq_bounds(d: usize, mins: &[f32], maxs: &[f32]) -> io::Result<()> {
+ if mins.len() != d || maxs.len() != d {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "SQ bounds length mismatch: d={}, mins={}, maxs={}",
+ d,
+ mins.len(),
+ maxs.len()
+ ),
+ ));
+ }
+ for (dim, (&min, &max)) in mins.iter().zip(maxs.iter()).enumerate() {
+ if !min.is_finite() || !max.is_finite() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("SQ bounds at dimension {} must be finite", dim),
+ ));
+ }
+ }
+ Ok(())
+}
+
+fn sq_global_bounds(mins: &[f32], maxs: &[f32]) -> (f32, f32) {
+ let min = mins.iter().copied().fold(f32::INFINITY, f32::min);
+ let max = maxs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
+ if min.is_finite() && max.is_finite() {
+ (min, max)
+ } else {
+ (0.0, 0.0)
+ }
+}
+
fn list_payload_len(count: usize, code_size: usize, graph_bytes_len: usize) ->
io::Result<usize> {
let id_bytes = checked_list_bytes(count, 8)?;
let code_bytes = checked_list_bytes(count, code_size)?;
@@ -719,6 +814,35 @@ mod tests {
assert!(distances[0].is_finite());
}
+ #[test]
+ fn test_ivfhnswsq_write_read_preserves_sq_dimension_bounds() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, -100.0, 1.0, 100.0];
+ let ids = vec![10, 11];
+ let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2,
HnswBuildParams::default());
+ index.train(&data, 2);
+ index.add(&data, &ids, 2);
+ index.build_graphs().unwrap();
+
+ let mut buf = Vec::new();
+ write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
+ assert_eq!(
+ u32::from_le_bytes(buf[4..8].try_into().unwrap()),
+ IVF_HNSW_SQ_VERSION
+ );
+
+ let reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap();
+
+ assert_eq!(reader.sq.mins, index.sq.mins);
+ assert_eq!(reader.sq.maxs, index.sq.maxs);
+ assert_eq!(reader.sq.min, index.sq.min);
+ assert_eq!(reader.sq.max, index.sq.max);
+ assert_eq!(reader.list_sqs.len(), nlist);
+ assert_eq!(reader.list_sqs[0].mins, index.list_sqs[0].mins);
+ assert_eq!(reader.list_sqs[0].maxs, index.list_sqs[0].maxs);
+ }
+
#[test]
fn test_ivfhnswsq_reader_search_with_roaring_filter() {
let d = 2;
diff --git a/core/src/sq.rs b/core/src/sq.rs
index f1eb5da..812f94f 100644
--- a/core/src/sq.rs
+++ b/core/src/sq.rs
@@ -17,11 +17,13 @@
use crate::distance::{fvec_norm_l2sqr, MetricType};
-#[derive(Debug, Clone, Copy, PartialEq)]
+#[derive(Debug, Clone, PartialEq)]
pub struct ScalarQuantizer {
pub d: usize,
pub min: f32,
pub max: f32,
+ pub mins: Vec<f32>,
+ pub maxs: Vec<f32>,
}
impl ScalarQuantizer {
@@ -30,29 +32,56 @@ impl ScalarQuantizer {
d,
min: 0.0,
max: 0.0,
+ mins: vec![0.0; d],
+ maxs: vec![0.0; d],
}
}
pub fn with_bounds(d: usize, min: f32, max: f32) -> Self {
- Self { d, min, max }
+ Self {
+ d,
+ min,
+ max,
+ mins: vec![min; d],
+ maxs: vec![max; d],
+ }
+ }
+
+ pub fn with_dimension_bounds(d: usize, mins: Vec<f32>, maxs: Vec<f32>) ->
Self {
+ assert_eq!(mins.len(), d);
+ assert_eq!(maxs.len(), d);
+ let mut sq = Self {
+ d,
+ min: 0.0,
+ max: 0.0,
+ mins,
+ maxs,
+ };
+ sq.refresh_global_bounds();
+ sq
}
pub fn train(&mut self, data: &[f32], n: usize) {
- let values = &data[..n * self.d];
- if values.is_empty() {
+ let len = n * self.d;
+ let values = &data[..len];
+ if n == 0 || self.d == 0 {
self.min = 0.0;
self.max = 0.0;
+ self.mins.fill(0.0);
+ self.maxs.fill(0.0);
return;
}
- let mut min = f32::INFINITY;
- let mut max = f32::NEG_INFINITY;
- for &value in values {
- min = min.min(value);
- max = max.max(value);
+ self.ensure_bounds_len();
+ self.mins.fill(f32::INFINITY);
+ self.maxs.fill(f32::NEG_INFINITY);
+ for vector in values.chunks_exact(self.d) {
+ for i in 0..self.d {
+ self.mins[i] = self.mins[i].min(vector[i]);
+ self.maxs[i] = self.maxs[i].max(vector[i]);
+ }
}
- self.min = min;
- self.max = max;
+ self.refresh_global_bounds();
}
pub fn code_size(&self) -> usize {
@@ -64,15 +93,19 @@ impl ScalarQuantizer {
assert!(data.len() >= len);
assert!(codes.len() >= len);
- if self.min >= self.max {
- codes[..len].fill(0);
- return;
- }
-
- let scale = 255.0 / (self.max - self.min);
- for i in 0..len {
- let scaled = ((data[i] - self.min) * scale).clamp(0.0, 255.0);
- codes[i] = scaled as u8;
+ for row in 0..n {
+ let base = row * self.d;
+ for dim in 0..self.d {
+ let min = self.mins[dim];
+ let max = self.maxs[dim];
+ let out = base + dim;
+ codes[out] = if min >= max {
+ 0
+ } else {
+ let scaled = ((data[out] - min) * 255.0 / (max -
min)).clamp(0.0, 255.0);
+ scaled.round() as u8
+ };
+ }
}
}
@@ -85,14 +118,11 @@ impl ScalarQuantizer {
assert!(codes.len() >= len);
assert!(vectors.len() >= len);
- if self.min >= self.max {
- vectors[..len].fill(self.min);
- return;
- }
-
- let scale = (self.max - self.min) / 255.0;
- for i in 0..len {
- vectors[i] = self.min + codes[i] as f32 * scale;
+ for row in 0..n {
+ let base = row * self.d;
+ for dim in 0..self.d {
+ vectors[base + dim] = self.decode_value(codes[base + dim],
dim);
+ }
}
}
@@ -114,6 +144,31 @@ impl ScalarQuantizer {
query: &[f32],
code: &[u8],
context: DistanceContext,
+ ) -> f32 {
+ self.distance_to_code_impl(query, code, &[], false, context)
+ }
+
+ pub fn distance_to_code_with_offset_with_context(
+ &self,
+ query: &[f32],
+ code: &[u8],
+ offset: &[f32],
+ context: DistanceContext,
+ ) -> f32 {
+ debug_assert!(query.len() >= self.d);
+ debug_assert!(code.len() >= self.d);
+ debug_assert!(offset.len() >= self.d);
+
+ self.distance_to_code_impl(query, code, offset, true, context)
+ }
+
+ fn distance_to_code_impl(
+ &self,
+ query: &[f32],
+ code: &[u8],
+ offset: &[f32],
+ use_offset: bool,
+ context: DistanceContext,
) -> f32 {
debug_assert!(query.len() >= self.d);
debug_assert!(code.len() >= self.d);
@@ -122,7 +177,8 @@ impl ScalarQuantizer {
MetricType::L2 => {
let mut sum = 0.0f32;
for i in 0..self.d {
- let diff = query[i] - self.decode_value(code[i]);
+ let diff =
+ query[i] - self.decode_value_with_offset(code[i], i,
offset, use_offset);
sum += diff * diff;
}
sum
@@ -130,7 +186,7 @@ impl ScalarQuantizer {
MetricType::InnerProduct => {
let mut dot = 0.0f32;
for i in 0..self.d {
- dot += query[i] * self.decode_value(code[i]);
+ dot += query[i] * self.decode_value_with_offset(code[i],
i, offset, use_offset);
}
-dot
}
@@ -138,7 +194,7 @@ impl ScalarQuantizer {
let mut dot = 0.0f32;
let mut vector_norm = 0.0f32;
for i in 0..self.d {
- let value = self.decode_value(code[i]);
+ let value = self.decode_value_with_offset(code[i], i,
offset, use_offset);
dot += query[i] * value;
vector_norm += value * value;
}
@@ -152,12 +208,48 @@ impl ScalarQuantizer {
}
}
- fn decode_value(&self, code: u8) -> f32 {
- if self.min >= self.max {
- self.min
+ fn decode_value_with_offset(
+ &self,
+ code: u8,
+ dim: usize,
+ offset: &[f32],
+ use_offset: bool,
+ ) -> f32 {
+ let value = self.decode_value(code, dim);
+ if use_offset {
+ value + offset[dim]
+ } else {
+ value
+ }
+ }
+
+ fn decode_value(&self, code: u8, dim: usize) -> f32 {
+ let min = self.mins[dim];
+ let max = self.maxs[dim];
+ if min >= max {
+ min
} else {
- self.min + code as f32 * (self.max - self.min) / 255.0
+ min + code as f32 * (max - min) / 255.0
+ }
+ }
+
+ fn ensure_bounds_len(&mut self) {
+ if self.mins.len() != self.d {
+ self.mins.resize(self.d, 0.0);
+ }
+ if self.maxs.len() != self.d {
+ self.maxs.resize(self.d, 0.0);
+ }
+ }
+
+ fn refresh_global_bounds(&mut self) {
+ if self.d == 0 {
+ self.min = 0.0;
+ self.max = 0.0;
+ return;
}
+ self.min = self.mins.iter().copied().fold(f32::INFINITY, f32::min);
+ self.max = self.maxs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
}
}
@@ -195,6 +287,8 @@ mod tests {
assert_eq!(sq.min, -1.0);
assert_eq!(sq.max, 3.0);
+ assert_eq!(sq.mins, vec![-1.0, 0.0]);
+ assert_eq!(sq.maxs, vec![1.0, 3.0]);
assert_eq!(codes[0], 0);
assert_eq!(codes[3], 255);
assert!((decoded[0] + 1.0).abs() < 1e-6);
@@ -216,6 +310,24 @@ mod tests {
assert_eq!(decoded, data);
}
+ #[test]
+ fn test_scalar_quantizer_uses_per_dimension_bounds() {
+ let data = vec![0.0, -100.0, 1.0, 100.0];
+ let mut sq = ScalarQuantizer::new(2);
+ sq.train(&data, 2);
+
+ let mut codes = vec![0u8; data.len()];
+ sq.encode_batch(&data, 2, &mut codes);
+ let mut decoded = vec![0.0f32; data.len()];
+ sq.decode_batch(&codes, 2, &mut decoded);
+
+ assert_eq!(codes, vec![0, 0, 255, 255]);
+ assert!((decoded[0] - 0.0).abs() < 1e-6);
+ assert!((decoded[1] + 100.0).abs() < 1e-6);
+ assert!((decoded[2] - 1.0).abs() < 1e-6);
+ assert!((decoded[3] - 100.0).abs() < 1e-6);
+ }
+
#[test]
fn test_scalar_quantizer_distance_to_code() {
let sq = ScalarQuantizer::with_bounds(2, 0.0, 1.0);