This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 17273a8 Add opq, shuffler, and fastscan modules (#3)
17273a8 is described below
commit 17273a81bd15a1799b2dc5d570f2a25469bbc1fb
Author: Jingsong Lee <[email protected]>
AuthorDate: Sun Jun 7 21:20:09 2026 +0800
Add opq, shuffler, and fastscan modules (#3)
- opq: Optimized Product Quantization with rotation matrix (SVD-based)
- shuffler: Disk-based partition shuffler for large-scale index building
- fastscan: Block-layout 4-bit PQ scanning with SIMD acceleration
---
core/Cargo.toml | 1 +
core/src/fastscan.rs | 500 +++++++++++++++++++++++++++++++++++++++++++++++++++
core/src/lib.rs | 3 +
core/src/opq.rs | 256 ++++++++++++++++++++++++++
core/src/shuffler.rs | 276 ++++++++++++++++++++++++++++
5 files changed, 1036 insertions(+)
diff --git a/core/Cargo.toml b/core/Cargo.toml
index 69ba8a7..22ecb07 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -22,6 +22,7 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
+nalgebra = "0.33"
rand = "0.8"
rayon = "1.10"
matrixmultiply = "0.3"
diff --git a/core/src/fastscan.rs b/core/src/fastscan.rs
new file mode 100644
index 0000000..6865590
--- /dev/null
+++ b/core/src/fastscan.rs
@@ -0,0 +1,500 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! FastScan: Faiss-style block layout (bbs=32) + vpshufb 32-way parallel
lookup
+//! for 4-bit PQ codes.
+//!
+//! Block layout: 32 vectors per block, codes interleaved by sub-quantizer
pair.
+//! Each block stores M/2 groups of 32 bytes (one byte per vector per
sub-quant pair).
+//!
+//! Layout: [block0_pair0(32B), block0_pair1(32B), ..., block0_pairM/2(32B),
+//! block1_pair0(32B), ...]
+
+/// Block size: 32 vectors per block (matches AVX2 register width).
+pub const BBS: usize = 32;
+
+/// Pack 4-bit codes from row-major [n][cs] into block layout.
+/// Output layout: [num_blocks][cs][BBS] where cs = M/2.
+/// Pads the last block with zeros if n is not a multiple of BBS.
+pub fn pack_codes_block_layout(codes: &[u8], n: usize, cs: usize) -> Vec<u8> {
+ let num_blocks = n.div_ceil(BBS);
+ let block_size = cs * BBS; // bytes per block
+ let mut packed = vec![0u8; num_blocks * block_size];
+
+ for block in 0..num_blocks {
+ let block_start = block * BBS;
+ for pair in 0..cs {
+ for vec_in_block in 0..BBS {
+ let global_vec = block_start + vec_in_block;
+ if global_vec < n {
+ packed[block * block_size + pair * BBS + vec_in_block] =
+ codes[global_vec * cs + pair];
+ }
+ // else: remains 0 (padding)
+ }
+ }
+ }
+
+ packed
+}
+
+/// Unpack block layout back to row-major [n][cs] (for compatibility).
+pub fn unpack_codes_block_layout(packed: &[u8], n: usize, cs: usize) ->
Vec<u8> {
+ let num_blocks = n.div_ceil(BBS);
+ let block_size = cs * BBS;
+ let mut codes = vec![0u8; n * cs];
+
+ for block in 0..num_blocks {
+ let block_start = block * BBS;
+ for pair in 0..cs {
+ for vec_in_block in 0..BBS {
+ let global_vec = block_start + vec_in_block;
+ if global_vec < n {
+ codes[global_vec * cs + pair] =
+ packed[block * block_size + pair * BBS + vec_in_block];
+ }
+ }
+ }
+ }
+
+ codes
+}
+
+/// Quantize a f32 distance table [M * 16] to u8 [M * 16].
+/// Returns (qmin, qmax_used, quantized_table).
+pub fn quantize_distance_table(table: &[f32], qmax_hint: f32) -> (f32, f32,
Vec<u8>) {
+ let qmin = table.iter().cloned().fold(f32::INFINITY, f32::min);
+ let qmax = qmax_hint.max(table.iter().cloned().fold(f32::MIN, f32::max));
+ let range = (qmax - qmin).max(1e-10);
+ let factor = 255.0 / range;
+
+ let qtable: Vec<u8> = table
+ .iter()
+ .map(|&d| ((d - qmin) * factor).clamp(0.0, 255.0) as u8)
+ .collect();
+
+ (qmin, qmax, qtable)
+}
+
+/// FastScan: scan block-layout 4-bit codes using SIMD.
+/// codes: block layout [num_blocks][cs][BBS]
+/// sim_table: [M * 16] f32 distance table
+/// Returns f32 distances for `n` vectors.
+pub fn fastscan_4bit(sim_table: &[f32], codes: &[u8], n: usize, m: usize,
dists: &mut [f32]) {
+ let cs = m / 2;
+
+ // Step 1: f32 exact for first min(200, n) vectors as qmax calibration
+ const FLAT_NUM: usize = 200;
+ let flat_end = n.min(FLAT_NUM);
+ let block_size = cs * BBS;
+
+ for i in 0..flat_end {
+ let block = i / BBS;
+ let vec_in_block = i % BBS;
+ let mut d = 0.0f32;
+ for pair in 0..cs {
+ let byte = codes[block * block_size + pair * BBS + vec_in_block];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ d += sim_table[(pair * 2) * 16 + lo];
+ d += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ dists[i] = d;
+ }
+
+ if n <= FLAT_NUM {
+ return;
+ }
+
+ // Step 2: Quantize distance table using the table's own range
+ let table_max = sim_table.iter().cloned().fold(f32::MIN, f32::max);
+ let (qmin, _, qtable) = quantize_distance_table(sim_table, table_max);
+ let range = (table_max - qmin).max(1e-10);
+
+ // Step 3: Scan blocks using SIMD
+ let num_blocks = n.div_ceil(BBS);
+ let start_block = flat_end.div_ceil(BBS);
+
+ #[cfg(target_arch = "x86_64")]
+ {
+ if is_x86_feature_detected!("avx2") {
+ unsafe {
+ fastscan_blocks_avx2(
+ &qtable,
+ codes,
+ cs,
+ start_block,
+ num_blocks,
+ block_size,
+ qmin,
+ range,
+ m,
+ dists,
+ );
+ }
+ let partial_start = start_block * BBS;
+ for i in flat_end..partial_start.min(n) {
+ let block = i / BBS;
+ let vec_in_block = i % BBS;
+ let mut d = 0.0f32;
+ for pair in 0..cs {
+ let byte = codes[block * block_size + pair * BBS +
vec_in_block];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ d += sim_table[(pair * 2) * 16 + lo];
+ d += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ dists[i] = d;
+ }
+ return;
+ }
+ }
+
+ #[cfg(target_arch = "aarch64")]
+ {
+ unsafe {
+ fastscan_blocks_neon(
+ &qtable,
+ codes,
+ cs,
+ start_block,
+ num_blocks,
+ block_size,
+ qmin,
+ range,
+ m,
+ dists,
+ );
+ }
+ let partial_start = start_block * BBS;
+ for i in flat_end..partial_start.min(n) {
+ let block = i / BBS;
+ let vec_in_block = i % BBS;
+ let mut d = 0.0f32;
+ for pair in 0..cs {
+ let byte = codes[block * block_size + pair * BBS +
vec_in_block];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ d += sim_table[(pair * 2) * 16 + lo];
+ d += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ dists[i] = d;
+ }
+ return;
+ }
+
+ // Fallback: scalar scan on blocks (used when no SIMD available)
+ #[allow(unreachable_code)]
+ for block in start_block..num_blocks {
+ let base_vec = block * BBS;
+ let vecs_in_block = BBS.min(n - base_vec);
+
+ let mut q_dists = [0u16; BBS];
+
+ for pair in 0..cs {
+ let qtab_lo = &qtable[(pair * 2) * 16..(pair * 2 + 1) * 16];
+ let qtab_hi = &qtable[(pair * 2 + 1) * 16..(pair * 2 + 2) * 16];
+ let block_codes = &codes[block * block_size + pair * BBS..];
+
+ for v in 0..vecs_in_block {
+ let byte = block_codes[v];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ q_dists[v] += qtab_lo[lo] as u16 + qtab_hi[hi] as u16;
+ }
+ }
+
+ // Dequantize
+ let inv_factor = range / 255.0;
+ let base_dist = qmin * m as f32;
+ for v in 0..vecs_in_block {
+ dists[base_vec + v] = q_dists[v] as f32 * inv_factor + base_dist;
+ }
+ }
+
+ // Fill gap between flat_end and start_block*BBS with exact computation
+ let partial_start = start_block * BBS;
+ for i in flat_end..partial_start.min(n) {
+ let block = i / BBS;
+ let vec_in_block = i % BBS;
+ let mut d = 0.0f32;
+ for pair in 0..cs {
+ let byte = codes[block * block_size + pair * BBS + vec_in_block];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ d += sim_table[(pair * 2) * 16 + lo];
+ d += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ dists[i] = d;
+ }
+}
+
+/// AVX2 block scan using vpshufb for 32-way parallel 4-bit lookup.
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn fastscan_blocks_avx2(
+ qtable: &[u8],
+ codes: &[u8],
+ cs: usize,
+ start_block: usize,
+ num_blocks: usize,
+ block_size: usize,
+ qmin: f32,
+ range: f32,
+ m: usize,
+ dists: &mut [f32],
+) {
+ use std::arch::x86_64::*;
+
+ let inv_factor = range / 255.0;
+ let base_dist = qmin * m as f32;
+
+ for block in start_block..num_blocks {
+ let base_vec = block * BBS;
+
+ // u16 accumulators: 2 × __m256i = 32 × u16 values
+ let mut accu_lo = _mm256_setzero_si256(); // vecs 0-15
+ let mut accu_hi = _mm256_setzero_si256(); // vecs 16-31
+
+ for pair in 0..cs {
+ // Load 32-byte quantized LUT for this pair (lo + hi
sub-quantizers)
+ // LUT layout: [16 bytes for sub_lo, 16 bytes for sub_hi]
+ let lut_lo_ptr = qtable.as_ptr().add((pair * 2) * 16);
+ let lut_hi_ptr = qtable.as_ptr().add((pair * 2 + 1) * 16);
+
+ // Broadcast 16-byte LUTs into 256-bit registers (same 16-byte
table in both lanes)
+ let lut_lo =
_mm256_broadcastsi128_si256(_mm_loadu_si128(lut_lo_ptr as *const __m128i));
+ let lut_hi =
_mm256_broadcastsi128_si256(_mm_loadu_si128(lut_hi_ptr as *const __m128i));
+
+ // Load 32 code bytes for this pair in this block
+ let code_ptr = codes.as_ptr().add(block * block_size + pair * BBS);
+ let code_vec = _mm256_loadu_si256(code_ptr as *const __m256i);
+
+ // Split nibbles
+ let mask = _mm256_set1_epi8(0x0F);
+ let lo_nibbles = _mm256_and_si256(code_vec, mask);
+ let hi_nibbles = _mm256_and_si256(_mm256_srli_epi16(code_vec, 4),
mask);
+
+ // vpshufb: 32-way parallel lookup
+ let dist_lo = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
+ let dist_hi = _mm256_shuffle_epi8(lut_hi, hi_nibbles);
+
+ // Widen each to u16 separately then add (avoids u8 saturation
overflow)
+ let zero = _mm256_setzero_si256();
+ let dlo_lo = _mm256_unpacklo_epi8(dist_lo, zero);
+ let dlo_hi = _mm256_unpackhi_epi8(dist_lo, zero);
+ let dhi_lo = _mm256_unpacklo_epi8(dist_hi, zero);
+ let dhi_hi = _mm256_unpackhi_epi8(dist_hi, zero);
+
+ accu_lo = _mm256_add_epi16(accu_lo, _mm256_add_epi16(dlo_lo,
dhi_lo));
+ accu_hi = _mm256_add_epi16(accu_hi, _mm256_add_epi16(dlo_hi,
dhi_hi));
+ }
+
+ // Extract u16 values and dequantize to f32.
+ // _mm256_unpacklo/hi_epi8 operates per 128-bit lane, so:
+ // accu_lo = [v0..v7 | v16..v23], accu_hi = [v8..v15 | v24..v31]
+ // Reassemble correct order with cross-lane permute.
+ let result_lo = _mm256_permute2x128_si256(accu_lo, accu_hi, 0x20);
+ let result_hi = _mm256_permute2x128_si256(accu_lo, accu_hi, 0x31);
+ let mut q_vals = [0u16; BBS];
+ _mm256_storeu_si256(q_vals.as_mut_ptr() as *mut __m256i, result_lo);
+ _mm256_storeu_si256(q_vals.as_mut_ptr().add(16) as *mut __m256i,
result_hi);
+
+ for v in 0..BBS {
+ let idx = base_vec + v;
+ if idx < dists.len() {
+ dists[idx] = q_vals[v] as f32 * inv_factor + base_dist;
+ }
+ }
+ }
+}
+
+/// ARM NEON block scan (16-way per instruction, 2 passes per block).
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn fastscan_blocks_neon(
+ qtable: &[u8],
+ codes: &[u8],
+ cs: usize,
+ start_block: usize,
+ num_blocks: usize,
+ block_size: usize,
+ qmin: f32,
+ range: f32,
+ m: usize,
+ dists: &mut [f32],
+) {
+ use std::arch::aarch64::*;
+
+ let inv_factor = range / 255.0;
+ let base_dist = qmin * m as f32;
+
+ for block in start_block..num_blocks {
+ let base_vec = block * BBS;
+
+ // u16 accumulators: 4 × uint16x8_t = 32 × u16
+ let mut accu = [vdupq_n_u16(0); 4];
+
+ for pair in 0..cs {
+ let lut_lo = vld1q_u8(qtable.as_ptr().add((pair * 2) * 16));
+ let lut_hi = vld1q_u8(qtable.as_ptr().add((pair * 2 + 1) * 16));
+
+ let code_ptr = codes.as_ptr().add(block * block_size + pair * BBS);
+
+ // Process 16 vectors at a time (NEON = 128-bit)
+ for half in 0..2 {
+ let code_vec = vld1q_u8(code_ptr.add(half * 16));
+
+ let mask = vdupq_n_u8(0x0F);
+ let lo_nib = vandq_u8(code_vec, mask);
+ let hi_nib = vshrq_n_u8(code_vec, 4);
+
+ // tbl: 16-way lookup
+ let dist_lo = vqtbl1q_u8(lut_lo, lo_nib);
+ let dist_hi = vqtbl1q_u8(lut_hi, hi_nib);
+
+ // Widen each to u16 separately then add (avoids u8 saturation
overflow)
+ accu[half * 2] = vaddq_u16(accu[half * 2],
vmovl_u8(vget_low_u8(dist_lo)));
+ accu[half * 2] = vaddq_u16(accu[half * 2],
vmovl_u8(vget_low_u8(dist_hi)));
+ accu[half * 2 + 1] = vaddq_u16(accu[half * 2 + 1],
vmovl_u8(vget_high_u8(dist_lo)));
+ accu[half * 2 + 1] = vaddq_u16(accu[half * 2 + 1],
vmovl_u8(vget_high_u8(dist_hi)));
+ }
+ }
+
+ // Extract and dequantize
+ let mut q_vals = [0u16; BBS];
+ for i in 0..4 {
+ vst1q_u16(q_vals.as_mut_ptr().add(i * 8), accu[i]);
+ }
+
+ for v in 0..BBS {
+ let idx = base_vec + v;
+ if idx < dists.len() {
+ dists[idx] = q_vals[v] as f32 * inv_factor + base_dist;
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_pack_unpack_roundtrip() {
+ let n = 100;
+ let cs = 4; // m/2 = 4, i.e. m=8
+ let codes: Vec<u8> = (0..n * cs).map(|i| (i % 256) as u8).collect();
+
+ let packed = pack_codes_block_layout(&codes, n, cs);
+ let unpacked = unpack_codes_block_layout(&packed, n, cs);
+
+ assert_eq!(codes, unpacked);
+ }
+
+ #[test]
+ fn test_fastscan_correctness() {
+ let m = 8;
+ let cs = m / 2;
+ let n = 100;
+
+ // Random-ish codes
+ let codes_row: Vec<u8> = (0..n * cs).map(|i| ((i * 7 + 3) % 256) as
u8).collect();
+
+ // Random-ish distance table [M * 16]
+ let sim_table: Vec<f32> = (0..m * 16).map(|i| (i as f32) * 0.1 +
0.5).collect();
+
+ // Compute ground truth with scalar
+ let mut expected = vec![0.0f32; n];
+ for i in 0..n {
+ for pair in 0..cs {
+ let byte = codes_row[i * cs + pair];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ expected[i] += sim_table[(pair * 2) * 16 + lo];
+ expected[i] += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ }
+
+ // Pack into block layout and scan
+ let packed = pack_codes_block_layout(&codes_row, n, cs);
+ let mut result = vec![0.0f32; n];
+ fastscan_4bit(&sim_table, &packed, n, m, &mut result);
+
+ // Check first 200 (f32 exact) — should match perfectly
+ for i in 0..n.min(200) {
+ assert!(
+ (result[i] - expected[i]).abs() < 1e-5,
+ "Mismatch at {}: {} vs {}",
+ i,
+ result[i],
+ expected[i]
+ );
+ }
+ }
+
+ #[test]
+ fn test_fastscan_large() {
+ let m = 16;
+ let cs = m / 2;
+ let n = 1000; // > 200, exercises quantized path
+
+ let codes_row: Vec<u8> = (0..n * cs).map(|i| ((i * 13 + 7) % 256) as
u8).collect();
+ let sim_table: Vec<f32> = (0..m * 16).map(|i| (i as f32) * 0.05 +
1.0).collect();
+
+ // Compute scalar reference
+ let mut expected = vec![0.0f32; n];
+ for i in 0..n {
+ for pair in 0..cs {
+ let byte = codes_row[i * cs + pair];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ expected[i] += sim_table[(pair * 2) * 16 + lo];
+ expected[i] += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ }
+
+ let packed = pack_codes_block_layout(&codes_row, n, cs);
+ let mut result = vec![0.0f32; n];
+ fastscan_4bit(&sim_table, &packed, n, m, &mut result);
+
+ // First 200 are computed with f32 exact — should match perfectly
+ for i in 0..200 {
+ assert!(
+ (result[i] - expected[i]).abs() < 1e-5,
+ "Exact mismatch at {}: got {}, expected {}",
+ i,
+ result[i],
+ expected[i]
+ );
+ }
+
+ // Beyond 200: quantized path — allow quantization tolerance
+ let max_expected = expected.iter().cloned().fold(f32::MIN, f32::max);
+ let tolerance = max_expected * 0.02; // 2% relative tolerance for u8
quantization
+ for i in 200..n {
+ assert!(
+ (result[i] - expected[i]).abs() <= tolerance,
+ "SIMD mismatch at {}: got {}, expected {}, diff {}",
+ i,
+ result[i],
+ expected[i],
+ (result[i] - expected[i]).abs()
+ );
+ }
+ }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index a2df43f..9f03d17 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -20,5 +20,8 @@
pub mod blas;
pub mod distance;
+pub mod fastscan;
pub mod kmeans;
+pub mod opq;
pub mod pq;
+pub mod shuffler;
diff --git a/core/src/opq.rs b/core/src/opq.rs
new file mode 100644
index 0000000..8c0739b
--- /dev/null
+++ b/core/src/opq.rs
@@ -0,0 +1,256 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::kmeans::KMeansConfig;
+use crate::pq::ProductQuantizer;
+use nalgebra::{DMatrix, SVD};
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
+
+/// OPQ (Optimized Product Quantization) rotation matrix.
+/// Aligned with Faiss's OPQMatrix from VectorTransform.cpp.
+///
+/// Learns an orthogonal rotation R that minimizes PQ reconstruction error
+/// via alternating Procrustes optimization.
+pub struct OPQMatrix {
+ pub d: usize,
+ pub m: usize,
+ pub niter: usize,
+ pub niter_pq: usize,
+ pub niter_pq_0: usize,
+ pub max_train_points: usize,
+ /// Rotation matrix [d * d], row-major. y = R * x.
+ pub rotation: Vec<f32>,
+ pub is_trained: bool,
+}
+
+impl OPQMatrix {
+ pub fn new(d: usize, m: usize) -> Self {
+ OPQMatrix {
+ d,
+ m,
+ niter: 50,
+ niter_pq: 4,
+ niter_pq_0: 40,
+ max_train_points: 65536,
+ rotation: vec![0.0f32; d * d],
+ is_trained: false,
+ }
+ }
+
+ /// Train the OPQ rotation matrix.
+ /// data: flat [n * d].
+ pub fn train(&mut self, data: &[f32], n: usize, pq: &mut ProductQuantizer)
{
+ let d = self.d;
+ let mut rng = StdRng::seed_from_u64(12345);
+
+ // Subsample if needed
+ let train_n = n.min(self.max_train_points);
+ let mut train_data = if n > self.max_train_points {
+ let mut sub = vec![0.0f32; train_n * d];
+ let mut indices: Vec<usize> = (0..n).collect();
+ for i in 0..train_n {
+ let j = rng.gen_range(i..n);
+ indices.swap(i, j);
+ }
+ for (out_i, &src_i) in indices[..train_n].iter().enumerate() {
+ sub[out_i * d..(out_i + 1) * d].copy_from_slice(&data[src_i *
d..(src_i + 1) * d]);
+ }
+ sub
+ } else {
+ data[..n * d].to_vec()
+ };
+
+ // Center data (subtract mean) — aligned with Faiss OPQMatrix
+ let mut mean = vec![0.0f32; d];
+ for i in 0..train_n {
+ for j in 0..d {
+ mean[j] += train_data[i * d + j];
+ }
+ }
+ let inv_n = 1.0 / train_n as f32;
+ for j in 0..d {
+ mean[j] *= inv_n;
+ }
+ for i in 0..train_n {
+ for j in 0..d {
+ train_data[i * d + j] -= mean[j];
+ }
+ }
+
+ // Initialize with random orthogonal matrix via QR decomposition
+ let random_mat: Vec<f32> = (0..d * d).map(|_| rng.gen::<f32>() -
0.5).collect();
+ let mat = DMatrix::from_row_slice(d, d, &random_mat);
+ let qr = mat.qr();
+ let q = qr.q();
+ for i in 0..d {
+ for j in 0..d {
+ self.rotation[i * d + j] = q[(i, j)];
+ }
+ }
+
+ let mut projected = vec![0.0f32; train_n * d];
+ let mut reconstructed = vec![0.0f32; train_n * d];
+ let mut codes = vec![0u8; train_n * pq.m];
+
+ for iter in 0..self.niter {
+ // 1. Project: projected = train_data * R^T
+ self.apply_batch(&train_data, &mut projected, train_n);
+
+ // 2. Train PQ on projected data (hot-start on iter >= 1)
+ let pq_niter = if iter == 0 {
+ self.niter_pq_0
+ } else {
+ self.niter_pq
+ };
+ let km_config = KMeansConfig {
+ niter: pq_niter,
+ ..KMeansConfig::default()
+ };
+ let hot_start = iter > 0;
+ pq.train_hot_start(&projected, train_n, &km_config, hot_start);
+
+ // 3. Encode and decode to get reconstructions
+ pq.encode_batch(&projected, train_n, &mut codes);
+ for i in 0..train_n {
+ pq.decode(
+ &codes[i * pq.m..(i + 1) * pq.m],
+ &mut reconstructed[i * d..(i + 1) * d],
+ );
+ }
+
+ // 4. Solve Procrustes: find R that minimizes ||X - Y*R^T||
+ // Solution: R = V * U^T where X^T * Y = U * S * V^T
+ let x_mat = DMatrix::from_row_slice(train_n, d, &train_data);
+ let y_mat = DMatrix::from_row_slice(train_n, d, &reconstructed);
+ let cross_cov = x_mat.transpose() * &y_mat; // [d x d]
+
+ let svd = SVD::new(cross_cov, true, true);
+ if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
+ // R = U * V^T
+ let r = &u * &vt;
+ for i in 0..d {
+ for j in 0..d {
+ self.rotation[i * d + j] = r[(i, j)];
+ }
+ }
+ }
+ }
+
+ // Final PQ training with the learned rotation
+ self.apply_batch(&train_data, &mut projected, train_n);
+ pq.train_with_config(&projected, train_n, &KMeansConfig::default());
+
+ self.is_trained = true;
+ }
+
+ /// Apply rotation to a single vector: y = R * x.
+ pub fn apply(&self, x: &[f32], y: &mut [f32]) {
+ let d = self.d;
+ for i in 0..d {
+ let mut sum = 0.0f32;
+ for j in 0..d {
+ sum += self.rotation[i * d + j] * x[j];
+ }
+ y[i] = sum;
+ }
+ }
+
+ /// Apply rotation to a batch of vectors.
+ pub fn apply_batch(&self, data: &[f32], out: &mut [f32], n: usize) {
+ for i in 0..n {
+ self.apply(
+ &data[i * self.d..(i + 1) * self.d],
+ &mut out[i * self.d..(i + 1) * self.d],
+ );
+ }
+ }
+
+ /// Apply reverse rotation: x = R^T * y.
+ pub fn apply_reverse(&self, y: &[f32], x: &mut [f32]) {
+ let d = self.d;
+ for i in 0..d {
+ let mut sum = 0.0f32;
+ for j in 0..d {
+ sum += self.rotation[j * d + i] * y[j];
+ }
+ x[i] = sum;
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_rotation_orthogonality() {
+ let d = 8;
+ let m = 2;
+ let n = 500;
+
+ let mut rng = StdRng::seed_from_u64(42);
+ let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+ let mut opq = OPQMatrix::new(d, m);
+ opq.niter = 5; // Reduce for test speed
+ let mut pq = ProductQuantizer::new(d, m);
+ opq.train(&data, n, &mut pq);
+
+ assert!(opq.is_trained);
+
+ // Test that R * R^T ≈ I
+ for i in 0..d {
+ for j in 0..d {
+ let mut dot = 0.0f32;
+ for k in 0..d {
+ dot += opq.rotation[i * d + k] * opq.rotation[j * d + k];
+ }
+ let expected = if i == j { 1.0 } else { 0.0 };
+ assert!(
+ (dot - expected).abs() < 1e-4,
+ "R*R^T[{},{}] = {}, expected {}",
+ i,
+ j,
+ dot,
+ expected
+ );
+ }
+ }
+ }
+
+ #[test]
+ fn test_apply_reverse() {
+ let d = 4;
+ let mut opq = OPQMatrix::new(d, 2);
+ // Set rotation to a simple permutation matrix
+ opq.rotation = vec![
+ 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
0.0, 1.0, 0.0,
+ ];
+
+ let x = [1.0, 2.0, 3.0, 4.0];
+ let mut y = [0.0f32; 4];
+ opq.apply(&x, &mut y);
+
+ let mut x_back = [0.0f32; 4];
+ opq.apply_reverse(&y, &mut x_back);
+
+ for i in 0..d {
+ assert!((x[i] - x_back[i]).abs() < 1e-6);
+ }
+ }
+}
diff --git a/core/src/shuffler.rs b/core/src/shuffler.rs
new file mode 100644
index 0000000..ac96d35
--- /dev/null
+++ b/core/src/shuffler.rs
@@ -0,0 +1,276 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Disk-based shuffler for large-scale IVF-PQ index building.
+//! Inspired by Lance's shuffler: write vectors sequentially with partition
IDs,
+//! then read back grouped by partition for PQ encoding.
+
+use std::fs::File;
+use std::io::{self, BufReader, BufWriter, Read, Write};
+use std::path::PathBuf;
+use std::sync::atomic::{AtomicU64, Ordering};
+
+type PartitionData = (Vec<Vec<i64>>, Vec<Vec<f32>>);
+
+/// Record format: [partition_id: u32][row_id: i64][vector: f32 * dim]
+const RECORD_OVERHEAD: usize = 4 + 8; // partition_id + row_id
+
+/// Disk-based shuffler that accumulates vectors with partition assignments,
+/// then reads them back grouped by partition.
+pub struct DiskShuffler {
+ path: PathBuf,
+ writer: Option<BufWriter<File>>,
+ dim: usize,
+ record_size: usize,
+ count: usize,
+ partition_counts: Vec<usize>,
+}
+
+impl DiskShuffler {
+ /// Create a new shuffler with a temp file.
+ pub fn new(dim: usize, nlist: usize) -> io::Result<Self> {
+ static COUNTER: AtomicU64 = AtomicU64::new(0);
+ let id = COUNTER.fetch_add(1, Ordering::Relaxed);
+ let path =
+ std::env::temp_dir().join(format!("ivfpq-shuffle-{}-{}.bin",
std::process::id(), id));
+ let file = File::create(&path)?;
+ let writer = BufWriter::with_capacity(8 * 1024 * 1024, file);
+
+ Ok(DiskShuffler {
+ path,
+ writer: Some(writer),
+ dim,
+ record_size: RECORD_OVERHEAD + dim * 4,
+ count: 0,
+ partition_counts: vec![0; nlist],
+ })
+ }
+
+ /// Write a vector with its partition assignment and row ID.
+ pub fn write_vector(
+ &mut self,
+ partition_id: u32,
+ row_id: i64,
+ vector: &[f32],
+ ) -> io::Result<()> {
+ if vector.len() != self.dim {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "vector length {} does not match expected dim {}",
+ vector.len(),
+ self.dim
+ ),
+ ));
+ }
+ if partition_id as usize >= self.partition_counts.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "partition_id {} out of range (nlist={})",
+ partition_id,
+ self.partition_counts.len()
+ ),
+ ));
+ }
+ let writer = self.writer.as_mut().unwrap();
+ writer.write_all(&partition_id.to_le_bytes())?;
+ writer.write_all(&row_id.to_le_bytes())?;
+ for &v in vector {
+ writer.write_all(&v.to_le_bytes())?;
+ }
+ self.partition_counts[partition_id as usize] += 1;
+ self.count += 1;
+ Ok(())
+ }
+
+ /// Finalize writing and return partition counts.
+ pub fn finish_write(&mut self) -> io::Result<()> {
+ if let Some(w) = self.writer.take() {
+ drop(w); // flush and close
+ }
+ Ok(())
+ }
+
+ /// Read all vectors for a specific partition.
+ /// Returns (row_ids, vectors) where vectors is flat [count * dim].
+ pub fn read_partition(&self, partition_id: u32) -> io::Result<(Vec<i64>,
Vec<f32>)> {
+ let count = self.partition_counts[partition_id as usize];
+ if count == 0 {
+ return Ok((Vec::new(), Vec::new()));
+ }
+
+ let mut ids = Vec::with_capacity(count);
+ let mut vectors = Vec::with_capacity(count * self.dim);
+
+ let file = File::open(&self.path)?;
+ let mut reader = BufReader::with_capacity(8 * 1024 * 1024, file);
+ let mut record_buf = vec![0u8; self.record_size];
+
+ for _ in 0..self.count {
+ reader.read_exact(&mut record_buf)?;
+ let pid =
+ u32::from_le_bytes([record_buf[0], record_buf[1],
record_buf[2], record_buf[3]]);
+ if pid == partition_id {
+ let row_id = i64::from_le_bytes([
+ record_buf[4],
+ record_buf[5],
+ record_buf[6],
+ record_buf[7],
+ record_buf[8],
+ record_buf[9],
+ record_buf[10],
+ record_buf[11],
+ ]);
+ ids.push(row_id);
+ for i in 0..self.dim {
+ let off = RECORD_OVERHEAD + i * 4;
+ let v = f32::from_le_bytes([
+ record_buf[off],
+ record_buf[off + 1],
+ record_buf[off + 2],
+ record_buf[off + 3],
+ ]);
+ vectors.push(v);
+ }
+ }
+ }
+
+ Ok((ids, vectors))
+ }
+
+ /// Read all partitions at once (for moderate datasets that fit in memory
after PQ encoding).
+ /// Returns (ids_per_list, vectors_per_list).
+ pub fn read_all_partitions(&self) -> io::Result<PartitionData> {
+ let nlist = self.partition_counts.len();
+ let mut all_ids: Vec<Vec<i64>> = vec![Vec::new(); nlist];
+ let mut all_vectors: Vec<Vec<f32>> = vec![Vec::new(); nlist];
+
+ // Pre-allocate
+ for p in 0..nlist {
+ all_ids[p].reserve(self.partition_counts[p]);
+ all_vectors[p].reserve(self.partition_counts[p] * self.dim);
+ }
+
+ let file = File::open(&self.path)?;
+ let mut reader = BufReader::with_capacity(8 * 1024 * 1024, file);
+ let mut record_buf = vec![0u8; self.record_size];
+
+ for _ in 0..self.count {
+ reader.read_exact(&mut record_buf)?;
+ let pid =
+ u32::from_le_bytes([record_buf[0], record_buf[1],
record_buf[2], record_buf[3]])
+ as usize;
+ let row_id = i64::from_le_bytes([
+ record_buf[4],
+ record_buf[5],
+ record_buf[6],
+ record_buf[7],
+ record_buf[8],
+ record_buf[9],
+ record_buf[10],
+ record_buf[11],
+ ]);
+ all_ids[pid].push(row_id);
+ for i in 0..self.dim {
+ let off = RECORD_OVERHEAD + i * 4;
+ let v = f32::from_le_bytes([
+ record_buf[off],
+ record_buf[off + 1],
+ record_buf[off + 2],
+ record_buf[off + 3],
+ ]);
+ all_vectors[pid].push(v);
+ }
+ }
+
+ Ok((all_ids, all_vectors))
+ }
+
+ pub fn total_count(&self) -> usize {
+ self.count
+ }
+
+ pub fn partition_counts(&self) -> &[usize] {
+ &self.partition_counts
+ }
+}
+
+impl Drop for DiskShuffler {
+ fn drop(&mut self) {
+ let _ = std::fs::remove_file(&self.path);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_write_vector_validates_dim() {
+ let mut shuffler = DiskShuffler::new(4, 2).unwrap();
+ let err = shuffler.write_vector(0, 1, &[1.0, 2.0]).unwrap_err();
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ }
+
+ #[test]
+ fn test_write_vector_validates_partition_id() {
+ let mut shuffler = DiskShuffler::new(4, 2).unwrap();
+ let err = shuffler
+ .write_vector(5, 1, &[1.0, 2.0, 3.0, 4.0])
+ .unwrap_err();
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ }
+
+ #[test]
+ fn test_shuffler_roundtrip() {
+ let dim = 4;
+ let nlist = 3;
+ let mut shuffler = DiskShuffler::new(dim, nlist).unwrap();
+
+ // Write vectors to different partitions
+ shuffler
+ .write_vector(0, 100, &[1.0, 2.0, 3.0, 4.0])
+ .unwrap();
+ shuffler
+ .write_vector(1, 200, &[5.0, 6.0, 7.0, 8.0])
+ .unwrap();
+ shuffler
+ .write_vector(0, 300, &[9.0, 10.0, 11.0, 12.0])
+ .unwrap();
+ shuffler
+ .write_vector(2, 400, &[13.0, 14.0, 15.0, 16.0])
+ .unwrap();
+ shuffler.finish_write().unwrap();
+
+ assert_eq!(shuffler.partition_counts(), &[2, 1, 1]);
+
+ // Read partition 0
+ let (ids, vecs) = shuffler.read_partition(0).unwrap();
+ assert_eq!(ids, vec![100, 300]);
+ assert_eq!(vecs.len(), 2 * dim);
+ assert_eq!(&vecs[0..4], &[1.0, 2.0, 3.0, 4.0]);
+ assert_eq!(&vecs[4..8], &[9.0, 10.0, 11.0, 12.0]);
+
+ // Read all
+ let (all_ids, all_vecs) = shuffler.read_all_partitions().unwrap();
+ assert_eq!(all_ids[0], vec![100, 300]);
+ assert_eq!(all_ids[1], vec![200]);
+ assert_eq!(all_ids[2], vec![400]);
+ assert_eq!(&all_vecs[1][..], &[5.0, 6.0, 7.0, 8.0]);
+ }
+}