This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new 17273a8  Add opq, shuffler, and fastscan modules (#3)
17273a8 is described below

commit 17273a81bd15a1799b2dc5d570f2a25469bbc1fb
Author: Jingsong Lee <[email protected]>
AuthorDate: Sun Jun 7 21:20:09 2026 +0800

    Add opq, shuffler, and fastscan modules (#3)
    
    - opq: Optimized Product Quantization with rotation matrix (SVD-based)
    - shuffler: Disk-based partition shuffler for large-scale index building
    - fastscan: Block-layout 4-bit PQ scanning with SIMD acceleration
---
 core/Cargo.toml      |   1 +
 core/src/fastscan.rs | 500 +++++++++++++++++++++++++++++++++++++++++++++++++++
 core/src/lib.rs      |   3 +
 core/src/opq.rs      | 256 ++++++++++++++++++++++++++
 core/src/shuffler.rs | 276 ++++++++++++++++++++++++++++
 5 files changed, 1036 insertions(+)

diff --git a/core/Cargo.toml b/core/Cargo.toml
index 69ba8a7..22ecb07 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -22,6 +22,7 @@ edition = "2021"
 license = "Apache-2.0"
 
 [dependencies]
+nalgebra = "0.33"
 rand = "0.8"
 rayon = "1.10"
 matrixmultiply = "0.3"
diff --git a/core/src/fastscan.rs b/core/src/fastscan.rs
new file mode 100644
index 0000000..6865590
--- /dev/null
+++ b/core/src/fastscan.rs
@@ -0,0 +1,500 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! FastScan: Faiss-style block layout (bbs=32) + vpshufb 32-way parallel 
lookup
+//! for 4-bit PQ codes.
+//!
+//! Block layout: 32 vectors per block, codes interleaved by sub-quantizer 
pair.
+//! Each block stores M/2 groups of 32 bytes (one byte per vector per 
sub-quant pair).
+//!
+//! Layout: [block0_pair0(32B), block0_pair1(32B), ..., block0_pairM/2(32B),
+//!          block1_pair0(32B), ...]
+
+/// Block size: 32 vectors per block (matches AVX2 register width).
+pub const BBS: usize = 32;
+
+/// Pack 4-bit codes from row-major [n][cs] into block layout.
+/// Output layout: [num_blocks][cs][BBS] where cs = M/2.
+/// Pads the last block with zeros if n is not a multiple of BBS.
+pub fn pack_codes_block_layout(codes: &[u8], n: usize, cs: usize) -> Vec<u8> {
+    let num_blocks = n.div_ceil(BBS);
+    let block_size = cs * BBS; // bytes per block
+    let mut packed = vec![0u8; num_blocks * block_size];
+
+    for block in 0..num_blocks {
+        let block_start = block * BBS;
+        for pair in 0..cs {
+            for vec_in_block in 0..BBS {
+                let global_vec = block_start + vec_in_block;
+                if global_vec < n {
+                    packed[block * block_size + pair * BBS + vec_in_block] =
+                        codes[global_vec * cs + pair];
+                }
+                // else: remains 0 (padding)
+            }
+        }
+    }
+
+    packed
+}
+
+/// Unpack block layout back to row-major [n][cs] (for compatibility).
+pub fn unpack_codes_block_layout(packed: &[u8], n: usize, cs: usize) -> 
Vec<u8> {
+    let num_blocks = n.div_ceil(BBS);
+    let block_size = cs * BBS;
+    let mut codes = vec![0u8; n * cs];
+
+    for block in 0..num_blocks {
+        let block_start = block * BBS;
+        for pair in 0..cs {
+            for vec_in_block in 0..BBS {
+                let global_vec = block_start + vec_in_block;
+                if global_vec < n {
+                    codes[global_vec * cs + pair] =
+                        packed[block * block_size + pair * BBS + vec_in_block];
+                }
+            }
+        }
+    }
+
+    codes
+}
+
+/// Quantize a f32 distance table [M * 16] to u8 [M * 16].
+/// Returns (qmin, qmax_used, quantized_table).
+pub fn quantize_distance_table(table: &[f32], qmax_hint: f32) -> (f32, f32, 
Vec<u8>) {
+    let qmin = table.iter().cloned().fold(f32::INFINITY, f32::min);
+    let qmax = qmax_hint.max(table.iter().cloned().fold(f32::MIN, f32::max));
+    let range = (qmax - qmin).max(1e-10);
+    let factor = 255.0 / range;
+
+    let qtable: Vec<u8> = table
+        .iter()
+        .map(|&d| ((d - qmin) * factor).clamp(0.0, 255.0) as u8)
+        .collect();
+
+    (qmin, qmax, qtable)
+}
+
+/// FastScan: scan block-layout 4-bit codes using SIMD.
+/// codes: block layout [num_blocks][cs][BBS]
+/// sim_table: [M * 16] f32 distance table
+/// Returns f32 distances for `n` vectors.
+pub fn fastscan_4bit(sim_table: &[f32], codes: &[u8], n: usize, m: usize, 
dists: &mut [f32]) {
+    let cs = m / 2;
+
+    // Step 1: f32 exact for first min(200, n) vectors as qmax calibration
+    const FLAT_NUM: usize = 200;
+    let flat_end = n.min(FLAT_NUM);
+    let block_size = cs * BBS;
+
+    for i in 0..flat_end {
+        let block = i / BBS;
+        let vec_in_block = i % BBS;
+        let mut d = 0.0f32;
+        for pair in 0..cs {
+            let byte = codes[block * block_size + pair * BBS + vec_in_block];
+            let lo = (byte & 0x0F) as usize;
+            let hi = ((byte >> 4) & 0x0F) as usize;
+            d += sim_table[(pair * 2) * 16 + lo];
+            d += sim_table[(pair * 2 + 1) * 16 + hi];
+        }
+        dists[i] = d;
+    }
+
+    if n <= FLAT_NUM {
+        return;
+    }
+
+    // Step 2: Quantize distance table using the table's own range
+    let table_max = sim_table.iter().cloned().fold(f32::MIN, f32::max);
+    let (qmin, _, qtable) = quantize_distance_table(sim_table, table_max);
+    let range = (table_max - qmin).max(1e-10);
+
+    // Step 3: Scan blocks using SIMD
+    let num_blocks = n.div_ceil(BBS);
+    let start_block = flat_end.div_ceil(BBS);
+
+    #[cfg(target_arch = "x86_64")]
+    {
+        if is_x86_feature_detected!("avx2") {
+            unsafe {
+                fastscan_blocks_avx2(
+                    &qtable,
+                    codes,
+                    cs,
+                    start_block,
+                    num_blocks,
+                    block_size,
+                    qmin,
+                    range,
+                    m,
+                    dists,
+                );
+            }
+            let partial_start = start_block * BBS;
+            for i in flat_end..partial_start.min(n) {
+                let block = i / BBS;
+                let vec_in_block = i % BBS;
+                let mut d = 0.0f32;
+                for pair in 0..cs {
+                    let byte = codes[block * block_size + pair * BBS + 
vec_in_block];
+                    let lo = (byte & 0x0F) as usize;
+                    let hi = ((byte >> 4) & 0x0F) as usize;
+                    d += sim_table[(pair * 2) * 16 + lo];
+                    d += sim_table[(pair * 2 + 1) * 16 + hi];
+                }
+                dists[i] = d;
+            }
+            return;
+        }
+    }
+
+    #[cfg(target_arch = "aarch64")]
+    {
+        unsafe {
+            fastscan_blocks_neon(
+                &qtable,
+                codes,
+                cs,
+                start_block,
+                num_blocks,
+                block_size,
+                qmin,
+                range,
+                m,
+                dists,
+            );
+        }
+        let partial_start = start_block * BBS;
+        for i in flat_end..partial_start.min(n) {
+            let block = i / BBS;
+            let vec_in_block = i % BBS;
+            let mut d = 0.0f32;
+            for pair in 0..cs {
+                let byte = codes[block * block_size + pair * BBS + 
vec_in_block];
+                let lo = (byte & 0x0F) as usize;
+                let hi = ((byte >> 4) & 0x0F) as usize;
+                d += sim_table[(pair * 2) * 16 + lo];
+                d += sim_table[(pair * 2 + 1) * 16 + hi];
+            }
+            dists[i] = d;
+        }
+        return;
+    }
+
+    // Fallback: scalar scan on blocks (used when no SIMD available)
+    #[allow(unreachable_code)]
+    for block in start_block..num_blocks {
+        let base_vec = block * BBS;
+        let vecs_in_block = BBS.min(n - base_vec);
+
+        let mut q_dists = [0u16; BBS];
+
+        for pair in 0..cs {
+            let qtab_lo = &qtable[(pair * 2) * 16..(pair * 2 + 1) * 16];
+            let qtab_hi = &qtable[(pair * 2 + 1) * 16..(pair * 2 + 2) * 16];
+            let block_codes = &codes[block * block_size + pair * BBS..];
+
+            for v in 0..vecs_in_block {
+                let byte = block_codes[v];
+                let lo = (byte & 0x0F) as usize;
+                let hi = ((byte >> 4) & 0x0F) as usize;
+                q_dists[v] += qtab_lo[lo] as u16 + qtab_hi[hi] as u16;
+            }
+        }
+
+        // Dequantize
+        let inv_factor = range / 255.0;
+        let base_dist = qmin * m as f32;
+        for v in 0..vecs_in_block {
+            dists[base_vec + v] = q_dists[v] as f32 * inv_factor + base_dist;
+        }
+    }
+
+    // Fill gap between flat_end and start_block*BBS with exact computation
+    let partial_start = start_block * BBS;
+    for i in flat_end..partial_start.min(n) {
+        let block = i / BBS;
+        let vec_in_block = i % BBS;
+        let mut d = 0.0f32;
+        for pair in 0..cs {
+            let byte = codes[block * block_size + pair * BBS + vec_in_block];
+            let lo = (byte & 0x0F) as usize;
+            let hi = ((byte >> 4) & 0x0F) as usize;
+            d += sim_table[(pair * 2) * 16 + lo];
+            d += sim_table[(pair * 2 + 1) * 16 + hi];
+        }
+        dists[i] = d;
+    }
+}
+
+/// AVX2 block scan using vpshufb for 32-way parallel 4-bit lookup.
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn fastscan_blocks_avx2(
+    qtable: &[u8],
+    codes: &[u8],
+    cs: usize,
+    start_block: usize,
+    num_blocks: usize,
+    block_size: usize,
+    qmin: f32,
+    range: f32,
+    m: usize,
+    dists: &mut [f32],
+) {
+    use std::arch::x86_64::*;
+
+    let inv_factor = range / 255.0;
+    let base_dist = qmin * m as f32;
+
+    for block in start_block..num_blocks {
+        let base_vec = block * BBS;
+
+        // u16 accumulators: 2 × __m256i = 32 × u16 values
+        let mut accu_lo = _mm256_setzero_si256(); // vecs 0-15
+        let mut accu_hi = _mm256_setzero_si256(); // vecs 16-31
+
+        for pair in 0..cs {
+            // Load 32-byte quantized LUT for this pair (lo + hi 
sub-quantizers)
+            // LUT layout: [16 bytes for sub_lo, 16 bytes for sub_hi]
+            let lut_lo_ptr = qtable.as_ptr().add((pair * 2) * 16);
+            let lut_hi_ptr = qtable.as_ptr().add((pair * 2 + 1) * 16);
+
+            // Broadcast 16-byte LUTs into 256-bit registers (same 16-byte 
table in both lanes)
+            let lut_lo = 
_mm256_broadcastsi128_si256(_mm_loadu_si128(lut_lo_ptr as *const __m128i));
+            let lut_hi = 
_mm256_broadcastsi128_si256(_mm_loadu_si128(lut_hi_ptr as *const __m128i));
+
+            // Load 32 code bytes for this pair in this block
+            let code_ptr = codes.as_ptr().add(block * block_size + pair * BBS);
+            let code_vec = _mm256_loadu_si256(code_ptr as *const __m256i);
+
+            // Split nibbles
+            let mask = _mm256_set1_epi8(0x0F);
+            let lo_nibbles = _mm256_and_si256(code_vec, mask);
+            let hi_nibbles = _mm256_and_si256(_mm256_srli_epi16(code_vec, 4), 
mask);
+
+            // vpshufb: 32-way parallel lookup
+            let dist_lo = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
+            let dist_hi = _mm256_shuffle_epi8(lut_hi, hi_nibbles);
+
+            // Widen each to u16 separately then add (avoids u8 saturation 
overflow)
+            let zero = _mm256_setzero_si256();
+            let dlo_lo = _mm256_unpacklo_epi8(dist_lo, zero);
+            let dlo_hi = _mm256_unpackhi_epi8(dist_lo, zero);
+            let dhi_lo = _mm256_unpacklo_epi8(dist_hi, zero);
+            let dhi_hi = _mm256_unpackhi_epi8(dist_hi, zero);
+
+            accu_lo = _mm256_add_epi16(accu_lo, _mm256_add_epi16(dlo_lo, 
dhi_lo));
+            accu_hi = _mm256_add_epi16(accu_hi, _mm256_add_epi16(dlo_hi, 
dhi_hi));
+        }
+
+        // Extract u16 values and dequantize to f32.
+        // _mm256_unpacklo/hi_epi8 operates per 128-bit lane, so:
+        //   accu_lo = [v0..v7 | v16..v23], accu_hi = [v8..v15 | v24..v31]
+        // Reassemble correct order with cross-lane permute.
+        let result_lo = _mm256_permute2x128_si256(accu_lo, accu_hi, 0x20);
+        let result_hi = _mm256_permute2x128_si256(accu_lo, accu_hi, 0x31);
+        let mut q_vals = [0u16; BBS];
+        _mm256_storeu_si256(q_vals.as_mut_ptr() as *mut __m256i, result_lo);
+        _mm256_storeu_si256(q_vals.as_mut_ptr().add(16) as *mut __m256i, 
result_hi);
+
+        for v in 0..BBS {
+            let idx = base_vec + v;
+            if idx < dists.len() {
+                dists[idx] = q_vals[v] as f32 * inv_factor + base_dist;
+            }
+        }
+    }
+}
+
+/// ARM NEON block scan (16-way per instruction, 2 passes per block).
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn fastscan_blocks_neon(
+    qtable: &[u8],
+    codes: &[u8],
+    cs: usize,
+    start_block: usize,
+    num_blocks: usize,
+    block_size: usize,
+    qmin: f32,
+    range: f32,
+    m: usize,
+    dists: &mut [f32],
+) {
+    use std::arch::aarch64::*;
+
+    let inv_factor = range / 255.0;
+    let base_dist = qmin * m as f32;
+
+    for block in start_block..num_blocks {
+        let base_vec = block * BBS;
+
+        // u16 accumulators: 4 × uint16x8_t = 32 × u16
+        let mut accu = [vdupq_n_u16(0); 4];
+
+        for pair in 0..cs {
+            let lut_lo = vld1q_u8(qtable.as_ptr().add((pair * 2) * 16));
+            let lut_hi = vld1q_u8(qtable.as_ptr().add((pair * 2 + 1) * 16));
+
+            let code_ptr = codes.as_ptr().add(block * block_size + pair * BBS);
+
+            // Process 16 vectors at a time (NEON = 128-bit)
+            for half in 0..2 {
+                let code_vec = vld1q_u8(code_ptr.add(half * 16));
+
+                let mask = vdupq_n_u8(0x0F);
+                let lo_nib = vandq_u8(code_vec, mask);
+                let hi_nib = vshrq_n_u8(code_vec, 4);
+
+                // tbl: 16-way lookup
+                let dist_lo = vqtbl1q_u8(lut_lo, lo_nib);
+                let dist_hi = vqtbl1q_u8(lut_hi, hi_nib);
+
+                // Widen each to u16 separately then add (avoids u8 saturation 
overflow)
+                accu[half * 2] = vaddq_u16(accu[half * 2], 
vmovl_u8(vget_low_u8(dist_lo)));
+                accu[half * 2] = vaddq_u16(accu[half * 2], 
vmovl_u8(vget_low_u8(dist_hi)));
+                accu[half * 2 + 1] = vaddq_u16(accu[half * 2 + 1], 
vmovl_u8(vget_high_u8(dist_lo)));
+                accu[half * 2 + 1] = vaddq_u16(accu[half * 2 + 1], 
vmovl_u8(vget_high_u8(dist_hi)));
+            }
+        }
+
+        // Extract and dequantize
+        let mut q_vals = [0u16; BBS];
+        for i in 0..4 {
+            vst1q_u16(q_vals.as_mut_ptr().add(i * 8), accu[i]);
+        }
+
+        for v in 0..BBS {
+            let idx = base_vec + v;
+            if idx < dists.len() {
+                dists[idx] = q_vals[v] as f32 * inv_factor + base_dist;
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_pack_unpack_roundtrip() {
+        let n = 100;
+        let cs = 4; // m/2 = 4, i.e. m=8
+        let codes: Vec<u8> = (0..n * cs).map(|i| (i % 256) as u8).collect();
+
+        let packed = pack_codes_block_layout(&codes, n, cs);
+        let unpacked = unpack_codes_block_layout(&packed, n, cs);
+
+        assert_eq!(codes, unpacked);
+    }
+
+    #[test]
+    fn test_fastscan_correctness() {
+        let m = 8;
+        let cs = m / 2;
+        let n = 100;
+
+        // Random-ish codes
+        let codes_row: Vec<u8> = (0..n * cs).map(|i| ((i * 7 + 3) % 256) as 
u8).collect();
+
+        // Random-ish distance table [M * 16]
+        let sim_table: Vec<f32> = (0..m * 16).map(|i| (i as f32) * 0.1 + 
0.5).collect();
+
+        // Compute ground truth with scalar
+        let mut expected = vec![0.0f32; n];
+        for i in 0..n {
+            for pair in 0..cs {
+                let byte = codes_row[i * cs + pair];
+                let lo = (byte & 0x0F) as usize;
+                let hi = ((byte >> 4) & 0x0F) as usize;
+                expected[i] += sim_table[(pair * 2) * 16 + lo];
+                expected[i] += sim_table[(pair * 2 + 1) * 16 + hi];
+            }
+        }
+
+        // Pack into block layout and scan
+        let packed = pack_codes_block_layout(&codes_row, n, cs);
+        let mut result = vec![0.0f32; n];
+        fastscan_4bit(&sim_table, &packed, n, m, &mut result);
+
+        // Check first 200 (f32 exact) — should match perfectly
+        for i in 0..n.min(200) {
+            assert!(
+                (result[i] - expected[i]).abs() < 1e-5,
+                "Mismatch at {}: {} vs {}",
+                i,
+                result[i],
+                expected[i]
+            );
+        }
+    }
+
+    #[test]
+    fn test_fastscan_large() {
+        let m = 16;
+        let cs = m / 2;
+        let n = 1000; // > 200, exercises quantized path
+
+        let codes_row: Vec<u8> = (0..n * cs).map(|i| ((i * 13 + 7) % 256) as 
u8).collect();
+        let sim_table: Vec<f32> = (0..m * 16).map(|i| (i as f32) * 0.05 + 
1.0).collect();
+
+        // Compute scalar reference
+        let mut expected = vec![0.0f32; n];
+        for i in 0..n {
+            for pair in 0..cs {
+                let byte = codes_row[i * cs + pair];
+                let lo = (byte & 0x0F) as usize;
+                let hi = ((byte >> 4) & 0x0F) as usize;
+                expected[i] += sim_table[(pair * 2) * 16 + lo];
+                expected[i] += sim_table[(pair * 2 + 1) * 16 + hi];
+            }
+        }
+
+        let packed = pack_codes_block_layout(&codes_row, n, cs);
+        let mut result = vec![0.0f32; n];
+        fastscan_4bit(&sim_table, &packed, n, m, &mut result);
+
+        // First 200 are computed with f32 exact — should match perfectly
+        for i in 0..200 {
+            assert!(
+                (result[i] - expected[i]).abs() < 1e-5,
+                "Exact mismatch at {}: got {}, expected {}",
+                i,
+                result[i],
+                expected[i]
+            );
+        }
+
+        // Beyond 200: quantized path — allow quantization tolerance
+        let max_expected = expected.iter().cloned().fold(f32::MIN, f32::max);
+        let tolerance = max_expected * 0.02; // 2% relative tolerance for u8 
quantization
+        for i in 200..n {
+            assert!(
+                (result[i] - expected[i]).abs() <= tolerance,
+                "SIMD mismatch at {}: got {}, expected {}, diff {}",
+                i,
+                result[i],
+                expected[i],
+                (result[i] - expected[i]).abs()
+            );
+        }
+    }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index a2df43f..9f03d17 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -20,5 +20,8 @@
 
 pub mod blas;
 pub mod distance;
+pub mod fastscan;
 pub mod kmeans;
+pub mod opq;
 pub mod pq;
+pub mod shuffler;
diff --git a/core/src/opq.rs b/core/src/opq.rs
new file mode 100644
index 0000000..8c0739b
--- /dev/null
+++ b/core/src/opq.rs
@@ -0,0 +1,256 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::kmeans::KMeansConfig;
+use crate::pq::ProductQuantizer;
+use nalgebra::{DMatrix, SVD};
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
+
+/// OPQ (Optimized Product Quantization) rotation matrix.
+/// Aligned with Faiss's OPQMatrix from VectorTransform.cpp.
+///
+/// Learns an orthogonal rotation R that minimizes PQ reconstruction error
+/// via alternating Procrustes optimization.
+pub struct OPQMatrix {
+    pub d: usize,
+    pub m: usize,
+    pub niter: usize,
+    pub niter_pq: usize,
+    pub niter_pq_0: usize,
+    pub max_train_points: usize,
+    /// Rotation matrix [d * d], row-major. y = R * x.
+    pub rotation: Vec<f32>,
+    pub is_trained: bool,
+}
+
+impl OPQMatrix {
+    pub fn new(d: usize, m: usize) -> Self {
+        OPQMatrix {
+            d,
+            m,
+            niter: 50,
+            niter_pq: 4,
+            niter_pq_0: 40,
+            max_train_points: 65536,
+            rotation: vec![0.0f32; d * d],
+            is_trained: false,
+        }
+    }
+
+    /// Train the OPQ rotation matrix.
+    /// data: flat [n * d].
+    pub fn train(&mut self, data: &[f32], n: usize, pq: &mut ProductQuantizer) 
{
+        let d = self.d;
+        let mut rng = StdRng::seed_from_u64(12345);
+
+        // Subsample if needed
+        let train_n = n.min(self.max_train_points);
+        let mut train_data = if n > self.max_train_points {
+            let mut sub = vec![0.0f32; train_n * d];
+            let mut indices: Vec<usize> = (0..n).collect();
+            for i in 0..train_n {
+                let j = rng.gen_range(i..n);
+                indices.swap(i, j);
+            }
+            for (out_i, &src_i) in indices[..train_n].iter().enumerate() {
+                sub[out_i * d..(out_i + 1) * d].copy_from_slice(&data[src_i * 
d..(src_i + 1) * d]);
+            }
+            sub
+        } else {
+            data[..n * d].to_vec()
+        };
+
+        // Center data (subtract mean) — aligned with Faiss OPQMatrix
+        let mut mean = vec![0.0f32; d];
+        for i in 0..train_n {
+            for j in 0..d {
+                mean[j] += train_data[i * d + j];
+            }
+        }
+        let inv_n = 1.0 / train_n as f32;
+        for j in 0..d {
+            mean[j] *= inv_n;
+        }
+        for i in 0..train_n {
+            for j in 0..d {
+                train_data[i * d + j] -= mean[j];
+            }
+        }
+
+        // Initialize with random orthogonal matrix via QR decomposition
+        let random_mat: Vec<f32> = (0..d * d).map(|_| rng.gen::<f32>() - 
0.5).collect();
+        let mat = DMatrix::from_row_slice(d, d, &random_mat);
+        let qr = mat.qr();
+        let q = qr.q();
+        for i in 0..d {
+            for j in 0..d {
+                self.rotation[i * d + j] = q[(i, j)];
+            }
+        }
+
+        let mut projected = vec![0.0f32; train_n * d];
+        let mut reconstructed = vec![0.0f32; train_n * d];
+        let mut codes = vec![0u8; train_n * pq.m];
+
+        for iter in 0..self.niter {
+            // 1. Project: projected = train_data * R^T
+            self.apply_batch(&train_data, &mut projected, train_n);
+
+            // 2. Train PQ on projected data (hot-start on iter >= 1)
+            let pq_niter = if iter == 0 {
+                self.niter_pq_0
+            } else {
+                self.niter_pq
+            };
+            let km_config = KMeansConfig {
+                niter: pq_niter,
+                ..KMeansConfig::default()
+            };
+            let hot_start = iter > 0;
+            pq.train_hot_start(&projected, train_n, &km_config, hot_start);
+
+            // 3. Encode and decode to get reconstructions
+            pq.encode_batch(&projected, train_n, &mut codes);
+            for i in 0..train_n {
+                pq.decode(
+                    &codes[i * pq.m..(i + 1) * pq.m],
+                    &mut reconstructed[i * d..(i + 1) * d],
+                );
+            }
+
+            // 4. Solve Procrustes: find R that minimizes ||X - Y*R^T||
+            //    Solution: R = V * U^T where X^T * Y = U * S * V^T
+            let x_mat = DMatrix::from_row_slice(train_n, d, &train_data);
+            let y_mat = DMatrix::from_row_slice(train_n, d, &reconstructed);
+            let cross_cov = x_mat.transpose() * &y_mat; // [d x d]
+
+            let svd = SVD::new(cross_cov, true, true);
+            if let (Some(u), Some(vt)) = (svd.u, svd.v_t) {
+                // R = U * V^T
+                let r = &u * &vt;
+                for i in 0..d {
+                    for j in 0..d {
+                        self.rotation[i * d + j] = r[(i, j)];
+                    }
+                }
+            }
+        }
+
+        // Final PQ training with the learned rotation
+        self.apply_batch(&train_data, &mut projected, train_n);
+        pq.train_with_config(&projected, train_n, &KMeansConfig::default());
+
+        self.is_trained = true;
+    }
+
+    /// Apply rotation to a single vector: y = R * x.
+    pub fn apply(&self, x: &[f32], y: &mut [f32]) {
+        let d = self.d;
+        for i in 0..d {
+            let mut sum = 0.0f32;
+            for j in 0..d {
+                sum += self.rotation[i * d + j] * x[j];
+            }
+            y[i] = sum;
+        }
+    }
+
+    /// Apply rotation to a batch of vectors.
+    pub fn apply_batch(&self, data: &[f32], out: &mut [f32], n: usize) {
+        for i in 0..n {
+            self.apply(
+                &data[i * self.d..(i + 1) * self.d],
+                &mut out[i * self.d..(i + 1) * self.d],
+            );
+        }
+    }
+
+    /// Apply reverse rotation: x = R^T * y.
+    pub fn apply_reverse(&self, y: &[f32], x: &mut [f32]) {
+        let d = self.d;
+        for i in 0..d {
+            let mut sum = 0.0f32;
+            for j in 0..d {
+                sum += self.rotation[j * d + i] * y[j];
+            }
+            x[i] = sum;
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_rotation_orthogonality() {
+        let d = 8;
+        let m = 2;
+        let n = 500;
+
+        let mut rng = StdRng::seed_from_u64(42);
+        let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+        let mut opq = OPQMatrix::new(d, m);
+        opq.niter = 5; // Reduce for test speed
+        let mut pq = ProductQuantizer::new(d, m);
+        opq.train(&data, n, &mut pq);
+
+        assert!(opq.is_trained);
+
+        // Test that R * R^T ≈ I
+        for i in 0..d {
+            for j in 0..d {
+                let mut dot = 0.0f32;
+                for k in 0..d {
+                    dot += opq.rotation[i * d + k] * opq.rotation[j * d + k];
+                }
+                let expected = if i == j { 1.0 } else { 0.0 };
+                assert!(
+                    (dot - expected).abs() < 1e-4,
+                    "R*R^T[{},{}] = {}, expected {}",
+                    i,
+                    j,
+                    dot,
+                    expected
+                );
+            }
+        }
+    }
+
+    #[test]
+    fn test_apply_reverse() {
+        let d = 4;
+        let mut opq = OPQMatrix::new(d, 2);
+        // Set rotation to a simple permutation matrix
+        opq.rotation = vec![
+            0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 
0.0, 1.0, 0.0,
+        ];
+
+        let x = [1.0, 2.0, 3.0, 4.0];
+        let mut y = [0.0f32; 4];
+        opq.apply(&x, &mut y);
+
+        let mut x_back = [0.0f32; 4];
+        opq.apply_reverse(&y, &mut x_back);
+
+        for i in 0..d {
+            assert!((x[i] - x_back[i]).abs() < 1e-6);
+        }
+    }
+}
diff --git a/core/src/shuffler.rs b/core/src/shuffler.rs
new file mode 100644
index 0000000..ac96d35
--- /dev/null
+++ b/core/src/shuffler.rs
@@ -0,0 +1,276 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Disk-based shuffler for large-scale IVF-PQ index building.
+//! Inspired by Lance's shuffler: write vectors sequentially with partition 
IDs,
+//! then read back grouped by partition for PQ encoding.
+
+use std::fs::File;
+use std::io::{self, BufReader, BufWriter, Read, Write};
+use std::path::PathBuf;
+use std::sync::atomic::{AtomicU64, Ordering};
+
+type PartitionData = (Vec<Vec<i64>>, Vec<Vec<f32>>);
+
+/// Record format: [partition_id: u32][row_id: i64][vector: f32 * dim]
+const RECORD_OVERHEAD: usize = 4 + 8; // partition_id + row_id
+
+/// Disk-based shuffler that accumulates vectors with partition assignments,
+/// then reads them back grouped by partition.
+pub struct DiskShuffler {
+    path: PathBuf,
+    writer: Option<BufWriter<File>>,
+    dim: usize,
+    record_size: usize,
+    count: usize,
+    partition_counts: Vec<usize>,
+}
+
+impl DiskShuffler {
+    /// Create a new shuffler with a temp file.
+    pub fn new(dim: usize, nlist: usize) -> io::Result<Self> {
+        static COUNTER: AtomicU64 = AtomicU64::new(0);
+        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
+        let path =
+            std::env::temp_dir().join(format!("ivfpq-shuffle-{}-{}.bin", 
std::process::id(), id));
+        let file = File::create(&path)?;
+        let writer = BufWriter::with_capacity(8 * 1024 * 1024, file);
+
+        Ok(DiskShuffler {
+            path,
+            writer: Some(writer),
+            dim,
+            record_size: RECORD_OVERHEAD + dim * 4,
+            count: 0,
+            partition_counts: vec![0; nlist],
+        })
+    }
+
+    /// Write a vector with its partition assignment and row ID.
+    pub fn write_vector(
+        &mut self,
+        partition_id: u32,
+        row_id: i64,
+        vector: &[f32],
+    ) -> io::Result<()> {
+        if vector.len() != self.dim {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "vector length {} does not match expected dim {}",
+                    vector.len(),
+                    self.dim
+                ),
+            ));
+        }
+        if partition_id as usize >= self.partition_counts.len() {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "partition_id {} out of range (nlist={})",
+                    partition_id,
+                    self.partition_counts.len()
+                ),
+            ));
+        }
+        let writer = self.writer.as_mut().unwrap();
+        writer.write_all(&partition_id.to_le_bytes())?;
+        writer.write_all(&row_id.to_le_bytes())?;
+        for &v in vector {
+            writer.write_all(&v.to_le_bytes())?;
+        }
+        self.partition_counts[partition_id as usize] += 1;
+        self.count += 1;
+        Ok(())
+    }
+
+    /// Finalize writing and return partition counts.
+    pub fn finish_write(&mut self) -> io::Result<()> {
+        if let Some(w) = self.writer.take() {
+            drop(w); // flush and close
+        }
+        Ok(())
+    }
+
+    /// Read all vectors for a specific partition.
+    /// Returns (row_ids, vectors) where vectors is flat [count * dim].
+    pub fn read_partition(&self, partition_id: u32) -> io::Result<(Vec<i64>, 
Vec<f32>)> {
+        let count = self.partition_counts[partition_id as usize];
+        if count == 0 {
+            return Ok((Vec::new(), Vec::new()));
+        }
+
+        let mut ids = Vec::with_capacity(count);
+        let mut vectors = Vec::with_capacity(count * self.dim);
+
+        let file = File::open(&self.path)?;
+        let mut reader = BufReader::with_capacity(8 * 1024 * 1024, file);
+        let mut record_buf = vec![0u8; self.record_size];
+
+        for _ in 0..self.count {
+            reader.read_exact(&mut record_buf)?;
+            let pid =
+                u32::from_le_bytes([record_buf[0], record_buf[1], 
record_buf[2], record_buf[3]]);
+            if pid == partition_id {
+                let row_id = i64::from_le_bytes([
+                    record_buf[4],
+                    record_buf[5],
+                    record_buf[6],
+                    record_buf[7],
+                    record_buf[8],
+                    record_buf[9],
+                    record_buf[10],
+                    record_buf[11],
+                ]);
+                ids.push(row_id);
+                for i in 0..self.dim {
+                    let off = RECORD_OVERHEAD + i * 4;
+                    let v = f32::from_le_bytes([
+                        record_buf[off],
+                        record_buf[off + 1],
+                        record_buf[off + 2],
+                        record_buf[off + 3],
+                    ]);
+                    vectors.push(v);
+                }
+            }
+        }
+
+        Ok((ids, vectors))
+    }
+
+    /// Read all partitions at once (for moderate datasets that fit in memory 
after PQ encoding).
+    /// Returns (ids_per_list, vectors_per_list).
+    pub fn read_all_partitions(&self) -> io::Result<PartitionData> {
+        let nlist = self.partition_counts.len();
+        let mut all_ids: Vec<Vec<i64>> = vec![Vec::new(); nlist];
+        let mut all_vectors: Vec<Vec<f32>> = vec![Vec::new(); nlist];
+
+        // Pre-allocate
+        for p in 0..nlist {
+            all_ids[p].reserve(self.partition_counts[p]);
+            all_vectors[p].reserve(self.partition_counts[p] * self.dim);
+        }
+
+        let file = File::open(&self.path)?;
+        let mut reader = BufReader::with_capacity(8 * 1024 * 1024, file);
+        let mut record_buf = vec![0u8; self.record_size];
+
+        for _ in 0..self.count {
+            reader.read_exact(&mut record_buf)?;
+            let pid =
+                u32::from_le_bytes([record_buf[0], record_buf[1], 
record_buf[2], record_buf[3]])
+                    as usize;
+            let row_id = i64::from_le_bytes([
+                record_buf[4],
+                record_buf[5],
+                record_buf[6],
+                record_buf[7],
+                record_buf[8],
+                record_buf[9],
+                record_buf[10],
+                record_buf[11],
+            ]);
+            all_ids[pid].push(row_id);
+            for i in 0..self.dim {
+                let off = RECORD_OVERHEAD + i * 4;
+                let v = f32::from_le_bytes([
+                    record_buf[off],
+                    record_buf[off + 1],
+                    record_buf[off + 2],
+                    record_buf[off + 3],
+                ]);
+                all_vectors[pid].push(v);
+            }
+        }
+
+        Ok((all_ids, all_vectors))
+    }
+
+    pub fn total_count(&self) -> usize {
+        self.count
+    }
+
+    pub fn partition_counts(&self) -> &[usize] {
+        &self.partition_counts
+    }
+}
+
+impl Drop for DiskShuffler {
+    fn drop(&mut self) {
+        let _ = std::fs::remove_file(&self.path);
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_write_vector_validates_dim() {
+        let mut shuffler = DiskShuffler::new(4, 2).unwrap();
+        let err = shuffler.write_vector(0, 1, &[1.0, 2.0]).unwrap_err();
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+    }
+
+    #[test]
+    fn test_write_vector_validates_partition_id() {
+        let mut shuffler = DiskShuffler::new(4, 2).unwrap();
+        let err = shuffler
+            .write_vector(5, 1, &[1.0, 2.0, 3.0, 4.0])
+            .unwrap_err();
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+    }
+
+    #[test]
+    fn test_shuffler_roundtrip() {
+        let dim = 4;
+        let nlist = 3;
+        let mut shuffler = DiskShuffler::new(dim, nlist).unwrap();
+
+        // Write vectors to different partitions
+        shuffler
+            .write_vector(0, 100, &[1.0, 2.0, 3.0, 4.0])
+            .unwrap();
+        shuffler
+            .write_vector(1, 200, &[5.0, 6.0, 7.0, 8.0])
+            .unwrap();
+        shuffler
+            .write_vector(0, 300, &[9.0, 10.0, 11.0, 12.0])
+            .unwrap();
+        shuffler
+            .write_vector(2, 400, &[13.0, 14.0, 15.0, 16.0])
+            .unwrap();
+        shuffler.finish_write().unwrap();
+
+        assert_eq!(shuffler.partition_counts(), &[2, 1, 1]);
+
+        // Read partition 0
+        let (ids, vecs) = shuffler.read_partition(0).unwrap();
+        assert_eq!(ids, vec![100, 300]);
+        assert_eq!(vecs.len(), 2 * dim);
+        assert_eq!(&vecs[0..4], &[1.0, 2.0, 3.0, 4.0]);
+        assert_eq!(&vecs[4..8], &[9.0, 10.0, 11.0, 12.0]);
+
+        // Read all
+        let (all_ids, all_vecs) = shuffler.read_all_partitions().unwrap();
+        assert_eq!(all_ids[0], vec![100, 300]);
+        assert_eq!(all_ids[1], vec![200]);
+        assert_eq!(all_ids[2], vec![400]);
+        assert_eq!(&all_vecs[1][..], &[5.0, 6.0, 7.0, 8.0]);
+    }
+}


Reply via email to