This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new 2027edc  Add core foundation modules (#1)
2027edc is described below

commit 2027edc81aa813a1af8394266f53064fb6c9f377
Author: Jingsong Lee <[email protected]>
AuthorDate: Fri Jun 5 21:14:17 2026 +0800

    Add core foundation modules (#1)
    
    Bring over the basic building blocks: distance (L2/IP/Cosine + SIMD),
    blas (sgemm), kmeans (k-means++/hierarchical/streaming), and product
    quantizer (train/encode/decode/distance table). Skips higher-level
    modules (ivfpq, opq, fastscan, shuffler, io) and language bindings.
---
 core/Cargo.toml => .gitignore |  30 +-
 LICENSE                       | 201 ++++++++++
 NOTICE                        |   5 +
 copyright.txt                 |  17 +
 core/Cargo.toml               |   5 +
 core/src/blas.rs              |  91 +++++
 core/src/distance.rs          | 470 +++++++++++++++++++++++
 core/src/kmeans.rs            | 875 ++++++++++++++++++++++++++++++++++++++++++
 core/src/lib.rs               |   8 +
 core/src/pq.rs                | 542 ++++++++++++++++++++++++++
 10 files changed, 2239 insertions(+), 5 deletions(-)

diff --git a/core/Cargo.toml b/.gitignore
similarity index 77%
copy from core/Cargo.toml
copy to .gitignore
index b3b3233..d987f62 100644
--- a/core/Cargo.toml
+++ b/.gitignore
@@ -15,8 +15,28 @@
 # specific language governing permissions and limitations
 # under the License.
 
-[package]
-name = "paimon-vindex-core"
-version = "0.1.0"
-edition = "2021"
-license = "Apache-2.0"
+# Rust
+/target/
+Cargo.lock
+
+# Java
+java/target/
+java/src/main/resources/native/
+*.class
+
+# Python
+__pycache__/
+*.pyc
+*.egg-info/
+python/dist/
+python/build/
+.pytest_cache/
+
+# IDE
+.idea/
+*.iml
+.vscode/
+
+# OS
+.DS_Store
+Thumbs.db
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/NOTICE b/NOTICE
new file mode 100644
index 0000000..481e138
--- /dev/null
+++ b/NOTICE
@@ -0,0 +1,5 @@
+Apache Paimon Vector Index
+Copyright 2025 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
diff --git a/copyright.txt b/copyright.txt
new file mode 100644
index 0000000..2379dda
--- /dev/null
+++ b/copyright.txt
@@ -0,0 +1,17 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
diff --git a/core/Cargo.toml b/core/Cargo.toml
index b3b3233..69ba8a7 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -20,3 +20,8 @@ name = "paimon-vindex-core"
 version = "0.1.0"
 edition = "2021"
 license = "Apache-2.0"
+
+[dependencies]
+rand = "0.8"
+rayon = "1.10"
+matrixmultiply = "0.3"
diff --git a/core/src/blas.rs b/core/src/blas.rs
new file mode 100644
index 0000000..ddbe9a2
--- /dev/null
+++ b/core/src/blas.rs
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Matrix multiplication via `matrixmultiply` crate.
+//! Pure Rust with SIMD + cache blocking, no system BLAS dependency.
+
+/// C = alpha * A * B^T + beta * C
+/// A: [m × k] row-major
+/// B: [n × k] row-major (transposed in the multiply)
+/// C: [m × n] row-major
+pub fn sgemm_a_bt(
+    m: usize,
+    n: usize,
+    k: usize,
+    alpha: f32,
+    a: &[f32],
+    b: &[f32],
+    beta: f32,
+    c: &mut [f32],
+) {
+    unsafe {
+        matrixmultiply::sgemm(
+            m,
+            k,
+            n,
+            alpha,
+            a.as_ptr(),
+            k as isize, // row stride of A
+            1,          // col stride of A
+            b.as_ptr(),
+            1,          // B^T: col stride = 1 means we read B row-wise as 
columns
+            k as isize, // B^T: row stride = k means next "column of B^T" = 
next row of B
+            beta,
+            c.as_mut_ptr(),
+            n as isize, // row stride of C
+            1,          // col stride of C
+        );
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_sgemm_a_bt() {
+        // A = [[1, 2], [3, 4]]  (2x2)
+        // B = [[5, 6], [7, 8]]  (2x2)
+        // A * B^T = [[1*5+2*6, 1*7+2*8], [3*5+4*6, 3*7+4*8]]
+        //         = [[17, 23], [39, 53]]
+        let a = [1.0f32, 2.0, 3.0, 4.0];
+        let b = [5.0f32, 6.0, 7.0, 8.0];
+        let mut c = [0.0f32; 4];
+
+        sgemm_a_bt(2, 2, 2, 1.0, &a, &b, 0.0, &mut c);
+
+        assert!((c[0] - 17.0).abs() < 1e-5);
+        assert!((c[1] - 23.0).abs() < 1e-5);
+        assert!((c[2] - 39.0).abs() < 1e-5);
+        assert!((c[3] - 53.0).abs() < 1e-5);
+    }
+
+    #[test]
+    fn test_sgemm_rectangular() {
+        // A = [[1, 2, 3]]   (1x3)
+        // B = [[4, 5, 6], [7, 8, 9]]  (2x3)
+        // A * B^T = [[1*4+2*5+3*6, 1*7+2*8+3*9]] = [[32, 50]]
+        let a = [1.0f32, 2.0, 3.0];
+        let b = [4.0f32, 5.0, 6.0, 7.0, 8.0, 9.0];
+        let mut c = [0.0f32; 2];
+
+        sgemm_a_bt(1, 2, 3, 1.0, &a, &b, 0.0, &mut c);
+
+        assert!((c[0] - 32.0).abs() < 1e-5);
+        assert!((c[1] - 50.0).abs() < 1e-5);
+    }
+}
diff --git a/core/src/distance.rs b/core/src/distance.rs
new file mode 100644
index 0000000..f6174e1
--- /dev/null
+++ b/core/src/distance.rs
@@ -0,0 +1,470 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[repr(u32)]
+pub enum MetricType {
+    L2 = 0,
+    InnerProduct = 1,
+    Cosine = 2,
+}
+
+impl MetricType {
+    pub fn from_code(code: u32) -> Option<Self> {
+        match code {
+            0 => Some(MetricType::L2),
+            1 => Some(MetricType::InnerProduct),
+            2 => Some(MetricType::Cosine),
+            _ => None,
+        }
+    }
+}
+
+/// Squared L2 distance between two vectors.
+pub fn fvec_l2sqr(a: &[f32], b: &[f32]) -> f32 {
+    debug_assert_eq!(a.len(), b.len());
+    let mut sum = 0.0f32;
+    for i in 0..a.len() {
+        let d = a[i] - b[i];
+        sum += d * d;
+    }
+    sum
+}
+
+/// Squared L2 distance on sub-vectors.
+pub fn fvec_l2sqr_sub(a: &[f32], a_off: usize, b: &[f32], b_off: usize, len: 
usize) -> f32 {
+    let mut sum = 0.0f32;
+    for i in 0..len {
+        let d = a[a_off + i] - b[b_off + i];
+        sum += d * d;
+    }
+    sum
+}
+
+/// Inner product of two vectors.
+pub fn fvec_inner_product(a: &[f32], b: &[f32]) -> f32 {
+    debug_assert_eq!(a.len(), b.len());
+    let mut dot = 0.0f32;
+    for i in 0..a.len() {
+        dot += a[i] * b[i];
+    }
+    dot
+}
+
+/// Squared L2 norm of a vector.
+pub fn fvec_norm_l2sqr(a: &[f32]) -> f32 {
+    let mut sum = 0.0f32;
+    for &v in a {
+        sum += v * v;
+    }
+    sum
+}
+
+/// Normalize a vector in-place to unit length. Returns the original norm.
+pub fn fvec_normalize(v: &mut [f32]) -> f32 {
+    let norm = fvec_norm_l2sqr(v).sqrt();
+    if norm > 0.0 {
+        let inv = 1.0 / norm;
+        for x in v.iter_mut() {
+            *x *= inv;
+        }
+    }
+    norm
+}
+
+/// Compute result[i] = a[i] + bf * b[i]. Used for precomputed table merging.
+/// Aligned with Faiss's fvec_madd.
+pub fn fvec_madd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+    debug_assert_eq!(a.len(), b.len());
+    debug_assert_eq!(a.len(), result.len());
+    fvec_madd_simd(a, b, bf, result);
+}
+
+#[cfg(target_arch = "x86_64")]
+fn fvec_madd_simd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+    if is_x86_feature_detected!("avx2") {
+        unsafe { fvec_madd_avx2(a, b, bf, result) };
+    } else {
+        fvec_madd_scalar(a, b, bf, result);
+    }
+}
+
+#[cfg(target_arch = "aarch64")]
+fn fvec_madd_simd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+    unsafe { fvec_madd_neon(a, b, bf, result) }
+}
+
+#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
+fn fvec_madd_simd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+    fvec_madd_scalar(a, b, bf, result);
+}
+
+#[inline]
+#[allow(dead_code)]
+fn fvec_madd_scalar(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+    for i in 0..a.len() {
+        result[i] = a[i] + bf * b[i];
+    }
+}
+
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn fvec_madd_avx2(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+    use std::arch::x86_64::*;
+    let n = a.len();
+    let vbf = _mm256_set1_ps(bf);
+    let mut i = 0;
+    while i + 8 <= n {
+        let va = _mm256_loadu_ps(a.as_ptr().add(i));
+        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
+        let vr = _mm256_add_ps(va, _mm256_mul_ps(vbf, vb));
+        _mm256_storeu_ps(result.as_mut_ptr().add(i), vr);
+        i += 8;
+    }
+    while i < n {
+        result[i] = a[i] + bf * b[i];
+        i += 1;
+    }
+}
+
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn fvec_madd_neon(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+    use std::arch::aarch64::*;
+    let n = a.len();
+    let vbf = vdupq_n_f32(bf);
+    let mut i = 0;
+    while i + 4 <= n {
+        let va = vld1q_f32(a.as_ptr().add(i));
+        let vb = vld1q_f32(b.as_ptr().add(i));
+        let vr = vmlaq_f32(va, vbf, vb);
+        vst1q_f32(result.as_mut_ptr().add(i), vr);
+        i += 4;
+    }
+    while i < n {
+        result[i] = a[i] + bf * b[i];
+        i += 1;
+    }
+}
+
+/// SIMD-accelerated squared L2 distance for sub-vectors (used by PQ distance 
table).
+pub fn fvec_l2sqr_batch(
+    query_sub: &[f32],
+    centroids: &[f32],
+    dsub: usize,
+    ksub: usize,
+    result: &mut [f32],
+) {
+    for j in 0..ksub {
+        result[j] = fvec_l2sqr_sub(query_sub, 0, centroids, j * dsub, dsub);
+    }
+}
+
+/// SIMD-accelerated inner product for sub-vectors (used by PQ distance table).
+pub fn fvec_ip_batch(
+    query_sub: &[f32],
+    centroids: &[f32],
+    dsub: usize,
+    ksub: usize,
+    result: &mut [f32],
+) {
+    for j in 0..ksub {
+        let mut dot = 0.0f32;
+        for d in 0..dsub {
+            dot += query_sub[d] * centroids[j * dsub + d];
+        }
+        result[j] = dot;
+    }
+}
+
+/// Scan a batch of 4-bit PQ codes.
+/// Approach (aligned with Lance/Faiss):
+///   1. Compute first FLAT_NUM vectors with exact f32 (calibrate qmax)
+///   2. Quantize distance table to u8
+///   3. Accumulate distances in u8 domain via SIMD shuffle
+///   4. Dequantize back to f32 at the end
+///
+/// codes: nibble-packed [count * (m/2)], row-major.
+/// sim_table: [M * 16] f32 distance table.
+pub fn scan_4bit_simd(sim_table: &[f32], codes: &[u8], count: usize, m: usize, 
dists: &mut [f32]) {
+    const FLAT_NUM: usize = 200;
+
+    let cs = m / 2; // code_size = m/2 bytes per vector
+
+    // Step 1: Compute first FLAT_NUM vectors with f32 precision
+    let flat_end = count.min(FLAT_NUM);
+    for i in 0..flat_end {
+        let base = i * cs;
+        let mut d = 0.0f32;
+        for pair in 0..cs {
+            let byte = codes[base + pair];
+            let lo = (byte & 0x0F) as usize;
+            let hi = ((byte >> 4) & 0x0F) as usize;
+            d += sim_table[(pair * 2) * 16 + lo];
+            d += sim_table[(pair * 2 + 1) * 16 + hi];
+        }
+        dists[i] = d;
+    }
+
+    if count <= FLAT_NUM {
+        return;
+    }
+
+    // Step 2: Determine qmax from the first FLAT_NUM distances
+    let qmax = dists[..flat_end].iter().cloned().fold(f32::MIN, f32::max);
+
+    // Quantize the entire distance table [M * 16] to u8
+    let qmin = sim_table.iter().cloned().fold(f32::INFINITY, f32::min);
+    let range = (qmax - qmin).max(1e-10);
+    let factor = 255.0 / range;
+
+    let qtable: Vec<u8> = sim_table
+        .iter()
+        .map(|&d| ((d - qmin) * factor).clamp(0.0, 255.0) as u8)
+        .collect();
+
+    // Step 3: Scan remaining vectors in u8 domain
+    // Use u16 accumulators to avoid overflow (M/2 pairs × max 255 per pair × 
2 ≤ 65535 for M ≤ 256)
+    let mut q_dists = vec![0u16; count];
+
+    for pair in 0..cs {
+        let qtab_lo = &qtable[(pair * 2) * 16..(pair * 2 + 1) * 16];
+        let qtab_hi = &qtable[(pair * 2 + 1) * 16..(pair * 2 + 2) * 16];
+
+        // SIMD-friendly inner loop: sequential code access, 16-entry table 
fits in register
+        for i in flat_end..count {
+            let byte = codes[i * cs + pair];
+            let lo = (byte & 0x0F) as usize;
+            let hi = ((byte >> 4) & 0x0F) as usize;
+            q_dists[i] += qtab_lo[lo] as u16 + qtab_hi[hi] as u16;
+        }
+    }
+
+    // Step 4: Dequantize back to f32
+    let inv_factor = range / 255.0;
+    let base_dist = qmin * m as f32; // M sub-quantizers each contribute at 
least qmin
+    for i in flat_end..count {
+        dists[i] = q_dists[i] as f32 * inv_factor + base_dist;
+    }
+}
+
+/// Compute PQ distance from a precomputed distance table.
+/// table layout: [M][ksub], codes: M bytes.
+/// Each code[m] indexes into table[m * ksub + code[m]].
+#[inline]
+pub fn pq_distance_from_table(table: &[f32], codes: &[u8], m: usize, ksub: 
usize) -> f32 {
+    pq_distance_from_table_simd(table, codes, m, ksub)
+}
+
+/// Process 4 codes at once for better instruction-level parallelism.
+#[inline]
+pub fn pq_distance_four_codes(
+    table: &[f32],
+    codes: &[u8],
+    m: usize,
+    ksub: usize,
+    offsets: [usize; 4],
+) -> [f32; 4] {
+    let mut dists = [0.0f32; 4];
+    for i in 0..m {
+        let base = i * ksub;
+        for j in 0..4 {
+            dists[j] += table[base + codes[offsets[j] + i] as usize];
+        }
+    }
+    dists
+}
+
+// SIMD-accelerated PQ distance table lookup.
+#[cfg(target_arch = "x86_64")]
+#[inline]
+fn pq_distance_from_table_simd(table: &[f32], codes: &[u8], m: usize, ksub: 
usize) -> f32 {
+    if is_x86_feature_detected!("avx2") && m >= 8 && ksub == 256 {
+        unsafe { pq_distance_avx2(table, codes, m) }
+    } else {
+        pq_distance_scalar(table, codes, m, ksub)
+    }
+}
+
+#[cfg(target_arch = "aarch64")]
+#[inline]
+fn pq_distance_from_table_simd(table: &[f32], codes: &[u8], m: usize, ksub: 
usize) -> f32 {
+    if ksub == 256 && m >= 4 {
+        unsafe { pq_distance_neon(table, codes, m) }
+    } else {
+        pq_distance_scalar(table, codes, m, ksub)
+    }
+}
+
+/// NEON-accelerated PQ distance with manual gather + vaddq_f32 accumulation.
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn pq_distance_neon(table: &[f32], codes: &[u8], m: usize) -> f32 {
+    use std::arch::aarch64::*;
+
+    let ksub = 256usize;
+    let mut sum = vdupq_n_f32(0.0);
+    let mut i = 0;
+
+    while i + 4 <= m {
+        let d0 = *table.get_unchecked(i * ksub + *codes.get_unchecked(i) as 
usize);
+        let d1 = *table.get_unchecked((i + 1) * ksub + *codes.get_unchecked(i 
+ 1) as usize);
+        let d2 = *table.get_unchecked((i + 2) * ksub + *codes.get_unchecked(i 
+ 2) as usize);
+        let d3 = *table.get_unchecked((i + 3) * ksub + *codes.get_unchecked(i 
+ 3) as usize);
+
+        let arr = [d0, d1, d2, d3];
+        let v = vld1q_f32(arr.as_ptr());
+        sum = vaddq_f32(sum, v);
+        i += 4;
+    }
+
+    let mut result = vaddvq_f32(sum);
+
+    while i < m {
+        result += *table.get_unchecked(i * ksub + *codes.get_unchecked(i) as 
usize);
+        i += 1;
+    }
+
+    result
+}
+
+#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
+#[inline]
+fn pq_distance_from_table_simd(table: &[f32], codes: &[u8], m: usize, ksub: 
usize) -> f32 {
+    pq_distance_scalar(table, codes, m, ksub)
+}
+
+#[inline]
+fn pq_distance_scalar(table: &[f32], codes: &[u8], m: usize, ksub: usize) -> 
f32 {
+    let mut dist = 0.0f32;
+    for i in 0..m {
+        dist += table[i * ksub + codes[i] as usize];
+    }
+    dist
+}
+
+/// AVX2 PQ distance using gather instructions.
+/// Aligned with Faiss's pq_code_distance-avx2.h.
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn pq_distance_avx2(table: &[f32], codes: &[u8], m: usize) -> f32 {
+    use std::arch::x86_64::*;
+
+    let ksub = 256usize;
+    let mut sum = _mm256_setzero_ps();
+    let mut i = 0;
+
+    // Process 8 sub-quantizers at a time
+    while i + 8 <= m {
+        let offsets = _mm256_set_epi32(
+            (7 * ksub + codes[i + 7] as usize) as i32,
+            (6 * ksub + codes[i + 6] as usize) as i32,
+            (5 * ksub + codes[i + 5] as usize) as i32,
+            (4 * ksub + codes[i + 4] as usize) as i32,
+            (3 * ksub + codes[i + 3] as usize) as i32,
+            (2 * ksub + codes[i + 2] as usize) as i32,
+            (ksub + codes[i + 1] as usize) as i32,
+            (codes[i] as usize) as i32,
+        );
+
+        let tab_ptr = table.as_ptr().add(i * ksub);
+        let gathered = _mm256_i32gather_ps::<4>(tab_ptr, offsets);
+        sum = _mm256_add_ps(sum, gathered);
+        i += 8;
+    }
+
+    // Horizontal sum of the 8 floats in sum
+    let hi = _mm256_extractf128_ps::<1>(sum);
+    let lo = _mm256_castps256_ps128(sum);
+    let sum128 = _mm_add_ps(lo, hi);
+    let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
+    let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps::<1>(sum64, sum64));
+    let mut result = _mm_cvtss_f32(sum32);
+
+    // Handle remaining sub-quantizers
+    while i < m {
+        result += table[i * ksub + codes[i] as usize];
+        i += 1;
+    }
+
+    result
+}
+
+/// Compute distance between query and a set of vectors, return top-k.
+pub fn fvec_distances_batch(
+    query: &[f32],
+    vectors: &[f32],
+    n: usize,
+    d: usize,
+    metric: MetricType,
+    distances: &mut [f32],
+) {
+    for i in 0..n {
+        let vec = &vectors[i * d..(i + 1) * d];
+        distances[i] = match metric {
+            MetricType::L2 => fvec_l2sqr(query, vec),
+            MetricType::InnerProduct => -fvec_inner_product(query, vec),
+            MetricType::Cosine => {
+                let dot = fvec_inner_product(query, vec);
+                let na = fvec_norm_l2sqr(query).sqrt();
+                let nb = fvec_norm_l2sqr(vec).sqrt();
+                let denom = na * nb;
+                if denom > 0.0 {
+                    1.0 - dot / denom
+                } else {
+                    1.0
+                }
+            }
+        };
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_l2sqr() {
+        let a = [1.0, 2.0, 3.0];
+        let b = [4.0, 5.0, 6.0];
+        assert!((fvec_l2sqr(&a, &b) - 27.0).abs() < 1e-6);
+    }
+
+    #[test]
+    fn test_inner_product() {
+        let a = [1.0, 2.0, 3.0];
+        let b = [4.0, 5.0, 6.0];
+        assert!((fvec_inner_product(&a, &b) - 32.0).abs() < 1e-6);
+    }
+
+    #[test]
+    fn test_normalize() {
+        let mut v = [3.0, 4.0];
+        fvec_normalize(&mut v);
+        assert!((v[0] - 0.6).abs() < 1e-6);
+        assert!((v[1] - 0.8).abs() < 1e-6);
+    }
+
+    #[test]
+    fn test_pq_distance_scalar() {
+        let table = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]; // 2 sub-q, 
4 centroids
+        let codes = [1u8, 3u8];
+        let dist = pq_distance_scalar(&table, &codes, 2, 4);
+        // table[0*4 + 1] + table[1*4 + 3] = 0.2 + 0.8 = 1.0
+        assert!((dist - 1.0).abs() < 1e-6);
+    }
+}
diff --git a/core/src/kmeans.rs b/core/src/kmeans.rs
new file mode 100644
index 0000000..72d8069
--- /dev/null
+++ b/core/src/kmeans.rs
@@ -0,0 +1,875 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::blas::sgemm_a_bt;
+use crate::distance::{fvec_l2sqr, fvec_norm_l2sqr};
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
+
+pub struct KMeansConfig {
+    pub niter: usize,
+    pub nredo: usize,
+    pub max_points_per_centroid: usize,
+    pub seed: u64,
+    /// Balance factor: penalizes large clusters to produce more uniform 
partitions.
+    /// 0.0 = standard k-means. Higher values = more balanced.
+    /// Typical value: 0.1 for IVF construction.
+    pub balance_factor: f32,
+}
+
+impl Default for KMeansConfig {
+    fn default() -> Self {
+        KMeansConfig {
+            niter: 25,
+            nredo: 1,
+            max_points_per_centroid: 256,
+            seed: 1234,
+            balance_factor: 0.0,
+        }
+    }
+}
+
+const EPS: f32 = 1.0 / 1024.0;
+
+/// Threshold above which hierarchical k-means is used.
+const HIERARCHICAL_THRESHOLD: usize = 256;
+
+pub fn kmeans_train(config: &KMeansConfig, data: &[f32], n: usize, d: usize, 
k: usize) -> Vec<f32> {
+    if k > HIERARCHICAL_THRESHOLD && n > k {
+        kmeans_train_hierarchical(config, data, n, d, k)
+    } else {
+        kmeans_train_with_init(config, data, n, d, k, None)
+    }
+}
+
+/// Hierarchical k-means for large k (> 256).
+/// Starts with initial_k clusters and iteratively splits the largest until 
target k is reached.
+fn kmeans_train_hierarchical(
+    config: &KMeansConfig,
+    data: &[f32],
+    n: usize,
+    d: usize,
+    target_k: usize,
+) -> Vec<f32> {
+    use std::cmp::Ordering;
+    use std::collections::BinaryHeap;
+
+    #[derive(Clone)]
+    struct Cluster {
+        centroid: Vec<f32>,
+        indices: Vec<usize>,
+    }
+
+    impl Eq for Cluster {}
+    impl PartialEq for Cluster {
+        fn eq(&self, other: &Self) -> bool {
+            self.indices.len() == other.indices.len()
+        }
+    }
+    impl Ord for Cluster {
+        fn cmp(&self, other: &Self) -> Ordering {
+            self.indices.len().cmp(&other.indices.len())
+        }
+    }
+    impl PartialOrd for Cluster {
+        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+            Some(self.cmp(other))
+        }
+    }
+
+    let mut rng = StdRng::seed_from_u64(config.seed);
+
+    // Subsample for training
+    let max_n = target_k * config.max_points_per_centroid;
+    let (train_data, train_n) = if n > max_n {
+        let sub = subsample(data, n, d, max_n, &mut rng);
+        (sub, max_n)
+    } else {
+        (data.to_vec(), n)
+    };
+
+    // Step 1: Train initial_k clusters
+    let initial_k = 16.min(target_k);
+    let initial_config = KMeansConfig {
+        niter: config.niter,
+        seed: config.seed,
+        ..KMeansConfig::default()
+    };
+    let initial_centroids =
+        kmeans_train_with_init(&initial_config, &train_data, train_n, d, 
initial_k, None);
+
+    // Assign all points to initial clusters
+    let mut assignments = vec![0usize; train_n];
+    assign_clusters_fast(
+        &train_data,
+        train_n,
+        d,
+        &initial_centroids,
+        initial_k,
+        &mut assignments,
+        0.0,
+    );
+
+    // Build initial clusters
+    let mut heap: BinaryHeap<Cluster> = BinaryHeap::new();
+    for c in 0..initial_k {
+        let indices: Vec<usize> = (0..train_n).filter(|&i| assignments[i] == 
c).collect();
+        let centroid = initial_centroids[c * d..(c + 1) * d].to_vec();
+        heap.push(Cluster { centroid, indices });
+    }
+
+    // Step 2: Iteratively split the largest cluster
+    let mut finalized: Vec<Vec<f32>> = Vec::new();
+    let split_k = 2; // Split into 2 each time
+
+    while finalized.len() + heap.len() < target_k {
+        let largest = match heap.pop() {
+            Some(c) => c,
+            None => break,
+        };
+
+        if largest.indices.len() < split_k * 2 {
+            finalized.push(largest.centroid);
+            continue;
+        }
+
+        // Extract sub-data for this cluster
+        let sub_n = largest.indices.len();
+        let mut sub_data = vec![0.0f32; sub_n * d];
+        for (new_idx, &orig_idx) in largest.indices.iter().enumerate() {
+            sub_data[new_idx * d..(new_idx + 1) * d]
+                .copy_from_slice(&train_data[orig_idx * d..(orig_idx + 1) * 
d]);
+        }
+
+        // Run k-means to split
+        let sub_config = KMeansConfig {
+            niter: 10,
+            seed: config.seed + finalized.len() as u64,
+            ..KMeansConfig::default()
+        };
+        let sub_centroids = kmeans_train_with_init(&sub_config, &sub_data, 
sub_n, d, split_k, None);
+
+        // Reassign points in this cluster
+        let mut sub_assignments = vec![0usize; sub_n];
+        assign_clusters_fast(
+            &sub_data,
+            sub_n,
+            d,
+            &sub_centroids,
+            split_k,
+            &mut sub_assignments,
+            0.0,
+        );
+
+        for sc in 0..split_k {
+            let sub_indices: Vec<usize> = (0..sub_n)
+                .filter(|&i| sub_assignments[i] == sc)
+                .map(|i| largest.indices[i])
+                .collect();
+            let centroid = sub_centroids[sc * d..(sc + 1) * d].to_vec();
+            if !sub_indices.is_empty() {
+                heap.push(Cluster {
+                    centroid,
+                    indices: sub_indices,
+                });
+            }
+        }
+    }
+
+    // Collect all centroids
+    let mut result = Vec::with_capacity(target_k * d);
+    for c in finalized {
+        result.extend_from_slice(&c);
+    }
+    while let Some(cluster) = heap.pop() {
+        result.extend_from_slice(&cluster.centroid);
+        if result.len() >= target_k * d {
+            break;
+        }
+    }
+
+    // Pad if needed
+    result.resize(target_k * d, 0.0);
+    result
+}
+
+pub fn kmeans_train_with_init(
+    config: &KMeansConfig,
+    data: &[f32],
+    n: usize,
+    d: usize,
+    k: usize,
+    initial_centroids: Option<&[f32]>,
+) -> Vec<f32> {
+    if n == 0 || k == 0 {
+        return vec![0.0; k * d];
+    }
+
+    let mut rng = StdRng::seed_from_u64(config.seed);
+
+    let max_n = k * config.max_points_per_centroid;
+    let (train_data, train_n) = if n > max_n {
+        let sub = subsample(data, n, d, max_n, &mut rng);
+        (sub, max_n)
+    } else {
+        (data.to_vec(), n)
+    };
+
+    if train_n <= k {
+        let mut centroids = vec![0.0f32; k * d];
+        for i in 0..k {
+            let src = i % train_n;
+            centroids[i * d..(i + 1) * d].copy_from_slice(&train_data[src * 
d..(src + 1) * d]);
+        }
+        return centroids;
+    }
+
+    let mut best_centroids = vec![0.0f32; k * d];
+    let mut best_obj = f32::MAX;
+
+    let nredo = if initial_centroids.is_some() {
+        1
+    } else {
+        config.nredo
+    };
+
+    for redo in 0..nredo {
+        let mut centroids = if redo == 0 {
+            if let Some(init) = initial_centroids {
+                init.to_vec()
+            } else {
+                kmeans_plusplus_init(&train_data, train_n, d, k, &mut rng)
+            }
+        } else {
+            kmeans_plusplus_init(&train_data, train_n, d, k, &mut rng)
+        };
+        let mut assignments = vec![0usize; train_n];
+        let mut prev_obj = f32::MAX;
+
+        for _iter in 0..config.niter {
+            let obj = assign_clusters_fast(
+                &train_data,
+                train_n,
+                d,
+                &centroids,
+                k,
+                &mut assignments,
+                config.balance_factor,
+            );
+            update_centroids(
+                &train_data,
+                train_n,
+                d,
+                &mut centroids,
+                k,
+                &assignments,
+                &mut rng,
+            );
+
+            if prev_obj < f32::MAX {
+                let rel_change = (prev_obj - obj).abs() / prev_obj.max(1e-10);
+                if rel_change < 1e-6 {
+                    break;
+                }
+            }
+            prev_obj = obj;
+        }
+
+        if prev_obj < best_obj {
+            best_obj = prev_obj;
+            best_centroids.copy_from_slice(&centroids);
+        }
+    }
+
+    best_centroids
+}
+
+fn kmeans_plusplus_init(data: &[f32], n: usize, d: usize, k: usize, rng: &mut 
StdRng) -> Vec<f32> {
+    let mut centroids = vec![0.0f32; k * d];
+
+    let first = rng.gen_range(0..n);
+    centroids[..d].copy_from_slice(&data[first * d..(first + 1) * d]);
+
+    let mut min_dists = vec![f32::MAX; n];
+
+    for c in 1..k {
+        let prev = &centroids[(c - 1) * d..c * d];
+        let mut total = 0.0f32;
+        for i in 0..n {
+            let dist = fvec_l2sqr(&data[i * d..(i + 1) * d], prev);
+            if dist < min_dists[i] {
+                min_dists[i] = dist;
+            }
+            total += min_dists[i];
+        }
+
+        let target = rng.gen::<f32>() * total;
+        let mut cumulative = 0.0f32;
+        let mut selected = n - 1;
+        for i in 0..n {
+            cumulative += min_dists[i];
+            if cumulative >= target {
+                selected = i;
+                break;
+            }
+        }
+
+        centroids[c * d..(c + 1) * d].copy_from_slice(&data[selected * 
d..(selected + 1) * d]);
+    }
+
+    centroids
+}
+
+/// Fast assignment using sgemm: ||x-c||² = ||x||² + ||c||² - 2·x·cᵀ.
+/// Supports balance_factor to penalize large clusters.
+fn assign_clusters_fast(
+    data: &[f32],
+    n: usize,
+    d: usize,
+    centroids: &[f32],
+    k: usize,
+    assignments: &mut [usize],
+    balance_factor: f32,
+) -> f32 {
+    // Cap ip_matrix size to ~16MB. Chunk if n*k would be too large.
+    const MAX_MATRIX_ELEMS: usize = 4 * 1024 * 1024; // 16MB / 4 bytes
+    if n * k > MAX_MATRIX_ELEMS {
+        let chunk_n = MAX_MATRIX_ELEMS / k;
+        let mut total_obj = 0.0f32;
+        let mut offset = 0;
+        while offset < n {
+            let cn = (n - offset).min(chunk_n);
+            total_obj += assign_clusters_fast(
+                &data[offset * d..(offset + cn) * d],
+                cn,
+                d,
+                centroids,
+                k,
+                &mut assignments[offset..offset + cn],
+                balance_factor,
+            );
+            offset += cn;
+        }
+        return total_obj;
+    }
+
+    let x_norms: Vec<f32> = (0..n)
+        .map(|i| fvec_norm_l2sqr(&data[i * d..(i + 1) * d]))
+        .collect();
+    let c_norms: Vec<f32> = (0..k)
+        .map(|c| fvec_norm_l2sqr(&centroids[c * d..(c + 1) * d]))
+        .collect();
+
+    let mut ip_matrix = vec![0.0f32; n * k];
+    sgemm_a_bt(n, k, d, 1.0, data, centroids, 0.0, &mut ip_matrix);
+
+    // Compute cluster sizes for balance penalty
+    let mut cluster_sizes = vec![0u32; k];
+    if balance_factor > 0.0 {
+        for &a in assignments.iter() {
+            if a < k {
+                cluster_sizes[a] += 1;
+            }
+        }
+    }
+
+    let mut total_obj = 0.0f32;
+    for i in 0..n {
+        let mut best = 0;
+        let mut best_dist = f32::MAX;
+        let row = i * k;
+        for c in 0..k {
+            let mut dist = x_norms[i] + c_norms[c] - 2.0 * ip_matrix[row + c];
+            // Balance penalty: prefer smaller clusters
+            if balance_factor > 0.0 && cluster_sizes[c] > 0 {
+                dist += balance_factor * (cluster_sizes[c] as f32).ln();
+            }
+            if dist < best_dist {
+                best_dist = dist;
+                best = c;
+            }
+        }
+        assignments[i] = best;
+        total_obj += best_dist;
+    }
+
+    total_obj
+}
+
+fn update_centroids(
+    data: &[f32],
+    n: usize,
+    d: usize,
+    centroids: &mut [f32],
+    k: usize,
+    assignments: &[usize],
+    rng: &mut StdRng,
+) {
+    let mut counts = vec![0usize; k];
+    let mut sums = vec![0.0f32; k * d];
+
+    for i in 0..n {
+        let c = assignments[i];
+        counts[c] += 1;
+        for j in 0..d {
+            sums[c * d + j] += data[i * d + j];
+        }
+    }
+
+    for c in 0..k {
+        if counts[c] > 0 {
+            let inv = 1.0 / counts[c] as f32;
+            for j in 0..d {
+                centroids[c * d + j] = sums[c * d + j] * inv;
+            }
+        }
+    }
+
+    for c in 0..k {
+        if counts[c] > 0 {
+            continue;
+        }
+
+        let donor = counts
+            .iter()
+            .enumerate()
+            .max_by_key(|(_, &cnt)| cnt)
+            .map(|(idx, _)| idx)
+            .unwrap_or(0);
+
+        if counts[donor] <= 1 {
+            let idx = rng.gen_range(0..n);
+            centroids[c * d..(c + 1) * d].copy_from_slice(&data[idx * d..(idx 
+ 1) * d]);
+            continue;
+        }
+
+        let donor_copy: Vec<f32> = centroids[donor * d..(donor + 1) * 
d].to_vec();
+        centroids[c * d..(c + 1) * d].copy_from_slice(&donor_copy);
+
+        for j in 0..d {
+            if j.is_multiple_of(2) {
+                centroids[c * d + j] *= 1.0 + EPS;
+                centroids[donor * d + j] *= 1.0 - EPS;
+            } else {
+                centroids[c * d + j] *= 1.0 - EPS;
+                centroids[donor * d + j] *= 1.0 + EPS;
+            }
+        }
+
+        counts[c] = counts[donor] / 2;
+        counts[donor] -= counts[c];
+    }
+}
+
+pub fn find_nearest(point: &[f32], centroids: &[f32], k: usize, d: usize) -> 
usize {
+    let mut best = 0;
+    let mut best_dist = f32::MAX;
+    for c in 0..k {
+        let dist = fvec_l2sqr(point, &centroids[c * d..(c + 1) * d]);
+        if dist < best_dist {
+            best_dist = dist;
+            best = c;
+        }
+    }
+    best
+}
+
+pub fn find_topk(
+    point: &[f32],
+    centroids: &[f32],
+    k: usize,
+    d: usize,
+    nprobe: usize,
+) -> (Vec<usize>, Vec<f32>) {
+    let nprobe = nprobe.min(k);
+    let mut dists: Vec<(f32, usize)> = (0..k)
+        .map(|c| (fvec_l2sqr(point, &centroids[c * d..(c + 1) * d]), c))
+        .collect();
+    dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+    let indices: Vec<usize> = dists[..nprobe].iter().map(|&(_, i)| 
i).collect();
+    let distances: Vec<f32> = dists[..nprobe].iter().map(|&(d, _)| 
d).collect();
+    (indices, distances)
+}
+
+/// Batch find top-nprobe nearest centroids for multiple queries using sgemm.
+/// Returns (all_indices, all_distances) each of length nq * nprobe.
+pub fn find_topk_batch(
+    queries: &[f32],
+    nq: usize,
+    centroids: &[f32],
+    k: usize,
+    d: usize,
+    nprobe: usize,
+) -> (Vec<Vec<usize>>, Vec<Vec<f32>>) {
+    let nprobe = nprobe.min(k);
+
+    if nq == 1 {
+        let (indices, distances) = find_topk(&queries[..d], centroids, k, d, 
nprobe);
+        return (vec![indices], vec![distances]);
+    }
+
+    // Precompute norms
+    let q_norms: Vec<f32> = (0..nq)
+        .map(|i| fvec_norm_l2sqr(&queries[i * d..(i + 1) * d]))
+        .collect();
+    let c_norms: Vec<f32> = (0..k)
+        .map(|c| fvec_norm_l2sqr(&centroids[c * d..(c + 1) * d]))
+        .collect();
+
+    // Batch inner products: ip[nq × k] = queries[nq × d] · centroids[k × d]^T
+    let mut ip_matrix = vec![0.0f32; nq * k];
+    sgemm_a_bt(nq, k, d, 1.0, queries, centroids, 0.0, &mut ip_matrix);
+
+    // Extract top-nprobe per query
+    let mut all_indices = Vec::with_capacity(nq);
+    let mut all_distances = Vec::with_capacity(nq);
+
+    for qi in 0..nq {
+        let row = qi * k;
+        let mut dists: Vec<(f32, usize)> = (0..k)
+            .map(|c| {
+                let dist = q_norms[qi] + c_norms[c] - 2.0 * ip_matrix[row + c];
+                (dist.max(0.0), c)
+            })
+            .collect();
+        dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+
+        all_indices.push(dists[..nprobe].iter().map(|&(_, i)| i).collect());
+        all_distances.push(dists[..nprobe].iter().map(|&(d, _)| d).collect());
+    }
+
+    (all_indices, all_distances)
+}
+
+// --- Streaming Coreset K-means ---
+
+/// Streaming k-means trainer for very large datasets.
+/// Processes data in chunks, compresses each chunk into a weighted coreset,
+/// then trains final centroids on the accumulated coreset.
+pub struct StreamingKMeans {
+    pub d: usize,
+    pub k: usize,
+    pub chunk_size: usize,
+    config: KMeansConfig,
+    /// Accumulated coreset: (centroids, weights)
+    coreset_centroids: Vec<f32>,
+    coreset_weights: Vec<f32>,
+}
+
+impl StreamingKMeans {
+    /// Create a streaming k-means trainer.
+    /// chunk_size: number of vectors per chunk (e.g., k * 256)
+    pub fn new(d: usize, k: usize, chunk_size: usize, config: KMeansConfig) -> 
Self {
+        StreamingKMeans {
+            d,
+            k,
+            chunk_size,
+            config,
+            coreset_centroids: Vec::new(),
+            coreset_weights: Vec::new(),
+        }
+    }
+
+    /// Feed a chunk of training data. Can be called multiple times.
+    /// Each chunk is compressed into k weighted centroids (coreset).
+    pub fn add_chunk(&mut self, data: &[f32], n: usize) {
+        let d = self.d;
+        let chunk_k = self.k.min(n);
+
+        if chunk_k == 0 || n == 0 {
+            return;
+        }
+
+        // Train k-means on this chunk
+        let chunk_config = KMeansConfig {
+            niter: 15,
+            seed: self.config.seed + self.coreset_weights.len() as u64,
+            ..KMeansConfig::default()
+        };
+        let centroids = kmeans_train_with_init(&chunk_config, data, n, d, 
chunk_k, None);
+
+        // Assign points to centroids to compute weights
+        let mut assignments = vec![0usize; n];
+        assign_clusters_fast(data, n, d, &centroids, chunk_k, &mut 
assignments, 0.0);
+
+        let mut weights = vec![0.0f32; chunk_k];
+        for &a in &assignments {
+            weights[a] += 1.0;
+        }
+
+        // Append to coreset
+        self.coreset_centroids.extend_from_slice(&centroids);
+        self.coreset_weights.extend_from_slice(&weights);
+    }
+
+    /// Finalize: train final centroids on the accumulated weighted coreset.
+    pub fn finalize(&self) -> Vec<f32> {
+        let d = self.d;
+        let coreset_n = self.coreset_weights.len();
+
+        if coreset_n == 0 {
+            return vec![0.0f32; self.k * d];
+        }
+
+        if coreset_n <= self.k {
+            let mut result = self.coreset_centroids.clone();
+            result.resize(self.k * d, 0.0);
+            return result;
+        }
+
+        // Weighted k-means on coreset
+        weighted_kmeans_train(
+            &self.config,
+            &self.coreset_centroids,
+            &self.coreset_weights,
+            coreset_n,
+            d,
+            self.k,
+        )
+    }
+
+    /// Total vectors processed so far.
+    pub fn total_weight(&self) -> f32 {
+        self.coreset_weights.iter().sum()
+    }
+}
+
+/// Weighted k-means: each point has a weight that affects centroid 
computation.
+fn weighted_kmeans_train(
+    config: &KMeansConfig,
+    data: &[f32],
+    weights: &[f32],
+    n: usize,
+    d: usize,
+    k: usize,
+) -> Vec<f32> {
+    let mut rng = StdRng::seed_from_u64(config.seed);
+
+    if n <= k {
+        let mut centroids = vec![0.0f32; k * d];
+        for i in 0..k {
+            let src = i % n;
+            centroids[i * d..(i + 1) * d].copy_from_slice(&data[src * d..(src 
+ 1) * d]);
+        }
+        return centroids;
+    }
+
+    let mut centroids = kmeans_plusplus_init(data, n, d, k, &mut rng);
+    let mut assignments = vec![0usize; n];
+
+    for _iter in 0..config.niter {
+        // Assign (unweighted distance)
+        assign_clusters_fast(data, n, d, &centroids, k, &mut assignments, 0.0);
+
+        // Update with weights
+        let mut sums = vec![0.0f32; k * d];
+        let mut total_weights = vec![0.0f32; k];
+
+        for i in 0..n {
+            let c = assignments[i];
+            let w = weights[i];
+            total_weights[c] += w;
+            for j in 0..d {
+                sums[c * d + j] += w * data[i * d + j];
+            }
+        }
+
+        for c in 0..k {
+            if total_weights[c] > 0.0 {
+                let inv = 1.0 / total_weights[c];
+                for j in 0..d {
+                    centroids[c * d + j] = sums[c * d + j] * inv;
+                }
+            } else {
+                // Reinit empty cluster
+                let idx = rng.gen_range(0..n);
+                centroids[c * d..(c + 1) * d].copy_from_slice(&data[idx * 
d..(idx + 1) * d]);
+            }
+        }
+    }
+
+    centroids
+}
+
+fn subsample(data: &[f32], n: usize, d: usize, target_n: usize, rng: &mut 
StdRng) -> Vec<f32> {
+    let mut indices: Vec<usize> = (0..n).collect();
+    for i in 0..target_n {
+        let j = rng.gen_range(i..n);
+        indices.swap(i, j);
+    }
+    let mut result = vec![0.0f32; target_n * d];
+    for (out_i, &src_i) in indices[..target_n].iter().enumerate() {
+        result[out_i * d..(out_i + 1) * d].copy_from_slice(&data[src_i * 
d..(src_i + 1) * d]);
+    }
+    result
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_two_clusters() {
+        let mut data = Vec::new();
+        for _ in 0..50 {
+            data.push(0.1);
+            data.push(0.1);
+        }
+        for _ in 0..50 {
+            data.push(10.1);
+            data.push(10.1);
+        }
+
+        let config = KMeansConfig::default();
+        let centroids = kmeans_train(&config, &data, 100, 2, 2);
+
+        let c0 = if centroids[0] < 5.0 {
+            &centroids[0..2]
+        } else {
+            &centroids[2..4]
+        };
+        let c1 = if centroids[0] < 5.0 {
+            &centroids[2..4]
+        } else {
+            &centroids[0..2]
+        };
+
+        assert!(c0[0] < 2.0 && c0[1] < 2.0);
+        assert!(c1[0] > 8.0 && c1[1] > 8.0);
+    }
+
+    #[test]
+    fn test_find_topk() {
+        let centroids = [0.0, 0.0, 10.0, 0.0, 5.0, 5.0];
+        let query = [1.0, 1.0];
+        let (indices, _) = find_topk(&query, &centroids, 3, 2, 2);
+        assert_eq!(indices[0], 0);
+    }
+
+    #[test]
+    fn test_hot_start_converges_faster() {
+        let mut rng = StdRng::seed_from_u64(42);
+        let n = 500;
+        let d = 4;
+        let k = 4;
+
+        let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>() * 
10.0).collect();
+
+        let config = KMeansConfig {
+            niter: 25,
+            ..KMeansConfig::default()
+        };
+        let centroids = kmeans_train(&config, &data, n, d, k);
+
+        // Hot-start with previous centroids should converge in fewer 
iterations
+        let config2 = KMeansConfig {
+            niter: 3,
+            ..KMeansConfig::default()
+        };
+        let centroids2 = kmeans_train_with_init(&config2, &data, n, d, k, 
Some(&centroids));
+
+        // Should be very close to the original since it started from 
converged state
+        let mut total_diff = 0.0f32;
+        for i in 0..k * d {
+            total_diff += (centroids[i] - centroids2[i]).abs();
+        }
+        assert!(
+            total_diff < 1.0,
+            "Hot-start centroids drifted too much: {}",
+            total_diff
+        );
+    }
+
+    #[test]
+    fn test_streaming_coreset_kmeans() {
+        let n = 5000;
+        let d = 4;
+        let k = 10;
+        let chunk_size = 1000;
+
+        let mut rng = StdRng::seed_from_u64(42);
+        // Generate clustered data
+        let mut data = Vec::new();
+        for cluster in 0..k {
+            let cx = cluster as f32 * 20.0;
+            let cy = cluster as f32 * 20.0;
+            for _ in 0..n / k {
+                data.push(cx + rng.gen::<f32>() * 2.0);
+                data.push(cy + rng.gen::<f32>() * 2.0);
+                data.push(rng.gen::<f32>());
+                data.push(rng.gen::<f32>());
+            }
+        }
+
+        let config = KMeansConfig::default();
+        let mut streaming = StreamingKMeans::new(d, k, chunk_size, config);
+
+        // Feed data in chunks
+        for chunk_start in (0..n).step_by(chunk_size) {
+            let chunk_end = (chunk_start + chunk_size).min(n);
+            let chunk_n = chunk_end - chunk_start;
+            streaming.add_chunk(&data[chunk_start * d..chunk_end * d], 
chunk_n);
+        }
+
+        assert!((streaming.total_weight() - n as f32).abs() < 1.0);
+
+        let centroids = streaming.finalize();
+        assert_eq!(centroids.len(), k * d);
+
+        // Centroids should be diverse
+        let first = &centroids[0..d];
+        let mut diverse = false;
+        for i in 1..k {
+            if fvec_l2sqr(&centroids[i * d..(i + 1) * d], first) > 1.0 {
+                diverse = true;
+                break;
+            }
+        }
+        assert!(diverse, "Streaming centroids are not diverse");
+    }
+
+    #[test]
+    fn test_hierarchical_kmeans() {
+        let n = 2000;
+        let d = 4;
+        let k = 300; // > 256, triggers hierarchical
+
+        let mut rng = StdRng::seed_from_u64(42);
+        let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>() * 
100.0).collect();
+
+        let config = KMeansConfig::default();
+        let centroids = kmeans_train(&config, &data, n, d, k);
+
+        assert_eq!(centroids.len(), k * d);
+
+        // All centroids should be finite
+        for &v in &centroids {
+            assert!(v.is_finite(), "Non-finite centroid value: {}", v);
+        }
+
+        // Centroids should be diverse (not all the same)
+        let first = &centroids[0..d];
+        let mut all_same = true;
+        for i in 1..k {
+            if &centroids[i * d..(i + 1) * d] != first {
+                all_same = false;
+                break;
+            }
+        }
+        assert!(!all_same, "All centroids are identical");
+    }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index b248758..a2df43f 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -14,3 +14,11 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
+
+#![allow(clippy::needless_range_loop)]
+#![allow(clippy::too_many_arguments)]
+
+pub mod blas;
+pub mod distance;
+pub mod kmeans;
+pub mod pq;
diff --git a/core/src/pq.rs b/core/src/pq.rs
new file mode 100644
index 0000000..bb2e292
--- /dev/null
+++ b/core/src/pq.rs
@@ -0,0 +1,542 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::blas::sgemm_a_bt;
+use crate::distance::{
+    fvec_ip_batch, fvec_l2sqr_batch, fvec_l2sqr_sub, fvec_norm_l2sqr, 
pq_distance_from_table,
+    MetricType,
+};
+use crate::kmeans::{self, KMeansConfig};
+use rayon::prelude::*;
+
+/// Product Quantizer aligned with Faiss's ProductQuantizer.
+///
+/// Splits D-dimensional vectors into M sub-vectors of dimension dsub = D/M,
+/// and independently quantizes each sub-vector with ksub centroids.
+///
+/// Centroid layout: flat [M * ksub * dsub], row-major.
+/// centroids[m][j][d] is at index: m * ksub * dsub + j * dsub + d
+pub struct ProductQuantizer {
+    pub d: usize,
+    pub m: usize,
+    pub nbits: usize,
+    pub dsub: usize,
+    pub ksub: usize,
+    pub centroids: Vec<f32>,
+    /// Pre-computed squared norms of each centroid: [M * ksub].
+    /// Avoids recomputing per query for L2 distance table.
+    pub centroid_norms_cache: Vec<f32>,
+}
+
+impl ProductQuantizer {
+    pub fn new(d: usize, m: usize) -> Self {
+        Self::with_nbits(d, m, 8)
+    }
+
+    pub fn with_nbits(d: usize, m: usize, nbits: usize) -> Self {
+        assert!(
+            d.is_multiple_of(m),
+            "dimension {} must be divisible by m={}",
+            d,
+            m
+        );
+        assert!(
+            nbits == 4 || nbits == 8,
+            "nbits must be 4 or 8, got {}",
+            nbits
+        );
+        if nbits == 4 {
+            assert!(
+                m.is_multiple_of(2),
+                "m must be even for 4-bit PQ, got {}",
+                m
+            );
+        }
+        let dsub = d / m;
+        let ksub = 1 << nbits;
+        ProductQuantizer {
+            d,
+            m,
+            nbits,
+            dsub,
+            ksub,
+            centroids: Vec::new(),
+            centroid_norms_cache: Vec::new(),
+        }
+    }
+
+    /// Train the codebooks from training data.
+    /// data: flat [n * d], n training vectors.
+    pub fn train(&mut self, data: &[f32], n: usize) {
+        self.train_with_config(data, n, &KMeansConfig::default());
+    }
+
+    pub fn train_with_config(&mut self, data: &[f32], n: usize, km_config: 
&KMeansConfig) {
+        self.train_hot_start(data, n, km_config, false);
+    }
+
+    /// Train with optional hot-start: reuse existing centroids as k-means 
initial values.
+    /// Parallelizes across M sub-quantizers with rayon.
+    pub fn train_hot_start(
+        &mut self,
+        data: &[f32],
+        n: usize,
+        km_config: &KMeansConfig,
+        hot_start: bool,
+    ) {
+        let prev_centroids = if hot_start && !self.centroids.is_empty() {
+            Some(self.centroids.clone())
+        } else {
+            None
+        };
+
+        let m = self.m;
+        let d = self.d;
+        let dsub = self.dsub;
+        let ksub = self.ksub;
+
+        // Train all M sub-quantizers in parallel
+        let sub_results: Vec<Vec<f32>> = (0..m)
+            .into_par_iter()
+            .map(|sub| {
+                let offset = sub * dsub;
+
+                let mut sub_data = vec![0.0f32; n * dsub];
+                for i in 0..n {
+                    sub_data[i * dsub..(i + 1) * dsub]
+                        .copy_from_slice(&data[i * d + offset..i * d + offset 
+ dsub]);
+                }
+
+                let init: Option<Vec<f32>> = prev_centroids.as_ref().map(|pc| {
+                    let src = sub * ksub * dsub;
+                    pc[src..src + ksub * dsub].to_vec()
+                });
+
+                kmeans::kmeans_train_with_init(km_config, &sub_data, n, dsub, 
ksub, init.as_deref())
+            })
+            .collect();
+
+        self.centroids = vec![0.0f32; m * ksub * dsub];
+        for (sub, sub_centroids) in sub_results.into_iter().enumerate() {
+            let dst_offset = sub * ksub * dsub;
+            self.centroids[dst_offset..dst_offset + ksub * 
dsub].copy_from_slice(&sub_centroids);
+        }
+        self.rebuild_norms_cache();
+    }
+
+    /// Rebuild the centroid norms cache. Called after training or loading 
centroids.
+    pub fn rebuild_norms_cache(&mut self) {
+        self.centroid_norms_cache = vec![0.0f32; self.m * self.ksub];
+        for sub in 0..self.m {
+            let c_base = sub * self.ksub * self.dsub;
+            for j in 0..self.ksub {
+                let c_off = c_base + j * self.dsub;
+                self.centroid_norms_cache[sub * self.ksub + j] =
+                    fvec_norm_l2sqr(&self.centroids[c_off..c_off + self.dsub]);
+            }
+        }
+    }
+
+    /// Bytes per encoded vector.
+    pub fn code_size(&self) -> usize {
+        if self.nbits == 4 {
+            self.m / 2
+        } else {
+            self.m
+        }
+    }
+
+    /// Encode a single vector into PQ codes.
+    /// For nbits=8: codes has length M (one byte per sub-quantizer).
+    /// For nbits=4: codes has length M/2 (two nibbles per byte).
+    pub fn encode(&self, x: &[f32], codes: &mut [u8]) {
+        if self.nbits == 4 {
+            self.encode_4bit(x, codes);
+        } else {
+            self.encode_8bit(x, codes);
+        }
+    }
+
+    fn encode_8bit(&self, x: &[f32], codes: &mut [u8]) {
+        for sub in 0..self.m {
+            let x_off = sub * self.dsub;
+            let c_base = sub * self.ksub * self.dsub;
+
+            let mut best = 0u8;
+            let mut best_dist = f32::MAX;
+            for j in 0..self.ksub {
+                let c_off = c_base + j * self.dsub;
+                let dist = fvec_l2sqr_sub(x, x_off, &self.centroids, c_off, 
self.dsub);
+                if dist < best_dist {
+                    best_dist = dist;
+                    best = j as u8;
+                }
+            }
+            codes[sub] = best;
+        }
+    }
+
+    fn encode_4bit(&self, x: &[f32], codes: &mut [u8]) {
+        for pair in 0..self.m / 2 {
+            let sub_lo = pair * 2;
+            let sub_hi = pair * 2 + 1;
+
+            let mut best_lo = 0u8;
+            let mut best_dist_lo = f32::MAX;
+            let x_off_lo = sub_lo * self.dsub;
+            let c_base_lo = sub_lo * self.ksub * self.dsub;
+            for j in 0..self.ksub {
+                let dist = fvec_l2sqr_sub(
+                    x,
+                    x_off_lo,
+                    &self.centroids,
+                    c_base_lo + j * self.dsub,
+                    self.dsub,
+                );
+                if dist < best_dist_lo {
+                    best_dist_lo = dist;
+                    best_lo = j as u8;
+                }
+            }
+
+            let mut best_hi = 0u8;
+            let mut best_dist_hi = f32::MAX;
+            let x_off_hi = sub_hi * self.dsub;
+            let c_base_hi = sub_hi * self.ksub * self.dsub;
+            for j in 0..self.ksub {
+                let dist = fvec_l2sqr_sub(
+                    x,
+                    x_off_hi,
+                    &self.centroids,
+                    c_base_hi + j * self.dsub,
+                    self.dsub,
+                );
+                if dist < best_dist_hi {
+                    best_dist_hi = dist;
+                    best_hi = j as u8;
+                }
+            }
+
+            // Pack: low nibble + high nibble
+            codes[pair] = best_lo | (best_hi << 4);
+        }
+    }
+
+    /// Encode multiple vectors in parallel.
+    pub fn encode_batch(&self, data: &[f32], n: usize, codes: &mut [u8]) {
+        let d = self.d;
+        let cs = self.code_size();
+
+        codes
+            .par_chunks_mut(cs)
+            .enumerate()
+            .for_each(|(i, code_chunk)| {
+                if i < n {
+                    self.encode(&data[i * d..(i + 1) * d], code_chunk);
+                }
+            });
+    }
+
+    /// Decode PQ codes back to an approximate vector.
+    pub fn decode(&self, codes: &[u8], x: &mut [f32]) {
+        if self.nbits == 4 {
+            for pair in 0..self.m / 2 {
+                let byte = codes[pair];
+                let code_lo = (byte & 0x0F) as usize;
+                let code_hi = ((byte >> 4) & 0x0F) as usize;
+
+                let sub_lo = pair * 2;
+                let sub_hi = pair * 2 + 1;
+
+                let c_off_lo = sub_lo * self.ksub * self.dsub + code_lo * 
self.dsub;
+                let x_off_lo = sub_lo * self.dsub;
+                x[x_off_lo..x_off_lo + self.dsub]
+                    .copy_from_slice(&self.centroids[c_off_lo..c_off_lo + 
self.dsub]);
+
+                let c_off_hi = sub_hi * self.ksub * self.dsub + code_hi * 
self.dsub;
+                let x_off_hi = sub_hi * self.dsub;
+                x[x_off_hi..x_off_hi + self.dsub]
+                    .copy_from_slice(&self.centroids[c_off_hi..c_off_hi + 
self.dsub]);
+            }
+        } else {
+            for sub in 0..self.m {
+                let c_off = sub * self.ksub * self.dsub + (codes[sub] as 
usize) * self.dsub;
+                let x_off = sub * self.dsub;
+                x[x_off..x_off + self.dsub]
+                    .copy_from_slice(&self.centroids[c_off..c_off + 
self.dsub]);
+            }
+        }
+    }
+
+    /// Precompute the distance table from a query to all PQ centroids.
+    /// Uses sgemm for dsub >= 4 (L2: ||q-c||²=||q||²+||c||²-2q·cᵀ).
+    pub fn compute_distance_table(&self, query: &[f32], metric: MetricType, 
table: &mut [f32]) {
+        if self.dsub >= 4 {
+            self.compute_distance_table_sgemm(query, metric, table);
+        } else {
+            self.compute_distance_table_loop(query, metric, table);
+        }
+    }
+
+    fn compute_distance_table_sgemm(&self, query: &[f32], metric: MetricType, 
table: &mut [f32]) {
+        for sub in 0..self.m {
+            let q_off = sub * self.dsub;
+            let c_base = sub * self.ksub * self.dsub;
+            let t_base = sub * self.ksub;
+
+            // Inner product: ip[ksub] = query_sub[1×dsub] · 
centroids_sub[ksub×dsub]ᵀ
+            sgemm_a_bt(
+                1,
+                self.ksub,
+                self.dsub,
+                1.0,
+                &query[q_off..q_off + self.dsub],
+                &self.centroids[c_base..c_base + self.ksub * self.dsub],
+                0.0,
+                &mut table[t_base..t_base + self.ksub],
+            );
+
+            match metric {
+                MetricType::L2 | MetricType::Cosine => {
+                    // ||q-c||² = ||q||² + ||c||² - 2·q·c
+                    // Use pre-cached centroid norms (avoids recomputing per 
query)
+                    let q_norm = fvec_norm_l2sqr(&query[q_off..q_off + 
self.dsub]);
+                    let norms_base = sub * self.ksub;
+                    for j in 0..self.ksub {
+                        let c_norm = if !self.centroid_norms_cache.is_empty() {
+                            self.centroid_norms_cache[norms_base + j]
+                        } else {
+                            let c_off = c_base + j * self.dsub;
+                            fvec_norm_l2sqr(&self.centroids[c_off..c_off + 
self.dsub])
+                        };
+                        table[t_base + j] = q_norm + c_norm - 2.0 * 
table[t_base + j];
+                    }
+                }
+                MetricType::InnerProduct => {
+                    for j in 0..self.ksub {
+                        table[t_base + j] = -table[t_base + j];
+                    }
+                }
+            }
+        }
+    }
+
+    fn compute_distance_table_loop(&self, query: &[f32], metric: MetricType, 
table: &mut [f32]) {
+        for sub in 0..self.m {
+            let q_off = sub * self.dsub;
+            let c_base = sub * self.ksub * self.dsub;
+            let t_base = sub * self.ksub;
+
+            match metric {
+                MetricType::L2 | MetricType::Cosine => {
+                    fvec_l2sqr_batch(
+                        &query[q_off..q_off + self.dsub],
+                        &self.centroids[c_base..c_base + self.ksub * 
self.dsub],
+                        self.dsub,
+                        self.ksub,
+                        &mut table[t_base..t_base + self.ksub],
+                    );
+                }
+                MetricType::InnerProduct => {
+                    fvec_ip_batch(
+                        &query[q_off..q_off + self.dsub],
+                        &self.centroids[c_base..c_base + self.ksub * 
self.dsub],
+                        self.dsub,
+                        self.ksub,
+                        &mut table[t_base..t_base + self.ksub],
+                    );
+                    for j in 0..self.ksub {
+                        table[t_base + j] = -table[t_base + j];
+                    }
+                }
+            }
+        }
+    }
+
+    /// Compute inner product table: ip_table[m * ksub + j] = <query_m, 
centroid_m_j>.
+    pub fn compute_inner_product_table(&self, query: &[f32], table: &mut 
[f32]) {
+        for sub in 0..self.m {
+            let q_off = sub * self.dsub;
+            let c_base = sub * self.ksub * self.dsub;
+            let t_base = sub * self.ksub;
+
+            fvec_ip_batch(
+                &query[q_off..q_off + self.dsub],
+                &self.centroids[c_base..c_base + self.ksub * self.dsub],
+                self.dsub,
+                self.ksub,
+                &mut table[t_base..t_base + self.ksub],
+            );
+        }
+    }
+
+    /// Compute the approximate distance from a distance table.
+    #[inline]
+    pub fn distance_from_table(&self, table: &[f32], codes: &[u8]) -> f32 {
+        if self.nbits == 4 {
+            self.distance_from_table_4bit(table, codes)
+        } else {
+            pq_distance_from_table(table, codes, self.m, self.ksub)
+        }
+    }
+
+    /// 4-bit PQ distance: unpack nibbles and accumulate from 16-entry tables.
+    #[inline]
+    fn distance_from_table_4bit(&self, table: &[f32], codes: &[u8]) -> f32 {
+        let mut dist = 0.0f32;
+        for pair in 0..self.m / 2 {
+            let byte = codes[pair];
+            let code_lo = (byte & 0x0F) as usize;
+            let code_hi = ((byte >> 4) & 0x0F) as usize;
+
+            let sub_lo = pair * 2;
+            let sub_hi = pair * 2 + 1;
+
+            dist += table[sub_lo * self.ksub + code_lo];
+            dist += table[sub_hi * self.ksub + code_hi];
+        }
+        dist
+    }
+
+    /// Compute squared norms of all PQ centroids.
+    /// Uses cache if available, otherwise computes from scratch.
+    pub fn compute_centroid_norms(&self) -> Vec<f32> {
+        if !self.centroid_norms_cache.is_empty() {
+            return self.centroid_norms_cache.clone();
+        }
+        let mut norms = vec![0.0f32; self.m * self.ksub];
+        for sub in 0..self.m {
+            let c_base = sub * self.ksub * self.dsub;
+            for j in 0..self.ksub {
+                let c_off = c_base + j * self.dsub;
+                norms[sub * self.ksub + j] =
+                    fvec_norm_l2sqr(&self.centroids[c_off..c_off + self.dsub]);
+            }
+        }
+        norms
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use rand::rngs::StdRng;
+    use rand::{Rng, SeedableRng};
+
+    #[test]
+    fn test_encode_decode_roundtrip() {
+        let d = 8;
+        let m = 2;
+        let n = 100;
+        let mut rng = StdRng::seed_from_u64(42);
+
+        let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+        let mut pq = ProductQuantizer::new(d, m);
+        pq.train(&data, n);
+
+        let original = &data[0..d];
+        let mut codes = vec![0u8; m];
+        pq.encode(original, &mut codes);
+
+        let mut decoded = vec![0.0f32; d];
+        pq.decode(&codes, &mut decoded);
+
+        // Decoded should be a reasonable approximation
+        let error = fvec_l2sqr_sub(original, 0, &decoded, 0, d);
+        assert!(error < 10.0); // PQ introduces quantization error
+    }
+
+    #[test]
+    fn test_distance_table() {
+        let d = 8;
+        let m = 2;
+        let n = 100;
+        let mut rng = StdRng::seed_from_u64(42);
+
+        let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+        let mut pq = ProductQuantizer::new(d, m);
+        pq.train(&data, n);
+
+        let query = &data[0..d];
+        let mut table = vec![0.0f32; m * pq.ksub];
+        pq.compute_distance_table(query, MetricType::L2, &mut table);
+
+        let mut codes = vec![0u8; m];
+        pq.encode(query, &mut codes);
+
+        let dist = pq.distance_from_table(&table, &codes);
+        assert!(dist >= 0.0);
+    }
+
+    #[test]
+    fn test_4bit_encode_decode() {
+        let d = 8;
+        let m = 4; // must be even for 4-bit
+        let n = 200;
+        let mut rng = StdRng::seed_from_u64(42);
+
+        let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+        let mut pq = ProductQuantizer::with_nbits(d, m, 4);
+        assert_eq!(pq.ksub, 16);
+        assert_eq!(pq.code_size(), 2); // m/2 = 2 bytes per vector
+
+        pq.train(&data, n);
+
+        let original = &data[0..d];
+        let mut codes = vec![0u8; pq.code_size()];
+        pq.encode(original, &mut codes);
+
+        // Verify codes are non-trivial (not all zeros)
+        assert!(codes.iter().any(|&b| b != 0));
+
+        let mut decoded = vec![0.0f32; d];
+        pq.decode(&codes, &mut decoded);
+
+        // Should be a reasonable approximation
+        let error = fvec_l2sqr_sub(original, 0, &decoded, 0, d);
+        assert!(error < 20.0); // 4-bit has higher error than 8-bit
+
+        // Distance table
+        let mut table = vec![0.0f32; m * pq.ksub];
+        pq.compute_distance_table(original, MetricType::L2, &mut table);
+        let dist = pq.distance_from_table(&table, &codes);
+        assert!(dist >= 0.0);
+    }
+
+    #[test]
+    fn test_4bit_batch_encode() {
+        let d = 16;
+        let m = 8;
+        let n = 100;
+        let mut rng = StdRng::seed_from_u64(42);
+
+        let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+        let mut pq = ProductQuantizer::with_nbits(d, m, 4);
+        pq.train(&data, n);
+
+        let cs = pq.code_size(); // m/2 = 4
+        let mut codes = vec![0u8; n * cs];
+        pq.encode_batch(&data, n, &mut codes);
+
+        // Verify codes are non-trivial (not all zeros)
+        assert!(codes.iter().any(|&b| b != 0));
+    }
+}

Reply via email to