This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 54fbcfe Add IVF-PQ index module (#5)
54fbcfe is described below
commit 54fbcfe0ca51e1ca69e0b2e03f8628e1bff99227
Author: Jingsong Lee <[email protected]>
AuthorDate: Mon Jun 8 08:34:02 2026 +0800
Add IVF-PQ index module (#5)
---
core/src/ivfpq.rs | 1365 +++++++++++++++++++++++++++++++++++++++++++++++++++++
core/src/lib.rs | 1 +
core/src/opq.rs | 5 +-
3 files changed, 1369 insertions(+), 2 deletions(-)
diff --git a/core/src/ivfpq.rs b/core/src/ivfpq.rs
new file mode 100644
index 0000000..00826b9
--- /dev/null
+++ b/core/src/ivfpq.rs
@@ -0,0 +1,1365 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{
+ fvec_madd, fvec_normalize, pq_distance_four_codes, pq_distance_from_table,
MetricType,
+};
+use crate::kmeans::{self, KMeansConfig};
+use crate::opq::OPQMatrix;
+use crate::pq::ProductQuantizer;
+use rayon::prelude::*;
+use std::collections::HashSet;
+
+/// IVF-PQ index aligned with Faiss's IndexIVFPQ.
+pub struct IVFPQIndex {
+ pub d: usize,
+ pub nlist: usize,
+ pub metric: MetricType,
+ pub by_residual: bool,
+
+ pub quantizer_centroids: Vec<f32>,
+ pub pq: ProductQuantizer,
+ pub opq: Option<OPQMatrix>,
+
+ pub ids: Vec<Vec<i64>>,
+ pub codes: Vec<Vec<u8>>,
+
+ /// Precomputed table [nlist * M * ksub] for L2+by_residual mode.
+ /// Avoids recomputing distance table per list during search.
+ precomputed_table: Vec<f32>,
+ /// Block-layout packed codes for 4-bit FastScan. One per list.
+ fastscan_codes: Vec<Vec<u8>>,
+}
+
+impl IVFPQIndex {
+ pub fn new(d: usize, nlist: usize, m: usize, metric: MetricType, use_opq:
bool) -> Self {
+ Self::with_nbits(d, nlist, m, 8, metric, use_opq)
+ }
+
+ pub fn with_nbits(
+ d: usize,
+ nlist: usize,
+ m: usize,
+ nbits: usize,
+ metric: MetricType,
+ use_opq: bool,
+ ) -> Self {
+ let by_residual = metric == MetricType::L2;
+ IVFPQIndex {
+ d,
+ nlist,
+ metric,
+ by_residual,
+ quantizer_centroids: Vec::new(),
+ pq: ProductQuantizer::with_nbits(d, m, nbits),
+ opq: if use_opq {
+ Some(OPQMatrix::new(d, m))
+ } else {
+ None
+ },
+ ids: vec![Vec::new(); nlist],
+ codes: vec![Vec::new(); nlist],
+ precomputed_table: Vec::new(),
+ fastscan_codes: Vec::new(),
+ }
+ }
+
+ /// Create an index with automatic nlist based on target partition size.
+ /// nlist = max(1, n / target_partition_size), clamped to reasonable
bounds.
+ pub fn with_target_partition_size(
+ d: usize,
+ n: usize,
+ target_partition_size: usize,
+ m: usize,
+ metric: MetricType,
+ use_opq: bool,
+ ) -> Self {
+ let nlist = (n / target_partition_size.max(1)).clamp(1, 65536);
+ Self::new(d, nlist, m, metric, use_opq)
+ }
+
+ /// Create an index from an already-trained index, copying centroids,
codebooks, and OPQ.
+ /// The new index has empty inverted lists — call `add()` to populate.
+ /// Used for distributed build: train once globally, then each worker
creates from_trained.
+ pub fn from_trained(trained: &IVFPQIndex) -> Self {
+ IVFPQIndex {
+ d: trained.d,
+ nlist: trained.nlist,
+ metric: trained.metric,
+ by_residual: trained.by_residual,
+ quantizer_centroids: trained.quantizer_centroids.clone(),
+ pq: ProductQuantizer {
+ d: trained.pq.d,
+ m: trained.pq.m,
+ nbits: trained.pq.nbits,
+ dsub: trained.pq.dsub,
+ ksub: trained.pq.ksub,
+ centroids: trained.pq.centroids.clone(),
+ centroid_norms_cache: trained.pq.centroid_norms_cache.clone(),
+ },
+ opq: trained.opq.as_ref().map(|o| OPQMatrix {
+ d: o.d,
+ m: o.m,
+ niter: 0,
+ niter_pq: 0,
+ niter_pq_0: 0,
+ max_train_points: 0,
+ rotation: o.rotation.clone(),
+ is_trained: true,
+ }),
+ ids: vec![Vec::new(); trained.nlist],
+ codes: vec![Vec::new(); trained.nlist],
+ precomputed_table: Vec::new(),
+ fastscan_codes: Vec::new(),
+ }
+ }
+
+ pub fn train(&mut self, data: &[f32], n: usize) {
+ let d = self.d;
+
+ let train_data = if self.metric == MetricType::Cosine {
+ let mut normalized = data[..n * d].to_vec();
+ for i in 0..n {
+ fvec_normalize(&mut normalized[i * d..(i + 1) * d]);
+ }
+ normalized
+ } else {
+ data[..n * d].to_vec()
+ };
+
+ // When OPQ is enabled, jointly train rotation + PQ, then project data.
+ // IVF centroids must be trained on projected (rotated) data since
+ // add() and search() assign rotated vectors via preprocess_queries().
+ let effective_data = if let Some(ref mut opq) = self.opq {
+ opq.train(&train_data, n, &mut self.pq);
+ let mut projected = vec![0.0f32; n * d];
+ opq.apply_batch(&train_data, &mut projected, n);
+ projected
+ } else {
+ train_data
+ };
+
+ let km_config = KMeansConfig::default();
+ self.quantizer_centroids =
+ kmeans::kmeans_train(&km_config, &effective_data, n, d,
self.nlist);
+
+ // Retrain PQ on the exact distribution that add/search will encode.
+ // For OPQ: opq.train() trained PQ on centered data, but add/search
+ // encode uncentered vectors, so we must retrain here for all metrics.
+ let pq_train_data = if self.by_residual {
+ compute_residuals(&effective_data, n, d,
&self.quantizer_centroids, self.nlist)
+ } else {
+ effective_data
+ };
+ self.pq.train(&pq_train_data, n);
+ }
+
+ /// Add vectors in batches (Faiss-style: batch assign → batch residual →
batch encode).
+ pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
+ const BATCH_SIZE: usize = 32768;
+ let mut offset = 0;
+ while offset < n {
+ let batch_n = (n - offset).min(BATCH_SIZE);
+ self.add_batch(
+ &data[offset * self.d..(offset + batch_n) * self.d],
+ &ids[offset..offset + batch_n],
+ batch_n,
+ );
+ offset += batch_n;
+ }
+ }
+
+ fn add_batch(&mut self, data: &[f32], ids: &[i64], n: usize) {
+ let d = self.d;
+
+ // Step 1: Preprocess (normalize + OPQ rotate)
+ let processed = self.preprocess_queries(data, n);
+
+ // Step 2: Batch assign to coarse centroids (uses sgemm)
+ let assignments: Vec<usize> = (0..n)
+ .into_par_iter()
+ .map(|i| {
+ kmeans::find_nearest(
+ &processed[i * d..(i + 1) * d],
+ &self.quantizer_centroids,
+ self.nlist,
+ d,
+ )
+ })
+ .collect();
+
+ // Step 3: Batch compute residuals (parallel)
+ let to_encode = if self.by_residual {
+ let mut residuals = vec![0.0f32; n * d];
+ residuals
+ .par_chunks_mut(d)
+ .enumerate()
+ .for_each(|(i, res)| {
+ let list_id = assignments[i];
+ for j in 0..d {
+ res[j] = processed[i * d + j] -
self.quantizer_centroids[list_id * d + j];
+ }
+ });
+ residuals
+ } else {
+ processed
+ };
+
+ // Step 4: Batch PQ encode (parallel)
+ let cs = self.pq.code_size();
+ let mut codes = vec![0u8; n * cs];
+ self.pq.encode_batch(&to_encode, n, &mut codes);
+
+ // Step 5: Distribute to inverted lists
+ for i in 0..n {
+ let list_id = assignments[i];
+ self.ids[list_id].push(ids[i]);
+ self.codes[list_id].extend_from_slice(&codes[i * cs..(i + 1) *
cs]);
+ }
+
+ // Invalidate stale precomputed structures (must rebuild after all
adds)
+ if !self.fastscan_codes.is_empty() {
+ self.fastscan_codes.clear();
+ }
+ if !self.precomputed_table.is_empty() {
+ self.precomputed_table.clear();
+ }
+ }
+
+ /// Build fastscan block codes for 4-bit search acceleration.
+ /// Call after all vectors are added. Lightweight — only reorganizes
existing codes.
+ pub fn build_search_structures(&mut self) {
+ if self.pq.nbits == 4 {
+ let cs = self.pq.code_size();
+ self.fastscan_codes = self
+ .codes
+ .iter()
+ .enumerate()
+ .map(|(list_id, codes)| {
+ let count = self.ids[list_id].len();
+ if count == 0 {
+ Vec::new()
+ } else {
+ crate::fastscan::pack_codes_block_layout(codes, count,
cs)
+ }
+ })
+ .collect();
+ }
+ }
+
+ /// Build precomputed distance tables for faster repeated queries.
+ /// Only useful for long-running services with many queries on the same
index.
+ /// Costs ~10ms to build and uses nlist * M * ksub * 4 bytes of memory.
+ pub fn build_precomputed_table(&mut self) {
+ let d = self.d;
+ let m = self.pq.m;
+ let ksub = self.pq.ksub;
+ let nlist = self.nlist;
+
+ if self.metric != MetricType::L2 || !self.by_residual {
+ return;
+ }
+ {
+ let pq_norms = self.pq.compute_centroid_norms();
+ let mut table = vec![0.0f32; nlist * m * ksub];
+
+ for i in 0..nlist {
+ let centroid = &self.quantizer_centroids[i * d..(i + 1) * d];
+ let tab_base = i * m * ksub;
+
+ for sub in 0..m {
+ let sub_centroid = ¢roid[sub * self.pq.dsub..(sub + 1)
* self.pq.dsub];
+ let pq_base = sub * ksub * self.pq.dsub;
+
+ for j in 0..ksub {
+ let pq_off = pq_base + j * self.pq.dsub;
+ let mut ip = 0.0f32;
+ for dd in 0..self.pq.dsub {
+ ip += sub_centroid[dd] * self.pq.centroids[pq_off
+ dd];
+ }
+ table[tab_base + sub * ksub + j] = pq_norms[sub * ksub
+ j] + 2.0 * ip;
+ }
+ }
+ }
+ self.precomputed_table = table;
+ }
+ }
+
+ /// Search for top-k nearest neighbors.
+ /// Uses rayon to parallelize across queries.
+ pub fn search(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ self.search_with_filter(
+ queries,
+ nq,
+ k,
+ nprobe,
+ None,
+ result_distances,
+ result_labels,
+ );
+ }
+
+ /// Search with optional ID filter.
+ pub fn search_with_filter(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ filter: Option<&HashSet<i64>>,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ let d = self.d;
+ let m = self.pq.m;
+ let ksub = self.pq.ksub;
+
+ let processed_queries = self.preprocess_queries(queries, nq);
+
+ let (all_probe_indices, all_coarse_dists) = kmeans::find_topk_batch(
+ &processed_queries,
+ nq,
+ &self.quantizer_centroids,
+ self.nlist,
+ d,
+ nprobe,
+ );
+
+ let use_precomputed = !self.precomputed_table.is_empty();
+ let use_fastscan = !self.fastscan_codes.is_empty() && self.pq.nbits ==
4;
+
+ let results: Vec<Vec<(f32, i64)>> = (0..nq)
+ .into_par_iter()
+ .map(|qi| {
+ let query = &processed_queries[qi * d..(qi + 1) * d];
+ let probe_indices = &all_probe_indices[qi];
+ let coarse_dists = &all_coarse_dists[qi];
+
+ let mut heap = TopKHeap::new(k);
+ let mut sim_table = vec![0.0f32; m * ksub];
+
+ let ip_table = if use_precomputed {
+ let mut t = vec![0.0f32; m * ksub];
+ self.pq.compute_inner_product_table(query, &mut t);
+ t
+ } else {
+ Vec::new()
+ };
+
+ for (probe_rank, &list_id) in probe_indices.iter().enumerate()
{
+ let count = self.ids[list_id].len();
+ if count == 0 {
+ continue;
+ }
+
+ // Precomputed sim_table omits ||q-c||²; add it as dis0.
+ // Non-precomputed path computes from residual_query,
already full distance.
+ let dis0 = if use_precomputed {
+ coarse_dists[probe_rank]
+ } else {
+ 0.0
+ };
+
+ if use_precomputed {
+ let tab_base = list_id * m * ksub;
+ fvec_madd(
+ &self.precomputed_table[tab_base..tab_base + m *
ksub],
+ &ip_table,
+ -2.0,
+ &mut sim_table,
+ );
+ } else {
+ self.compute_list_table(query, list_id, &mut
sim_table);
+ }
+
+ if use_fastscan {
+ let mut dists = vec![0.0f32; count];
+ crate::fastscan::fastscan_4bit(
+ &sim_table,
+ &self.fastscan_codes[list_id],
+ count,
+ m,
+ &mut dists,
+ );
+ for i in 0..count {
+ if let Some(f) = filter {
+ if !f.contains(&self.ids[list_id][i]) {
+ continue;
+ }
+ }
+ heap.push(dis0 + dists[i], self.ids[list_id][i]);
+ }
+ } else if self.pq.nbits == 4 {
+ scan_codes_4bit(
+ &sim_table,
+ &self.codes[list_id],
+ &self.ids[list_id],
+ count,
+ m,
+ ksub,
+ dis0,
+ filter,
+ &mut heap,
+ );
+ } else {
+ scan_codes_batched(
+ &sim_table,
+ &self.codes[list_id],
+ &self.ids[list_id],
+ count,
+ m,
+ ksub,
+ dis0,
+ filter,
+ &mut heap,
+ );
+ }
+ }
+
+ heap.into_sorted()
+ })
+ .collect();
+
+ for (qi, result) in results.into_iter().enumerate() {
+ let out_base = qi * k;
+ for (i, &(dist, id)) in result.iter().enumerate() {
+ result_distances[out_base + i] = dist;
+ result_labels[out_base + i] = id;
+ }
+ for i in result.len()..k {
+ result_distances[out_base + i] = f32::MAX;
+ result_labels[out_base + i] = -1;
+ }
+ }
+ }
+
+ fn preprocess_queries(&self, queries: &[f32], nq: usize) -> Vec<f32> {
+ let d = self.d;
+ let mut processed = queries[..nq * d].to_vec();
+
+ if self.metric == MetricType::Cosine {
+ for i in 0..nq {
+ fvec_normalize(&mut processed[i * d..(i + 1) * d]);
+ }
+ }
+
+ if let Some(ref opq) = self.opq {
+ let mut rotated = vec![0.0f32; nq * d];
+ opq.apply_batch(&processed, &mut rotated, nq);
+ return rotated;
+ }
+
+ processed
+ }
+
+ fn compute_list_table(&self, query: &[f32], list_id: usize, sim_table:
&mut [f32]) {
+ let d = self.d;
+ if self.by_residual {
+ let mut residual_query = vec![0.0f32; d];
+ for j in 0..d {
+ residual_query[j] = query[j] -
self.quantizer_centroids[list_id * d + j];
+ }
+ self.pq
+ .compute_distance_table(&residual_query, self.metric,
sim_table);
+ } else {
+ self.pq
+ .compute_distance_table(query, self.metric, sim_table);
+ }
+ }
+
+ /// Search with max_codes budget: stop scanning when total scanned codes
exceeds limit.
+ /// Useful for bounding worst-case latency when some inverted lists are
very large.
+ pub fn search_with_max_codes(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ max_codes: usize,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ let d = self.d;
+ let m = self.pq.m;
+ let ksub = self.pq.ksub;
+
+ let processed_queries = self.preprocess_queries(queries, nq);
+ let (all_probe_indices, all_coarse_dists) = kmeans::find_topk_batch(
+ &processed_queries,
+ nq,
+ &self.quantizer_centroids,
+ self.nlist,
+ d,
+ nprobe,
+ );
+
+ let use_precomputed = !self.precomputed_table.is_empty();
+ let use_fastscan = !self.fastscan_codes.is_empty() && self.pq.nbits ==
4;
+
+ let results: Vec<Vec<(f32, i64)>> = (0..nq)
+ .into_par_iter()
+ .map(|qi| {
+ let query = &processed_queries[qi * d..(qi + 1) * d];
+ let probe_indices = &all_probe_indices[qi];
+ let coarse_dists = &all_coarse_dists[qi];
+
+ let mut heap = TopKHeap::new(k);
+ let mut sim_table = vec![0.0f32; m * ksub];
+ let mut total_scanned = 0usize;
+
+ let ip_table = if use_precomputed {
+ let mut t = vec![0.0f32; m * ksub];
+ self.pq.compute_inner_product_table(query, &mut t);
+ t
+ } else {
+ Vec::new()
+ };
+
+ for (probe_rank, &list_id) in probe_indices.iter().enumerate()
{
+ let count = self.ids[list_id].len();
+ if count == 0 {
+ continue;
+ }
+
+ if total_scanned >= max_codes {
+ break;
+ }
+ let scan_count = count.min(max_codes - total_scanned);
+
+ let dis0 = if use_precomputed {
+ coarse_dists[probe_rank]
+ } else {
+ 0.0
+ };
+
+ if use_precomputed {
+ let tab_base = list_id * m * ksub;
+ fvec_madd(
+ &self.precomputed_table[tab_base..tab_base + m *
ksub],
+ &ip_table,
+ -2.0,
+ &mut sim_table,
+ );
+ } else {
+ self.compute_list_table(query, list_id, &mut
sim_table);
+ }
+
+ if use_fastscan {
+ let mut dists = vec![0.0f32; scan_count];
+ crate::fastscan::fastscan_4bit(
+ &sim_table,
+ &self.fastscan_codes[list_id],
+ scan_count,
+ m,
+ &mut dists,
+ );
+ for i in 0..scan_count {
+ heap.push(dis0 + dists[i], self.ids[list_id][i]);
+ }
+ } else if self.pq.nbits == 4 {
+ scan_codes_4bit(
+ &sim_table,
+ &self.codes[list_id],
+ &self.ids[list_id],
+ scan_count,
+ m,
+ ksub,
+ dis0,
+ None,
+ &mut heap,
+ );
+ } else {
+ scan_codes_batched(
+ &sim_table,
+ &self.codes[list_id],
+ &self.ids[list_id],
+ scan_count,
+ m,
+ ksub,
+ dis0,
+ None,
+ &mut heap,
+ );
+ }
+
+ total_scanned += scan_count;
+ }
+
+ heap.into_sorted()
+ })
+ .collect();
+
+ for (qi, result) in results.into_iter().enumerate() {
+ let out_base = qi * k;
+ for (i, &(dist, id)) in result.iter().enumerate() {
+ result_distances[out_base + i] = dist;
+ result_labels[out_base + i] = id;
+ }
+ for i in result.len()..k {
+ result_distances[out_base + i] = f32::MAX;
+ result_labels[out_base + i] = -1;
+ }
+ }
+ }
+
+ /// Merge another index's inverted lists into this one.
+ /// Both indexes must have the same centroids and codebooks (trained from
the same data).
+ /// Used for compaction: merging multiple small index files into one.
+ pub fn merge_from(&mut self, other: &IVFPQIndex) {
+ assert_eq!(self.d, other.d, "Dimension mismatch");
+ assert_eq!(self.nlist, other.nlist, "nlist mismatch");
+ assert_eq!(self.pq.m, other.pq.m, "PQ M mismatch");
+ assert_eq!(self.pq.nbits, other.pq.nbits, "PQ nbits mismatch");
+
+ for list_id in 0..self.nlist {
+ self.ids[list_id].extend_from_slice(&other.ids[list_id]);
+ self.codes[list_id].extend_from_slice(&other.codes[list_id]);
+ }
+
+ // Invalidate precomputed structures (need to rebuild after merge)
+ self.fastscan_codes.clear();
+ self.precomputed_table.clear();
+ }
+}
+
+/// Scan 4-bit packed codes using u8-domain accumulation.
+fn scan_codes_4bit(
+ sim_table: &[f32],
+ codes: &[u8],
+ ids: &[i64],
+ count: usize,
+ m: usize,
+ _ksub: usize,
+ dis0: f32,
+ filter: Option<&HashSet<i64>>,
+ heap: &mut TopKHeap,
+) {
+ let mut dists = vec![0.0f32; count];
+ crate::distance::scan_4bit_simd(sim_table, codes, count, m, &mut dists);
+
+ for i in 0..count {
+ if let Some(f) = filter {
+ if !f.contains(&ids[i]) {
+ continue;
+ }
+ }
+ heap.push(dis0 + dists[i], ids[i]);
+ }
+}
+
+/// Scan 4-bit transposed codes: layout [M/2][n].
+/// Each sub-quantizer pair's codes are contiguous — ideal for SIMD.
+#[allow(dead_code)]
+fn scan_codes_4bit_transposed(
+ sim_table: &[f32],
+ codes: &[u8],
+ ids: &[i64],
+ count: usize,
+ m: usize,
+ dis0: f32,
+ filter: Option<&HashSet<i64>>,
+ heap: &mut TopKHeap,
+) {
+ let cs = m / 2;
+
+ const FLAT_NUM: usize = 200;
+ let flat_end = count.min(FLAT_NUM);
+
+ let mut dists = vec![0.0f32; count];
+
+ for i in 0..flat_end {
+ let mut d = 0.0f32;
+ for pair in 0..cs {
+ let byte = codes[pair * count + i];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ d += sim_table[(pair * 2) * 16 + lo];
+ d += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ dists[i] = d;
+ }
+
+ if count > FLAT_NUM {
+ let qmin = sim_table.iter().cloned().fold(f32::INFINITY, f32::min);
+ let qmax = dists[..flat_end].iter().cloned().fold(f32::MIN, f32::max);
+ let range = (qmax - qmin).max(1e-10);
+ let factor = 255.0 / range;
+
+ let qtable: Vec<u8> = sim_table
+ .iter()
+ .map(|&d| ((d - qmin) * factor).clamp(0.0, 255.0) as u8)
+ .collect();
+
+ let mut q_dists = vec![0u16; count];
+ for pair in 0..cs {
+ let qtab_lo = &qtable[(pair * 2) * 16..(pair * 2 + 1) * 16];
+ let qtab_hi = &qtable[(pair * 2 + 1) * 16..(pair * 2 + 2) * 16];
+ let col = &codes[pair * count..];
+
+ for i in flat_end..count {
+ let byte = col[i];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ q_dists[i] += qtab_lo[lo] as u16 + qtab_hi[hi] as u16;
+ }
+ }
+
+ let inv_factor = range / 255.0;
+ let base_dist = qmin * m as f32;
+ for i in flat_end..count {
+ dists[i] = q_dists[i] as f32 * inv_factor + base_dist;
+ }
+ }
+
+ for i in 0..count {
+ if let Some(f) = filter {
+ if !f.contains(&ids[i]) {
+ continue;
+ }
+ }
+ heap.push(dis0 + dists[i], ids[i]);
+ }
+}
+
+/// Scan transposed (column-major) codes: layout is [M][n].
+/// The distance table sub-slice stays in L1 cache for the entire inner loop.
+#[allow(dead_code)]
+fn scan_codes_transposed(
+ sim_table: &[f32],
+ codes: &[u8],
+ ids: &[i64],
+ count: usize,
+ m: usize,
+ ksub: usize,
+ dis0: f32,
+ filter: Option<&HashSet<i64>>,
+ heap: &mut TopKHeap,
+) {
+ let mut dists = vec![dis0; count];
+ for sub in 0..m {
+ let tab_base = sub * ksub;
+ let col_base = sub * count;
+ for i in 0..count {
+ dists[i] += sim_table[tab_base + codes[col_base + i] as usize];
+ }
+ }
+
+ for i in 0..count {
+ if let Some(f) = filter {
+ if !f.contains(&ids[i]) {
+ continue;
+ }
+ }
+ heap.push(dists[i], ids[i]);
+ }
+}
+
+/// Scan inverted list codes with 4-code batching for ILP (row-major layout).
+fn scan_codes_batched(
+ sim_table: &[f32],
+ codes: &[u8],
+ ids: &[i64],
+ count: usize,
+ m: usize,
+ ksub: usize,
+ dis0: f32,
+ filter: Option<&HashSet<i64>>,
+ heap: &mut TopKHeap,
+) {
+ let mut i = 0;
+
+ while i + 4 <= count {
+ let dists = pq_distance_four_codes(
+ sim_table,
+ codes,
+ m,
+ ksub,
+ [i * m, (i + 1) * m, (i + 2) * m, (i + 3) * m],
+ );
+
+ for j in 0..4 {
+ let idx = i + j;
+ let id = ids[idx];
+ if let Some(f) = filter {
+ if !f.contains(&id) {
+ continue;
+ }
+ }
+ heap.push(dis0 + dists[j], id);
+ }
+ i += 4;
+ }
+
+ while i < count {
+ let code = &codes[i * m..(i + 1) * m];
+ let dist = dis0 + pq_distance_from_table(sim_table, code, m, ksub);
+ let id = ids[i];
+ if let Some(f) = filter {
+ if !f.contains(&id) {
+ i += 1;
+ continue;
+ }
+ }
+ heap.push(dist, id);
+ i += 1;
+ }
+}
+
+// --- Top-K Heap ---
+
+struct TopKHeap {
+ k: usize,
+ data: Vec<(f32, i64)>,
+ built: bool,
+}
+
+impl TopKHeap {
+ fn new(k: usize) -> Self {
+ TopKHeap {
+ k,
+ data: Vec::with_capacity(k),
+ built: false,
+ }
+ }
+
+ #[inline]
+ fn push(&mut self, dist: f32, id: i64) {
+ if self.k == 0 {
+ return;
+ }
+ if self.data.len() < self.k {
+ self.data.push((dist, id));
+ if self.data.len() == self.k {
+ build_max_heap(&mut self.data);
+ self.built = true;
+ }
+ } else if dist < self.data[0].0 {
+ self.data[0] = (dist, id);
+ sift_down(&mut self.data, 0);
+ }
+ }
+
+ fn into_sorted(mut self) -> Vec<(f32, i64)> {
+ self.data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ self.data
+ }
+}
+
+// --- Utilities ---
+
+fn compute_residuals(
+ data: &[f32],
+ n: usize,
+ d: usize,
+ centroids: &[f32],
+ nlist: usize,
+) -> Vec<f32> {
+ let mut residuals = vec![0.0f32; n * d];
+ for i in 0..n {
+ let point = &data[i * d..(i + 1) * d];
+ let list_id = kmeans::find_nearest(point, centroids, nlist, d);
+ for j in 0..d {
+ residuals[i * d + j] = point[j] - centroids[list_id * d + j];
+ }
+ }
+ residuals
+}
+
+fn build_max_heap(heap: &mut [(f32, i64)]) {
+ let n = heap.len();
+ for i in (0..n / 2).rev() {
+ sift_down(heap, i);
+ }
+}
+
+fn sift_down(heap: &mut [(f32, i64)], mut i: usize) {
+ let n = heap.len();
+ loop {
+ let mut largest = i;
+ let left = 2 * i + 1;
+ let right = 2 * i + 2;
+
+ if left < n && heap[left].0 > heap[largest].0 {
+ largest = left;
+ }
+ if right < n && heap[right].0 > heap[largest].0 {
+ largest = right;
+ }
+ if largest == i {
+ break;
+ }
+ heap.swap(i, largest);
+ i = largest;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::rngs::StdRng;
+ use rand::{Rng, SeedableRng};
+
+ fn generate_clustered_data(n: usize, d: usize, num_clusters: usize, seed:
u64) -> Vec<f32> {
+ let mut rng = StdRng::seed_from_u64(seed);
+ let mut centers = vec![0.0f32; num_clusters * d];
+ for i in 0..num_clusters * d {
+ centers[i] = rng.gen::<f32>() * 100.0;
+ }
+
+ let mut data = vec![0.0f32; n * d];
+ for i in 0..n {
+ let cluster = i % num_clusters;
+ for j in 0..d {
+ data[i * d + j] = centers[cluster * d + j] + rng.gen::<f32>()
* 2.0 - 1.0;
+ }
+ }
+ data
+ }
+
+ #[test]
+ fn test_build_and_search_l2() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+ let k = 5;
+ let nprobe = 2;
+
+ let data = generate_clustered_data(n, d, 4, 42);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let query = &data[0..d];
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search(query, 1, k, nprobe, &mut dists, &mut labels);
+
+ assert_eq!(labels[0], 0);
+ for i in 1..k {
+ assert!(dists[i] >= dists[i - 1]);
+ }
+ }
+
+ #[test]
+ fn test_build_and_search_ip() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+
+ let data = generate_clustered_data(n, d, 4, 123);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::InnerProduct,
false);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut dists = vec![0.0f32; 5];
+ let mut labels = vec![0i64; 5];
+ index.search(&data[0..d], 1, 5, 2, &mut dists, &mut labels);
+
+ for i in 1..5 {
+ assert!(dists[i] >= dists[i - 1]);
+ }
+ }
+
+ #[test]
+ fn test_search_with_filter() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+ let k = 5;
+
+ let data = generate_clustered_data(n, d, 4, 42);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let filter: HashSet<i64> = (0..n as i64).filter(|id| id % 2 ==
0).collect();
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search_with_filter(&data[0..d], 1, k, 4, Some(&filter), &mut
dists, &mut labels);
+
+ for &label in &labels[..k] {
+ if label >= 0 {
+ assert!(label % 2 == 0, "Filter violated: got odd ID {}",
label);
+ }
+ }
+ }
+
+ #[test]
+ fn test_batch_search() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+ let k = 5;
+ let nq = 10;
+
+ let data = generate_clustered_data(n, d, 4, 42);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let queries: Vec<f32> = data[..nq * d].to_vec();
+ let mut dists = vec![0.0f32; nq * k];
+ let mut labels = vec![0i64; nq * k];
+ index.search(&queries, nq, k, 2, &mut dists, &mut labels);
+
+ for qi in 0..nq {
+ assert_eq!(labels[qi * k], qi as i64);
+ }
+ }
+
+ #[test]
+ fn test_4bit_ivfpq() {
+ let d = 16;
+ let nlist = 4;
+ let m = 8;
+ let n = 1000;
+ let k = 5;
+ let nprobe = 2;
+
+ let data = generate_clustered_data(n, d, 4, 42);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::with_nbits(d, nlist, m, 4, MetricType::L2,
false);
+ assert_eq!(index.pq.ksub, 16);
+ assert_eq!(index.pq.code_size(), 4);
+
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search(&data[0..d], 1, k, nprobe, &mut dists, &mut labels);
+
+ assert_eq!(labels[0], 0);
+ for i in 1..k {
+ assert!(dists[i] >= dists[i - 1]);
+ }
+
+ let codes_8bit_size = n * m;
+ let codes_4bit_size: usize = index.codes.iter().map(|c| c.len()).sum();
+ assert!(
+ codes_4bit_size < codes_8bit_size,
+ "4-bit ({}) should be smaller than 8-bit ({})",
+ codes_4bit_size,
+ codes_8bit_size,
+ );
+ }
+
+ #[test]
+ fn test_max_codes_early_termination() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+ let k = 5;
+
+ let data = generate_clustered_data(n, d, 4, 42);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut dists_limited = vec![0.0f32; k];
+ let mut labels_limited = vec![0i64; k];
+ index.search_with_max_codes(
+ &data[0..d],
+ 1,
+ k,
+ 4,
+ 50,
+ &mut dists_limited,
+ &mut labels_limited,
+ );
+
+ let valid = labels_limited.iter().filter(|&&id| id >= 0).count();
+ assert!(valid > 0, "max_codes search returned no results");
+
+ let mut dists_full = vec![0.0f32; k];
+ let mut labels_full = vec![0i64; k];
+ index.search(&data[0..d], 1, k, 4, &mut dists_full, &mut labels_full);
+
+ assert!(dists_full[0] <= dists_limited[0] + 1e-6);
+ }
+
+ #[test]
+ fn test_from_trained_and_merge() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 500;
+
+ let data = generate_clustered_data(n * 2, d, 4, 42);
+ let ids_a: Vec<i64> = (0..n as i64).collect();
+ let ids_b: Vec<i64> = (n as i64..2 * n as i64).collect();
+
+ let mut trainer = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ trainer.train(&data[..n * d], n);
+
+ let mut worker_a = IVFPQIndex::from_trained(&trainer);
+ worker_a.add(&data[..n * d], &ids_a, n);
+
+ let mut worker_b = IVFPQIndex::from_trained(&trainer);
+ worker_b.add(&data[n * d..], &ids_b, n);
+
+ let total_a: usize = worker_a.ids.iter().map(|l| l.len()).sum();
+ let total_b: usize = worker_b.ids.iter().map(|l| l.len()).sum();
+ assert_eq!(total_a + total_b, n * 2);
+
+ let mut merged = IVFPQIndex::from_trained(&trainer);
+ merged.merge_from(&worker_a);
+ merged.merge_from(&worker_b);
+
+ let total_merged: usize = merged.ids.iter().map(|l| l.len()).sum();
+ assert_eq!(total_merged, n * 2);
+
+ let mut dists = vec![0.0f32; 5];
+ let mut labels = vec![0i64; 5];
+ merged.search(&data[0..d], 1, 5, 4, &mut dists, &mut labels);
+ assert_eq!(labels[0], 0);
+
+ merged.search(&data[n * d..(n + 1) * d], 1, 5, 4, &mut dists, &mut
labels);
+ assert_eq!(labels[0], n as i64);
+ }
+
+ #[test]
+ fn test_opq_ip() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+ let k = 5;
+
+ let data = generate_clustered_data(n, d, 4, 55);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::InnerProduct,
true);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search(&data[0..d], 1, k, 4, &mut dists, &mut labels);
+
+ let valid = labels.iter().filter(|&&id| id >= 0).count();
+ assert!(valid > 0, "OPQ+IP should return results");
+ for i in 1..valid {
+ assert!(dists[i] >= dists[i - 1]);
+ }
+ }
+
+ #[test]
+ fn test_opq_cosine() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+ let k = 5;
+
+ let data = generate_clustered_data(n, d, 4, 77);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::Cosine, true);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search(&data[0..d], 1, k, 4, &mut dists, &mut labels);
+
+ let valid = labels.iter().filter(|&&id| id >= 0).count();
+ assert!(valid > 0, "OPQ+Cosine should return results");
+ for i in 1..valid {
+ assert!(dists[i] >= dists[i - 1]);
+ }
+ }
+
+ #[test]
+ fn test_opq_4bit() {
+ let d = 16;
+ let nlist = 4;
+ let m = 8;
+ let n = 1000;
+ let k = 5;
+
+ let data = generate_clustered_data(n, d, 4, 42);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::with_nbits(d, nlist, m, 4, MetricType::L2,
true);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search(&data[0..d], 1, k, 4, &mut dists, &mut labels);
+
+ assert_eq!(labels[0], 0, "OPQ+4bit should recall query vector itself");
+ for i in 1..k {
+ assert!(dists[i] >= dists[i - 1]);
+ }
+ }
+
+ #[test]
+ fn test_precomputed_table_matches_normal_search() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 1000;
+ let k = 10;
+ let nprobe = 4;
+
+ let data = generate_clustered_data(n, d, 4, 42);
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ // Normal search
+ let mut dists_normal = vec![0.0f32; k];
+ let mut labels_normal = vec![0i64; k];
+ index.search(
+ &data[0..d],
+ 1,
+ k,
+ nprobe,
+ &mut dists_normal,
+ &mut labels_normal,
+ );
+
+ // Enable precomputed table and search again
+ index.build_precomputed_table();
+ let mut dists_precomp = vec![0.0f32; k];
+ let mut labels_precomp = vec![0i64; k];
+ index.search(
+ &data[0..d],
+ 1,
+ k,
+ nprobe,
+ &mut dists_precomp,
+ &mut labels_precomp,
+ );
+
+ // Same top-k ranking
+ assert_eq!(
+ labels_normal, labels_precomp,
+ "precomputed table should produce identical ranking"
+ );
+ for i in 0..k {
+ assert!(
+ (dists_normal[i] - dists_precomp[i]).abs() < 1e-2,
+ "distance mismatch at rank {}: normal={}, precomp={}",
+ i,
+ dists_normal[i],
+ dists_precomp[i]
+ );
+ }
+ }
+
+ #[test]
+ fn test_fastscan_invalidated_after_add() {
+ let d = 16;
+ let nlist = 4;
+ let m = 8;
+ let n = 500;
+ let k = 5;
+
+ let data = generate_clustered_data(n * 2, d, 4, 42);
+ let ids_a: Vec<i64> = (0..n as i64).collect();
+ let ids_b: Vec<i64> = (n as i64..2 * n as i64).collect();
+
+ let mut index = IVFPQIndex::with_nbits(d, nlist, m, 4, MetricType::L2,
false);
+ index.train(&data, n);
+ index.add(&data[..n * d], &ids_a, n);
+
+ // Build fastscan, then add more vectors
+ index.build_search_structures();
+ assert!(!index.fastscan_codes.is_empty());
+
+ index.add(&data[n * d..], &ids_b, n);
+ assert!(
+ index.fastscan_codes.is_empty(),
+ "fastscan_codes must be cleared after add()"
+ );
+
+ // Rebuild and search — should find vectors from both batches
+ index.build_search_structures();
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search(&data[0..d], 1, k, 4, &mut dists, &mut labels);
+ assert_eq!(labels[0], 0);
+
+ index.search(&data[n * d..(n + 1) * d], 1, k, 4, &mut dists, &mut
labels);
+ assert_eq!(labels[0], n as i64);
+ }
+
+ #[test]
+ fn test_precomputed_table_invalidated_after_add() {
+ let d = 16;
+ let nlist = 4;
+ let m = 4;
+ let n = 500;
+
+ let data = generate_clustered_data(n * 2, d, 4, 42);
+ let ids_a: Vec<i64> = (0..n as i64).collect();
+ let ids_b: Vec<i64> = (n as i64..2 * n as i64).collect();
+
+ let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+ index.train(&data[..n * d], n);
+ index.add(&data[..n * d], &ids_a, n);
+
+ index.build_precomputed_table();
+ assert!(!index.precomputed_table.is_empty());
+
+ index.add(&data[n * d..], &ids_b, n);
+ assert!(
+ index.precomputed_table.is_empty(),
+ "precomputed_table must be cleared after add()"
+ );
+
+ // Rebuild and search — should find vectors from both batches
+ index.build_precomputed_table();
+ let k = 5;
+ let mut dists = vec![0.0f32; k];
+ let mut labels = vec![0i64; k];
+ index.search(&data[0..d], 1, k, 4, &mut dists, &mut labels);
+ assert_eq!(labels[0], 0);
+
+ index.search(&data[n * d..(n + 1) * d], 1, k, 4, &mut dists, &mut
labels);
+ assert_eq!(labels[0], n as i64);
+ }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index 9f03d17..86595a1 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -21,6 +21,7 @@
pub mod blas;
pub mod distance;
pub mod fastscan;
+pub mod ivfpq;
pub mod kmeans;
pub mod opq;
pub mod pq;
diff --git a/core/src/opq.rs b/core/src/opq.rs
index 8c0739b..fc00b7a 100644
--- a/core/src/opq.rs
+++ b/core/src/opq.rs
@@ -105,7 +105,8 @@ impl OPQMatrix {
let mut projected = vec![0.0f32; train_n * d];
let mut reconstructed = vec![0.0f32; train_n * d];
- let mut codes = vec![0u8; train_n * pq.m];
+ let cs = pq.code_size();
+ let mut codes = vec![0u8; train_n * cs];
for iter in 0..self.niter {
// 1. Project: projected = train_data * R^T
@@ -128,7 +129,7 @@ impl OPQMatrix {
pq.encode_batch(&projected, train_n, &mut codes);
for i in 0..train_n {
pq.decode(
- &codes[i * pq.m..(i + 1) * pq.m],
+ &codes[i * cs..(i + 1) * cs],
&mut reconstructed[i * d..(i + 1) * d],
);
}