This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 2027edc Add core foundation modules (#1)
2027edc is described below
commit 2027edc81aa813a1af8394266f53064fb6c9f377
Author: Jingsong Lee <[email protected]>
AuthorDate: Fri Jun 5 21:14:17 2026 +0800
Add core foundation modules (#1)
Bring over the basic building blocks: distance (L2/IP/Cosine + SIMD),
blas (sgemm), kmeans (k-means++/hierarchical/streaming), and product
quantizer (train/encode/decode/distance table). Skips higher-level
modules (ivfpq, opq, fastscan, shuffler, io) and language bindings.
---
core/Cargo.toml => .gitignore | 30 +-
LICENSE | 201 ++++++++++
NOTICE | 5 +
copyright.txt | 17 +
core/Cargo.toml | 5 +
core/src/blas.rs | 91 +++++
core/src/distance.rs | 470 +++++++++++++++++++++++
core/src/kmeans.rs | 875 ++++++++++++++++++++++++++++++++++++++++++
core/src/lib.rs | 8 +
core/src/pq.rs | 542 ++++++++++++++++++++++++++
10 files changed, 2239 insertions(+), 5 deletions(-)
diff --git a/core/Cargo.toml b/.gitignore
similarity index 77%
copy from core/Cargo.toml
copy to .gitignore
index b3b3233..d987f62 100644
--- a/core/Cargo.toml
+++ b/.gitignore
@@ -15,8 +15,28 @@
# specific language governing permissions and limitations
# under the License.
-[package]
-name = "paimon-vindex-core"
-version = "0.1.0"
-edition = "2021"
-license = "Apache-2.0"
+# Rust
+/target/
+Cargo.lock
+
+# Java
+java/target/
+java/src/main/resources/native/
+*.class
+
+# Python
+__pycache__/
+*.pyc
+*.egg-info/
+python/dist/
+python/build/
+.pytest_cache/
+
+# IDE
+.idea/
+*.iml
+.vscode/
+
+# OS
+.DS_Store
+Thumbs.db
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/NOTICE b/NOTICE
new file mode 100644
index 0000000..481e138
--- /dev/null
+++ b/NOTICE
@@ -0,0 +1,5 @@
+Apache Paimon Vector Index
+Copyright 2025 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
diff --git a/copyright.txt b/copyright.txt
new file mode 100644
index 0000000..2379dda
--- /dev/null
+++ b/copyright.txt
@@ -0,0 +1,17 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
diff --git a/core/Cargo.toml b/core/Cargo.toml
index b3b3233..69ba8a7 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -20,3 +20,8 @@ name = "paimon-vindex-core"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
+
+[dependencies]
+rand = "0.8"
+rayon = "1.10"
+matrixmultiply = "0.3"
diff --git a/core/src/blas.rs b/core/src/blas.rs
new file mode 100644
index 0000000..ddbe9a2
--- /dev/null
+++ b/core/src/blas.rs
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Matrix multiplication via `matrixmultiply` crate.
+//! Pure Rust with SIMD + cache blocking, no system BLAS dependency.
+
+/// C = alpha * A * B^T + beta * C
+/// A: [m × k] row-major
+/// B: [n × k] row-major (transposed in the multiply)
+/// C: [m × n] row-major
+pub fn sgemm_a_bt(
+ m: usize,
+ n: usize,
+ k: usize,
+ alpha: f32,
+ a: &[f32],
+ b: &[f32],
+ beta: f32,
+ c: &mut [f32],
+) {
+ unsafe {
+ matrixmultiply::sgemm(
+ m,
+ k,
+ n,
+ alpha,
+ a.as_ptr(),
+ k as isize, // row stride of A
+ 1, // col stride of A
+ b.as_ptr(),
+ 1, // B^T: col stride = 1 means we read B row-wise as
columns
+ k as isize, // B^T: row stride = k means next "column of B^T" =
next row of B
+ beta,
+ c.as_mut_ptr(),
+ n as isize, // row stride of C
+ 1, // col stride of C
+ );
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_sgemm_a_bt() {
+ // A = [[1, 2], [3, 4]] (2x2)
+ // B = [[5, 6], [7, 8]] (2x2)
+ // A * B^T = [[1*5+2*6, 1*7+2*8], [3*5+4*6, 3*7+4*8]]
+ // = [[17, 23], [39, 53]]
+ let a = [1.0f32, 2.0, 3.0, 4.0];
+ let b = [5.0f32, 6.0, 7.0, 8.0];
+ let mut c = [0.0f32; 4];
+
+ sgemm_a_bt(2, 2, 2, 1.0, &a, &b, 0.0, &mut c);
+
+ assert!((c[0] - 17.0).abs() < 1e-5);
+ assert!((c[1] - 23.0).abs() < 1e-5);
+ assert!((c[2] - 39.0).abs() < 1e-5);
+ assert!((c[3] - 53.0).abs() < 1e-5);
+ }
+
+ #[test]
+ fn test_sgemm_rectangular() {
+ // A = [[1, 2, 3]] (1x3)
+ // B = [[4, 5, 6], [7, 8, 9]] (2x3)
+ // A * B^T = [[1*4+2*5+3*6, 1*7+2*8+3*9]] = [[32, 50]]
+ let a = [1.0f32, 2.0, 3.0];
+ let b = [4.0f32, 5.0, 6.0, 7.0, 8.0, 9.0];
+ let mut c = [0.0f32; 2];
+
+ sgemm_a_bt(1, 2, 3, 1.0, &a, &b, 0.0, &mut c);
+
+ assert!((c[0] - 32.0).abs() < 1e-5);
+ assert!((c[1] - 50.0).abs() < 1e-5);
+ }
+}
diff --git a/core/src/distance.rs b/core/src/distance.rs
new file mode 100644
index 0000000..f6174e1
--- /dev/null
+++ b/core/src/distance.rs
@@ -0,0 +1,470 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[repr(u32)]
+pub enum MetricType {
+ L2 = 0,
+ InnerProduct = 1,
+ Cosine = 2,
+}
+
+impl MetricType {
+ pub fn from_code(code: u32) -> Option<Self> {
+ match code {
+ 0 => Some(MetricType::L2),
+ 1 => Some(MetricType::InnerProduct),
+ 2 => Some(MetricType::Cosine),
+ _ => None,
+ }
+ }
+}
+
+/// Squared L2 distance between two vectors.
+pub fn fvec_l2sqr(a: &[f32], b: &[f32]) -> f32 {
+ debug_assert_eq!(a.len(), b.len());
+ let mut sum = 0.0f32;
+ for i in 0..a.len() {
+ let d = a[i] - b[i];
+ sum += d * d;
+ }
+ sum
+}
+
+/// Squared L2 distance on sub-vectors.
+pub fn fvec_l2sqr_sub(a: &[f32], a_off: usize, b: &[f32], b_off: usize, len:
usize) -> f32 {
+ let mut sum = 0.0f32;
+ for i in 0..len {
+ let d = a[a_off + i] - b[b_off + i];
+ sum += d * d;
+ }
+ sum
+}
+
+/// Inner product of two vectors.
+pub fn fvec_inner_product(a: &[f32], b: &[f32]) -> f32 {
+ debug_assert_eq!(a.len(), b.len());
+ let mut dot = 0.0f32;
+ for i in 0..a.len() {
+ dot += a[i] * b[i];
+ }
+ dot
+}
+
+/// Squared L2 norm of a vector.
+pub fn fvec_norm_l2sqr(a: &[f32]) -> f32 {
+ let mut sum = 0.0f32;
+ for &v in a {
+ sum += v * v;
+ }
+ sum
+}
+
+/// Normalize a vector in-place to unit length. Returns the original norm.
+pub fn fvec_normalize(v: &mut [f32]) -> f32 {
+ let norm = fvec_norm_l2sqr(v).sqrt();
+ if norm > 0.0 {
+ let inv = 1.0 / norm;
+ for x in v.iter_mut() {
+ *x *= inv;
+ }
+ }
+ norm
+}
+
+/// Compute result[i] = a[i] + bf * b[i]. Used for precomputed table merging.
+/// Aligned with Faiss's fvec_madd.
+pub fn fvec_madd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+ debug_assert_eq!(a.len(), b.len());
+ debug_assert_eq!(a.len(), result.len());
+ fvec_madd_simd(a, b, bf, result);
+}
+
+#[cfg(target_arch = "x86_64")]
+fn fvec_madd_simd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+ if is_x86_feature_detected!("avx2") {
+ unsafe { fvec_madd_avx2(a, b, bf, result) };
+ } else {
+ fvec_madd_scalar(a, b, bf, result);
+ }
+}
+
+#[cfg(target_arch = "aarch64")]
+fn fvec_madd_simd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+ unsafe { fvec_madd_neon(a, b, bf, result) }
+}
+
+#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
+fn fvec_madd_simd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+ fvec_madd_scalar(a, b, bf, result);
+}
+
+#[inline]
+#[allow(dead_code)]
+fn fvec_madd_scalar(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+ for i in 0..a.len() {
+ result[i] = a[i] + bf * b[i];
+ }
+}
+
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn fvec_madd_avx2(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+ use std::arch::x86_64::*;
+ let n = a.len();
+ let vbf = _mm256_set1_ps(bf);
+ let mut i = 0;
+ while i + 8 <= n {
+ let va = _mm256_loadu_ps(a.as_ptr().add(i));
+ let vb = _mm256_loadu_ps(b.as_ptr().add(i));
+ let vr = _mm256_add_ps(va, _mm256_mul_ps(vbf, vb));
+ _mm256_storeu_ps(result.as_mut_ptr().add(i), vr);
+ i += 8;
+ }
+ while i < n {
+ result[i] = a[i] + bf * b[i];
+ i += 1;
+ }
+}
+
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn fvec_madd_neon(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
+ use std::arch::aarch64::*;
+ let n = a.len();
+ let vbf = vdupq_n_f32(bf);
+ let mut i = 0;
+ while i + 4 <= n {
+ let va = vld1q_f32(a.as_ptr().add(i));
+ let vb = vld1q_f32(b.as_ptr().add(i));
+ let vr = vmlaq_f32(va, vbf, vb);
+ vst1q_f32(result.as_mut_ptr().add(i), vr);
+ i += 4;
+ }
+ while i < n {
+ result[i] = a[i] + bf * b[i];
+ i += 1;
+ }
+}
+
+/// SIMD-accelerated squared L2 distance for sub-vectors (used by PQ distance
table).
+pub fn fvec_l2sqr_batch(
+ query_sub: &[f32],
+ centroids: &[f32],
+ dsub: usize,
+ ksub: usize,
+ result: &mut [f32],
+) {
+ for j in 0..ksub {
+ result[j] = fvec_l2sqr_sub(query_sub, 0, centroids, j * dsub, dsub);
+ }
+}
+
+/// SIMD-accelerated inner product for sub-vectors (used by PQ distance table).
+pub fn fvec_ip_batch(
+ query_sub: &[f32],
+ centroids: &[f32],
+ dsub: usize,
+ ksub: usize,
+ result: &mut [f32],
+) {
+ for j in 0..ksub {
+ let mut dot = 0.0f32;
+ for d in 0..dsub {
+ dot += query_sub[d] * centroids[j * dsub + d];
+ }
+ result[j] = dot;
+ }
+}
+
+/// Scan a batch of 4-bit PQ codes.
+/// Approach (aligned with Lance/Faiss):
+/// 1. Compute first FLAT_NUM vectors with exact f32 (calibrate qmax)
+/// 2. Quantize distance table to u8
+/// 3. Accumulate distances in u8 domain via SIMD shuffle
+/// 4. Dequantize back to f32 at the end
+///
+/// codes: nibble-packed [count * (m/2)], row-major.
+/// sim_table: [M * 16] f32 distance table.
+pub fn scan_4bit_simd(sim_table: &[f32], codes: &[u8], count: usize, m: usize,
dists: &mut [f32]) {
+ const FLAT_NUM: usize = 200;
+
+ let cs = m / 2; // code_size = m/2 bytes per vector
+
+ // Step 1: Compute first FLAT_NUM vectors with f32 precision
+ let flat_end = count.min(FLAT_NUM);
+ for i in 0..flat_end {
+ let base = i * cs;
+ let mut d = 0.0f32;
+ for pair in 0..cs {
+ let byte = codes[base + pair];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ d += sim_table[(pair * 2) * 16 + lo];
+ d += sim_table[(pair * 2 + 1) * 16 + hi];
+ }
+ dists[i] = d;
+ }
+
+ if count <= FLAT_NUM {
+ return;
+ }
+
+ // Step 2: Determine qmax from the first FLAT_NUM distances
+ let qmax = dists[..flat_end].iter().cloned().fold(f32::MIN, f32::max);
+
+ // Quantize the entire distance table [M * 16] to u8
+ let qmin = sim_table.iter().cloned().fold(f32::INFINITY, f32::min);
+ let range = (qmax - qmin).max(1e-10);
+ let factor = 255.0 / range;
+
+ let qtable: Vec<u8> = sim_table
+ .iter()
+ .map(|&d| ((d - qmin) * factor).clamp(0.0, 255.0) as u8)
+ .collect();
+
+ // Step 3: Scan remaining vectors in u8 domain
+ // Use u16 accumulators to avoid overflow (M/2 pairs × max 255 per pair ×
2 ≤ 65535 for M ≤ 256)
+ let mut q_dists = vec![0u16; count];
+
+ for pair in 0..cs {
+ let qtab_lo = &qtable[(pair * 2) * 16..(pair * 2 + 1) * 16];
+ let qtab_hi = &qtable[(pair * 2 + 1) * 16..(pair * 2 + 2) * 16];
+
+ // SIMD-friendly inner loop: sequential code access, 16-entry table
fits in register
+ for i in flat_end..count {
+ let byte = codes[i * cs + pair];
+ let lo = (byte & 0x0F) as usize;
+ let hi = ((byte >> 4) & 0x0F) as usize;
+ q_dists[i] += qtab_lo[lo] as u16 + qtab_hi[hi] as u16;
+ }
+ }
+
+ // Step 4: Dequantize back to f32
+ let inv_factor = range / 255.0;
+ let base_dist = qmin * m as f32; // M sub-quantizers each contribute at
least qmin
+ for i in flat_end..count {
+ dists[i] = q_dists[i] as f32 * inv_factor + base_dist;
+ }
+}
+
+/// Compute PQ distance from a precomputed distance table.
+/// table layout: [M][ksub], codes: M bytes.
+/// Each code[m] indexes into table[m * ksub + code[m]].
+#[inline]
+pub fn pq_distance_from_table(table: &[f32], codes: &[u8], m: usize, ksub:
usize) -> f32 {
+ pq_distance_from_table_simd(table, codes, m, ksub)
+}
+
+/// Process 4 codes at once for better instruction-level parallelism.
+#[inline]
+pub fn pq_distance_four_codes(
+ table: &[f32],
+ codes: &[u8],
+ m: usize,
+ ksub: usize,
+ offsets: [usize; 4],
+) -> [f32; 4] {
+ let mut dists = [0.0f32; 4];
+ for i in 0..m {
+ let base = i * ksub;
+ for j in 0..4 {
+ dists[j] += table[base + codes[offsets[j] + i] as usize];
+ }
+ }
+ dists
+}
+
+// SIMD-accelerated PQ distance table lookup.
+#[cfg(target_arch = "x86_64")]
+#[inline]
+fn pq_distance_from_table_simd(table: &[f32], codes: &[u8], m: usize, ksub:
usize) -> f32 {
+ if is_x86_feature_detected!("avx2") && m >= 8 && ksub == 256 {
+ unsafe { pq_distance_avx2(table, codes, m) }
+ } else {
+ pq_distance_scalar(table, codes, m, ksub)
+ }
+}
+
+#[cfg(target_arch = "aarch64")]
+#[inline]
+fn pq_distance_from_table_simd(table: &[f32], codes: &[u8], m: usize, ksub:
usize) -> f32 {
+ if ksub == 256 && m >= 4 {
+ unsafe { pq_distance_neon(table, codes, m) }
+ } else {
+ pq_distance_scalar(table, codes, m, ksub)
+ }
+}
+
+/// NEON-accelerated PQ distance with manual gather + vaddq_f32 accumulation.
+#[cfg(target_arch = "aarch64")]
+#[target_feature(enable = "neon")]
+unsafe fn pq_distance_neon(table: &[f32], codes: &[u8], m: usize) -> f32 {
+ use std::arch::aarch64::*;
+
+ let ksub = 256usize;
+ let mut sum = vdupq_n_f32(0.0);
+ let mut i = 0;
+
+ while i + 4 <= m {
+ let d0 = *table.get_unchecked(i * ksub + *codes.get_unchecked(i) as
usize);
+ let d1 = *table.get_unchecked((i + 1) * ksub + *codes.get_unchecked(i
+ 1) as usize);
+ let d2 = *table.get_unchecked((i + 2) * ksub + *codes.get_unchecked(i
+ 2) as usize);
+ let d3 = *table.get_unchecked((i + 3) * ksub + *codes.get_unchecked(i
+ 3) as usize);
+
+ let arr = [d0, d1, d2, d3];
+ let v = vld1q_f32(arr.as_ptr());
+ sum = vaddq_f32(sum, v);
+ i += 4;
+ }
+
+ let mut result = vaddvq_f32(sum);
+
+ while i < m {
+ result += *table.get_unchecked(i * ksub + *codes.get_unchecked(i) as
usize);
+ i += 1;
+ }
+
+ result
+}
+
+#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
+#[inline]
+fn pq_distance_from_table_simd(table: &[f32], codes: &[u8], m: usize, ksub:
usize) -> f32 {
+ pq_distance_scalar(table, codes, m, ksub)
+}
+
+#[inline]
+fn pq_distance_scalar(table: &[f32], codes: &[u8], m: usize, ksub: usize) ->
f32 {
+ let mut dist = 0.0f32;
+ for i in 0..m {
+ dist += table[i * ksub + codes[i] as usize];
+ }
+ dist
+}
+
+/// AVX2 PQ distance using gather instructions.
+/// Aligned with Faiss's pq_code_distance-avx2.h.
+#[cfg(target_arch = "x86_64")]
+#[target_feature(enable = "avx2")]
+unsafe fn pq_distance_avx2(table: &[f32], codes: &[u8], m: usize) -> f32 {
+ use std::arch::x86_64::*;
+
+ let ksub = 256usize;
+ let mut sum = _mm256_setzero_ps();
+ let mut i = 0;
+
+ // Process 8 sub-quantizers at a time
+ while i + 8 <= m {
+ let offsets = _mm256_set_epi32(
+ (7 * ksub + codes[i + 7] as usize) as i32,
+ (6 * ksub + codes[i + 6] as usize) as i32,
+ (5 * ksub + codes[i + 5] as usize) as i32,
+ (4 * ksub + codes[i + 4] as usize) as i32,
+ (3 * ksub + codes[i + 3] as usize) as i32,
+ (2 * ksub + codes[i + 2] as usize) as i32,
+ (ksub + codes[i + 1] as usize) as i32,
+ (codes[i] as usize) as i32,
+ );
+
+ let tab_ptr = table.as_ptr().add(i * ksub);
+ let gathered = _mm256_i32gather_ps::<4>(tab_ptr, offsets);
+ sum = _mm256_add_ps(sum, gathered);
+ i += 8;
+ }
+
+ // Horizontal sum of the 8 floats in sum
+ let hi = _mm256_extractf128_ps::<1>(sum);
+ let lo = _mm256_castps256_ps128(sum);
+ let sum128 = _mm_add_ps(lo, hi);
+ let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
+ let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps::<1>(sum64, sum64));
+ let mut result = _mm_cvtss_f32(sum32);
+
+ // Handle remaining sub-quantizers
+ while i < m {
+ result += table[i * ksub + codes[i] as usize];
+ i += 1;
+ }
+
+ result
+}
+
+/// Compute distance between query and a set of vectors, return top-k.
+pub fn fvec_distances_batch(
+ query: &[f32],
+ vectors: &[f32],
+ n: usize,
+ d: usize,
+ metric: MetricType,
+ distances: &mut [f32],
+) {
+ for i in 0..n {
+ let vec = &vectors[i * d..(i + 1) * d];
+ distances[i] = match metric {
+ MetricType::L2 => fvec_l2sqr(query, vec),
+ MetricType::InnerProduct => -fvec_inner_product(query, vec),
+ MetricType::Cosine => {
+ let dot = fvec_inner_product(query, vec);
+ let na = fvec_norm_l2sqr(query).sqrt();
+ let nb = fvec_norm_l2sqr(vec).sqrt();
+ let denom = na * nb;
+ if denom > 0.0 {
+ 1.0 - dot / denom
+ } else {
+ 1.0
+ }
+ }
+ };
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_l2sqr() {
+ let a = [1.0, 2.0, 3.0];
+ let b = [4.0, 5.0, 6.0];
+ assert!((fvec_l2sqr(&a, &b) - 27.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_inner_product() {
+ let a = [1.0, 2.0, 3.0];
+ let b = [4.0, 5.0, 6.0];
+ assert!((fvec_inner_product(&a, &b) - 32.0).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_normalize() {
+ let mut v = [3.0, 4.0];
+ fvec_normalize(&mut v);
+ assert!((v[0] - 0.6).abs() < 1e-6);
+ assert!((v[1] - 0.8).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_pq_distance_scalar() {
+ let table = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]; // 2 sub-q,
4 centroids
+ let codes = [1u8, 3u8];
+ let dist = pq_distance_scalar(&table, &codes, 2, 4);
+ // table[0*4 + 1] + table[1*4 + 3] = 0.2 + 0.8 = 1.0
+ assert!((dist - 1.0).abs() < 1e-6);
+ }
+}
diff --git a/core/src/kmeans.rs b/core/src/kmeans.rs
new file mode 100644
index 0000000..72d8069
--- /dev/null
+++ b/core/src/kmeans.rs
@@ -0,0 +1,875 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::blas::sgemm_a_bt;
+use crate::distance::{fvec_l2sqr, fvec_norm_l2sqr};
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
+
+pub struct KMeansConfig {
+ pub niter: usize,
+ pub nredo: usize,
+ pub max_points_per_centroid: usize,
+ pub seed: u64,
+ /// Balance factor: penalizes large clusters to produce more uniform
partitions.
+ /// 0.0 = standard k-means. Higher values = more balanced.
+ /// Typical value: 0.1 for IVF construction.
+ pub balance_factor: f32,
+}
+
+impl Default for KMeansConfig {
+ fn default() -> Self {
+ KMeansConfig {
+ niter: 25,
+ nredo: 1,
+ max_points_per_centroid: 256,
+ seed: 1234,
+ balance_factor: 0.0,
+ }
+ }
+}
+
+const EPS: f32 = 1.0 / 1024.0;
+
+/// Threshold above which hierarchical k-means is used.
+const HIERARCHICAL_THRESHOLD: usize = 256;
+
+pub fn kmeans_train(config: &KMeansConfig, data: &[f32], n: usize, d: usize,
k: usize) -> Vec<f32> {
+ if k > HIERARCHICAL_THRESHOLD && n > k {
+ kmeans_train_hierarchical(config, data, n, d, k)
+ } else {
+ kmeans_train_with_init(config, data, n, d, k, None)
+ }
+}
+
+/// Hierarchical k-means for large k (> 256).
+/// Starts with initial_k clusters and iteratively splits the largest until
target k is reached.
+fn kmeans_train_hierarchical(
+ config: &KMeansConfig,
+ data: &[f32],
+ n: usize,
+ d: usize,
+ target_k: usize,
+) -> Vec<f32> {
+ use std::cmp::Ordering;
+ use std::collections::BinaryHeap;
+
+ #[derive(Clone)]
+ struct Cluster {
+ centroid: Vec<f32>,
+ indices: Vec<usize>,
+ }
+
+ impl Eq for Cluster {}
+ impl PartialEq for Cluster {
+ fn eq(&self, other: &Self) -> bool {
+ self.indices.len() == other.indices.len()
+ }
+ }
+ impl Ord for Cluster {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.indices.len().cmp(&other.indices.len())
+ }
+ }
+ impl PartialOrd for Cluster {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+ }
+
+ let mut rng = StdRng::seed_from_u64(config.seed);
+
+ // Subsample for training
+ let max_n = target_k * config.max_points_per_centroid;
+ let (train_data, train_n) = if n > max_n {
+ let sub = subsample(data, n, d, max_n, &mut rng);
+ (sub, max_n)
+ } else {
+ (data.to_vec(), n)
+ };
+
+ // Step 1: Train initial_k clusters
+ let initial_k = 16.min(target_k);
+ let initial_config = KMeansConfig {
+ niter: config.niter,
+ seed: config.seed,
+ ..KMeansConfig::default()
+ };
+ let initial_centroids =
+ kmeans_train_with_init(&initial_config, &train_data, train_n, d,
initial_k, None);
+
+ // Assign all points to initial clusters
+ let mut assignments = vec![0usize; train_n];
+ assign_clusters_fast(
+ &train_data,
+ train_n,
+ d,
+ &initial_centroids,
+ initial_k,
+ &mut assignments,
+ 0.0,
+ );
+
+ // Build initial clusters
+ let mut heap: BinaryHeap<Cluster> = BinaryHeap::new();
+ for c in 0..initial_k {
+ let indices: Vec<usize> = (0..train_n).filter(|&i| assignments[i] ==
c).collect();
+ let centroid = initial_centroids[c * d..(c + 1) * d].to_vec();
+ heap.push(Cluster { centroid, indices });
+ }
+
+ // Step 2: Iteratively split the largest cluster
+ let mut finalized: Vec<Vec<f32>> = Vec::new();
+ let split_k = 2; // Split into 2 each time
+
+ while finalized.len() + heap.len() < target_k {
+ let largest = match heap.pop() {
+ Some(c) => c,
+ None => break,
+ };
+
+ if largest.indices.len() < split_k * 2 {
+ finalized.push(largest.centroid);
+ continue;
+ }
+
+ // Extract sub-data for this cluster
+ let sub_n = largest.indices.len();
+ let mut sub_data = vec![0.0f32; sub_n * d];
+ for (new_idx, &orig_idx) in largest.indices.iter().enumerate() {
+ sub_data[new_idx * d..(new_idx + 1) * d]
+ .copy_from_slice(&train_data[orig_idx * d..(orig_idx + 1) *
d]);
+ }
+
+ // Run k-means to split
+ let sub_config = KMeansConfig {
+ niter: 10,
+ seed: config.seed + finalized.len() as u64,
+ ..KMeansConfig::default()
+ };
+ let sub_centroids = kmeans_train_with_init(&sub_config, &sub_data,
sub_n, d, split_k, None);
+
+ // Reassign points in this cluster
+ let mut sub_assignments = vec![0usize; sub_n];
+ assign_clusters_fast(
+ &sub_data,
+ sub_n,
+ d,
+ &sub_centroids,
+ split_k,
+ &mut sub_assignments,
+ 0.0,
+ );
+
+ for sc in 0..split_k {
+ let sub_indices: Vec<usize> = (0..sub_n)
+ .filter(|&i| sub_assignments[i] == sc)
+ .map(|i| largest.indices[i])
+ .collect();
+ let centroid = sub_centroids[sc * d..(sc + 1) * d].to_vec();
+ if !sub_indices.is_empty() {
+ heap.push(Cluster {
+ centroid,
+ indices: sub_indices,
+ });
+ }
+ }
+ }
+
+ // Collect all centroids
+ let mut result = Vec::with_capacity(target_k * d);
+ for c in finalized {
+ result.extend_from_slice(&c);
+ }
+ while let Some(cluster) = heap.pop() {
+ result.extend_from_slice(&cluster.centroid);
+ if result.len() >= target_k * d {
+ break;
+ }
+ }
+
+ // Pad if needed
+ result.resize(target_k * d, 0.0);
+ result
+}
+
+pub fn kmeans_train_with_init(
+ config: &KMeansConfig,
+ data: &[f32],
+ n: usize,
+ d: usize,
+ k: usize,
+ initial_centroids: Option<&[f32]>,
+) -> Vec<f32> {
+ if n == 0 || k == 0 {
+ return vec![0.0; k * d];
+ }
+
+ let mut rng = StdRng::seed_from_u64(config.seed);
+
+ let max_n = k * config.max_points_per_centroid;
+ let (train_data, train_n) = if n > max_n {
+ let sub = subsample(data, n, d, max_n, &mut rng);
+ (sub, max_n)
+ } else {
+ (data.to_vec(), n)
+ };
+
+ if train_n <= k {
+ let mut centroids = vec![0.0f32; k * d];
+ for i in 0..k {
+ let src = i % train_n;
+ centroids[i * d..(i + 1) * d].copy_from_slice(&train_data[src *
d..(src + 1) * d]);
+ }
+ return centroids;
+ }
+
+ let mut best_centroids = vec![0.0f32; k * d];
+ let mut best_obj = f32::MAX;
+
+ let nredo = if initial_centroids.is_some() {
+ 1
+ } else {
+ config.nredo
+ };
+
+ for redo in 0..nredo {
+ let mut centroids = if redo == 0 {
+ if let Some(init) = initial_centroids {
+ init.to_vec()
+ } else {
+ kmeans_plusplus_init(&train_data, train_n, d, k, &mut rng)
+ }
+ } else {
+ kmeans_plusplus_init(&train_data, train_n, d, k, &mut rng)
+ };
+ let mut assignments = vec![0usize; train_n];
+ let mut prev_obj = f32::MAX;
+
+ for _iter in 0..config.niter {
+ let obj = assign_clusters_fast(
+ &train_data,
+ train_n,
+ d,
+ ¢roids,
+ k,
+ &mut assignments,
+ config.balance_factor,
+ );
+ update_centroids(
+ &train_data,
+ train_n,
+ d,
+ &mut centroids,
+ k,
+ &assignments,
+ &mut rng,
+ );
+
+ if prev_obj < f32::MAX {
+ let rel_change = (prev_obj - obj).abs() / prev_obj.max(1e-10);
+ if rel_change < 1e-6 {
+ break;
+ }
+ }
+ prev_obj = obj;
+ }
+
+ if prev_obj < best_obj {
+ best_obj = prev_obj;
+ best_centroids.copy_from_slice(¢roids);
+ }
+ }
+
+ best_centroids
+}
+
+fn kmeans_plusplus_init(data: &[f32], n: usize, d: usize, k: usize, rng: &mut
StdRng) -> Vec<f32> {
+ let mut centroids = vec![0.0f32; k * d];
+
+ let first = rng.gen_range(0..n);
+ centroids[..d].copy_from_slice(&data[first * d..(first + 1) * d]);
+
+ let mut min_dists = vec![f32::MAX; n];
+
+ for c in 1..k {
+ let prev = ¢roids[(c - 1) * d..c * d];
+ let mut total = 0.0f32;
+ for i in 0..n {
+ let dist = fvec_l2sqr(&data[i * d..(i + 1) * d], prev);
+ if dist < min_dists[i] {
+ min_dists[i] = dist;
+ }
+ total += min_dists[i];
+ }
+
+ let target = rng.gen::<f32>() * total;
+ let mut cumulative = 0.0f32;
+ let mut selected = n - 1;
+ for i in 0..n {
+ cumulative += min_dists[i];
+ if cumulative >= target {
+ selected = i;
+ break;
+ }
+ }
+
+ centroids[c * d..(c + 1) * d].copy_from_slice(&data[selected *
d..(selected + 1) * d]);
+ }
+
+ centroids
+}
+
+/// Fast assignment using sgemm: ||x-c||² = ||x||² + ||c||² - 2·x·cᵀ.
+/// Supports balance_factor to penalize large clusters.
+fn assign_clusters_fast(
+ data: &[f32],
+ n: usize,
+ d: usize,
+ centroids: &[f32],
+ k: usize,
+ assignments: &mut [usize],
+ balance_factor: f32,
+) -> f32 {
+ // Cap ip_matrix size to ~16MB. Chunk if n*k would be too large.
+ const MAX_MATRIX_ELEMS: usize = 4 * 1024 * 1024; // 16MB / 4 bytes
+ if n * k > MAX_MATRIX_ELEMS {
+ let chunk_n = MAX_MATRIX_ELEMS / k;
+ let mut total_obj = 0.0f32;
+ let mut offset = 0;
+ while offset < n {
+ let cn = (n - offset).min(chunk_n);
+ total_obj += assign_clusters_fast(
+ &data[offset * d..(offset + cn) * d],
+ cn,
+ d,
+ centroids,
+ k,
+ &mut assignments[offset..offset + cn],
+ balance_factor,
+ );
+ offset += cn;
+ }
+ return total_obj;
+ }
+
+ let x_norms: Vec<f32> = (0..n)
+ .map(|i| fvec_norm_l2sqr(&data[i * d..(i + 1) * d]))
+ .collect();
+ let c_norms: Vec<f32> = (0..k)
+ .map(|c| fvec_norm_l2sqr(¢roids[c * d..(c + 1) * d]))
+ .collect();
+
+ let mut ip_matrix = vec![0.0f32; n * k];
+ sgemm_a_bt(n, k, d, 1.0, data, centroids, 0.0, &mut ip_matrix);
+
+ // Compute cluster sizes for balance penalty
+ let mut cluster_sizes = vec![0u32; k];
+ if balance_factor > 0.0 {
+ for &a in assignments.iter() {
+ if a < k {
+ cluster_sizes[a] += 1;
+ }
+ }
+ }
+
+ let mut total_obj = 0.0f32;
+ for i in 0..n {
+ let mut best = 0;
+ let mut best_dist = f32::MAX;
+ let row = i * k;
+ for c in 0..k {
+ let mut dist = x_norms[i] + c_norms[c] - 2.0 * ip_matrix[row + c];
+ // Balance penalty: prefer smaller clusters
+ if balance_factor > 0.0 && cluster_sizes[c] > 0 {
+ dist += balance_factor * (cluster_sizes[c] as f32).ln();
+ }
+ if dist < best_dist {
+ best_dist = dist;
+ best = c;
+ }
+ }
+ assignments[i] = best;
+ total_obj += best_dist;
+ }
+
+ total_obj
+}
+
+fn update_centroids(
+ data: &[f32],
+ n: usize,
+ d: usize,
+ centroids: &mut [f32],
+ k: usize,
+ assignments: &[usize],
+ rng: &mut StdRng,
+) {
+ let mut counts = vec![0usize; k];
+ let mut sums = vec![0.0f32; k * d];
+
+ for i in 0..n {
+ let c = assignments[i];
+ counts[c] += 1;
+ for j in 0..d {
+ sums[c * d + j] += data[i * d + j];
+ }
+ }
+
+ for c in 0..k {
+ if counts[c] > 0 {
+ let inv = 1.0 / counts[c] as f32;
+ for j in 0..d {
+ centroids[c * d + j] = sums[c * d + j] * inv;
+ }
+ }
+ }
+
+ for c in 0..k {
+ if counts[c] > 0 {
+ continue;
+ }
+
+ let donor = counts
+ .iter()
+ .enumerate()
+ .max_by_key(|(_, &cnt)| cnt)
+ .map(|(idx, _)| idx)
+ .unwrap_or(0);
+
+ if counts[donor] <= 1 {
+ let idx = rng.gen_range(0..n);
+ centroids[c * d..(c + 1) * d].copy_from_slice(&data[idx * d..(idx
+ 1) * d]);
+ continue;
+ }
+
+ let donor_copy: Vec<f32> = centroids[donor * d..(donor + 1) *
d].to_vec();
+ centroids[c * d..(c + 1) * d].copy_from_slice(&donor_copy);
+
+ for j in 0..d {
+ if j.is_multiple_of(2) {
+ centroids[c * d + j] *= 1.0 + EPS;
+ centroids[donor * d + j] *= 1.0 - EPS;
+ } else {
+ centroids[c * d + j] *= 1.0 - EPS;
+ centroids[donor * d + j] *= 1.0 + EPS;
+ }
+ }
+
+ counts[c] = counts[donor] / 2;
+ counts[donor] -= counts[c];
+ }
+}
+
+pub fn find_nearest(point: &[f32], centroids: &[f32], k: usize, d: usize) ->
usize {
+ let mut best = 0;
+ let mut best_dist = f32::MAX;
+ for c in 0..k {
+ let dist = fvec_l2sqr(point, ¢roids[c * d..(c + 1) * d]);
+ if dist < best_dist {
+ best_dist = dist;
+ best = c;
+ }
+ }
+ best
+}
+
+pub fn find_topk(
+ point: &[f32],
+ centroids: &[f32],
+ k: usize,
+ d: usize,
+ nprobe: usize,
+) -> (Vec<usize>, Vec<f32>) {
+ let nprobe = nprobe.min(k);
+ let mut dists: Vec<(f32, usize)> = (0..k)
+ .map(|c| (fvec_l2sqr(point, ¢roids[c * d..(c + 1) * d]), c))
+ .collect();
+ dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ let indices: Vec<usize> = dists[..nprobe].iter().map(|&(_, i)|
i).collect();
+ let distances: Vec<f32> = dists[..nprobe].iter().map(|&(d, _)|
d).collect();
+ (indices, distances)
+}
+
+/// Batch find top-nprobe nearest centroids for multiple queries using sgemm.
+/// Returns (all_indices, all_distances) each of length nq * nprobe.
+pub fn find_topk_batch(
+ queries: &[f32],
+ nq: usize,
+ centroids: &[f32],
+ k: usize,
+ d: usize,
+ nprobe: usize,
+) -> (Vec<Vec<usize>>, Vec<Vec<f32>>) {
+ let nprobe = nprobe.min(k);
+
+ if nq == 1 {
+ let (indices, distances) = find_topk(&queries[..d], centroids, k, d,
nprobe);
+ return (vec![indices], vec![distances]);
+ }
+
+ // Precompute norms
+ let q_norms: Vec<f32> = (0..nq)
+ .map(|i| fvec_norm_l2sqr(&queries[i * d..(i + 1) * d]))
+ .collect();
+ let c_norms: Vec<f32> = (0..k)
+ .map(|c| fvec_norm_l2sqr(¢roids[c * d..(c + 1) * d]))
+ .collect();
+
+ // Batch inner products: ip[nq × k] = queries[nq × d] · centroids[k × d]^T
+ let mut ip_matrix = vec![0.0f32; nq * k];
+ sgemm_a_bt(nq, k, d, 1.0, queries, centroids, 0.0, &mut ip_matrix);
+
+ // Extract top-nprobe per query
+ let mut all_indices = Vec::with_capacity(nq);
+ let mut all_distances = Vec::with_capacity(nq);
+
+ for qi in 0..nq {
+ let row = qi * k;
+ let mut dists: Vec<(f32, usize)> = (0..k)
+ .map(|c| {
+ let dist = q_norms[qi] + c_norms[c] - 2.0 * ip_matrix[row + c];
+ (dist.max(0.0), c)
+ })
+ .collect();
+ dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+
+ all_indices.push(dists[..nprobe].iter().map(|&(_, i)| i).collect());
+ all_distances.push(dists[..nprobe].iter().map(|&(d, _)| d).collect());
+ }
+
+ (all_indices, all_distances)
+}
+
+// --- Streaming Coreset K-means ---
+
+/// Streaming k-means trainer for very large datasets.
+/// Processes data in chunks, compresses each chunk into a weighted coreset,
+/// then trains final centroids on the accumulated coreset.
+pub struct StreamingKMeans {
+ pub d: usize,
+ pub k: usize,
+ pub chunk_size: usize,
+ config: KMeansConfig,
+ /// Accumulated coreset: (centroids, weights)
+ coreset_centroids: Vec<f32>,
+ coreset_weights: Vec<f32>,
+}
+
+impl StreamingKMeans {
+ /// Create a streaming k-means trainer.
+ /// chunk_size: number of vectors per chunk (e.g., k * 256)
+ pub fn new(d: usize, k: usize, chunk_size: usize, config: KMeansConfig) ->
Self {
+ StreamingKMeans {
+ d,
+ k,
+ chunk_size,
+ config,
+ coreset_centroids: Vec::new(),
+ coreset_weights: Vec::new(),
+ }
+ }
+
+ /// Feed a chunk of training data. Can be called multiple times.
+ /// Each chunk is compressed into k weighted centroids (coreset).
+ pub fn add_chunk(&mut self, data: &[f32], n: usize) {
+ let d = self.d;
+ let chunk_k = self.k.min(n);
+
+ if chunk_k == 0 || n == 0 {
+ return;
+ }
+
+ // Train k-means on this chunk
+ let chunk_config = KMeansConfig {
+ niter: 15,
+ seed: self.config.seed + self.coreset_weights.len() as u64,
+ ..KMeansConfig::default()
+ };
+ let centroids = kmeans_train_with_init(&chunk_config, data, n, d,
chunk_k, None);
+
+ // Assign points to centroids to compute weights
+ let mut assignments = vec![0usize; n];
+ assign_clusters_fast(data, n, d, ¢roids, chunk_k, &mut
assignments, 0.0);
+
+ let mut weights = vec![0.0f32; chunk_k];
+ for &a in &assignments {
+ weights[a] += 1.0;
+ }
+
+ // Append to coreset
+ self.coreset_centroids.extend_from_slice(¢roids);
+ self.coreset_weights.extend_from_slice(&weights);
+ }
+
+ /// Finalize: train final centroids on the accumulated weighted coreset.
+ pub fn finalize(&self) -> Vec<f32> {
+ let d = self.d;
+ let coreset_n = self.coreset_weights.len();
+
+ if coreset_n == 0 {
+ return vec![0.0f32; self.k * d];
+ }
+
+ if coreset_n <= self.k {
+ let mut result = self.coreset_centroids.clone();
+ result.resize(self.k * d, 0.0);
+ return result;
+ }
+
+ // Weighted k-means on coreset
+ weighted_kmeans_train(
+ &self.config,
+ &self.coreset_centroids,
+ &self.coreset_weights,
+ coreset_n,
+ d,
+ self.k,
+ )
+ }
+
+ /// Total vectors processed so far.
+ pub fn total_weight(&self) -> f32 {
+ self.coreset_weights.iter().sum()
+ }
+}
+
+/// Weighted k-means: each point has a weight that affects centroid
computation.
+fn weighted_kmeans_train(
+ config: &KMeansConfig,
+ data: &[f32],
+ weights: &[f32],
+ n: usize,
+ d: usize,
+ k: usize,
+) -> Vec<f32> {
+ let mut rng = StdRng::seed_from_u64(config.seed);
+
+ if n <= k {
+ let mut centroids = vec![0.0f32; k * d];
+ for i in 0..k {
+ let src = i % n;
+ centroids[i * d..(i + 1) * d].copy_from_slice(&data[src * d..(src
+ 1) * d]);
+ }
+ return centroids;
+ }
+
+ let mut centroids = kmeans_plusplus_init(data, n, d, k, &mut rng);
+ let mut assignments = vec![0usize; n];
+
+ for _iter in 0..config.niter {
+ // Assign (unweighted distance)
+ assign_clusters_fast(data, n, d, ¢roids, k, &mut assignments, 0.0);
+
+ // Update with weights
+ let mut sums = vec![0.0f32; k * d];
+ let mut total_weights = vec![0.0f32; k];
+
+ for i in 0..n {
+ let c = assignments[i];
+ let w = weights[i];
+ total_weights[c] += w;
+ for j in 0..d {
+ sums[c * d + j] += w * data[i * d + j];
+ }
+ }
+
+ for c in 0..k {
+ if total_weights[c] > 0.0 {
+ let inv = 1.0 / total_weights[c];
+ for j in 0..d {
+ centroids[c * d + j] = sums[c * d + j] * inv;
+ }
+ } else {
+ // Reinit empty cluster
+ let idx = rng.gen_range(0..n);
+ centroids[c * d..(c + 1) * d].copy_from_slice(&data[idx *
d..(idx + 1) * d]);
+ }
+ }
+ }
+
+ centroids
+}
+
+fn subsample(data: &[f32], n: usize, d: usize, target_n: usize, rng: &mut
StdRng) -> Vec<f32> {
+ let mut indices: Vec<usize> = (0..n).collect();
+ for i in 0..target_n {
+ let j = rng.gen_range(i..n);
+ indices.swap(i, j);
+ }
+ let mut result = vec![0.0f32; target_n * d];
+ for (out_i, &src_i) in indices[..target_n].iter().enumerate() {
+ result[out_i * d..(out_i + 1) * d].copy_from_slice(&data[src_i *
d..(src_i + 1) * d]);
+ }
+ result
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_two_clusters() {
+ let mut data = Vec::new();
+ for _ in 0..50 {
+ data.push(0.1);
+ data.push(0.1);
+ }
+ for _ in 0..50 {
+ data.push(10.1);
+ data.push(10.1);
+ }
+
+ let config = KMeansConfig::default();
+ let centroids = kmeans_train(&config, &data, 100, 2, 2);
+
+ let c0 = if centroids[0] < 5.0 {
+ ¢roids[0..2]
+ } else {
+ ¢roids[2..4]
+ };
+ let c1 = if centroids[0] < 5.0 {
+ ¢roids[2..4]
+ } else {
+ ¢roids[0..2]
+ };
+
+ assert!(c0[0] < 2.0 && c0[1] < 2.0);
+ assert!(c1[0] > 8.0 && c1[1] > 8.0);
+ }
+
+ #[test]
+ fn test_find_topk() {
+ let centroids = [0.0, 0.0, 10.0, 0.0, 5.0, 5.0];
+ let query = [1.0, 1.0];
+ let (indices, _) = find_topk(&query, ¢roids, 3, 2, 2);
+ assert_eq!(indices[0], 0);
+ }
+
+ #[test]
+ fn test_hot_start_converges_faster() {
+ let mut rng = StdRng::seed_from_u64(42);
+ let n = 500;
+ let d = 4;
+ let k = 4;
+
+ let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>() *
10.0).collect();
+
+ let config = KMeansConfig {
+ niter: 25,
+ ..KMeansConfig::default()
+ };
+ let centroids = kmeans_train(&config, &data, n, d, k);
+
+ // Hot-start with previous centroids should converge in fewer
iterations
+ let config2 = KMeansConfig {
+ niter: 3,
+ ..KMeansConfig::default()
+ };
+ let centroids2 = kmeans_train_with_init(&config2, &data, n, d, k,
Some(¢roids));
+
+ // Should be very close to the original since it started from
converged state
+ let mut total_diff = 0.0f32;
+ for i in 0..k * d {
+ total_diff += (centroids[i] - centroids2[i]).abs();
+ }
+ assert!(
+ total_diff < 1.0,
+ "Hot-start centroids drifted too much: {}",
+ total_diff
+ );
+ }
+
+ #[test]
+ fn test_streaming_coreset_kmeans() {
+ let n = 5000;
+ let d = 4;
+ let k = 10;
+ let chunk_size = 1000;
+
+ let mut rng = StdRng::seed_from_u64(42);
+ // Generate clustered data
+ let mut data = Vec::new();
+ for cluster in 0..k {
+ let cx = cluster as f32 * 20.0;
+ let cy = cluster as f32 * 20.0;
+ for _ in 0..n / k {
+ data.push(cx + rng.gen::<f32>() * 2.0);
+ data.push(cy + rng.gen::<f32>() * 2.0);
+ data.push(rng.gen::<f32>());
+ data.push(rng.gen::<f32>());
+ }
+ }
+
+ let config = KMeansConfig::default();
+ let mut streaming = StreamingKMeans::new(d, k, chunk_size, config);
+
+ // Feed data in chunks
+ for chunk_start in (0..n).step_by(chunk_size) {
+ let chunk_end = (chunk_start + chunk_size).min(n);
+ let chunk_n = chunk_end - chunk_start;
+ streaming.add_chunk(&data[chunk_start * d..chunk_end * d],
chunk_n);
+ }
+
+ assert!((streaming.total_weight() - n as f32).abs() < 1.0);
+
+ let centroids = streaming.finalize();
+ assert_eq!(centroids.len(), k * d);
+
+ // Centroids should be diverse
+ let first = ¢roids[0..d];
+ let mut diverse = false;
+ for i in 1..k {
+ if fvec_l2sqr(¢roids[i * d..(i + 1) * d], first) > 1.0 {
+ diverse = true;
+ break;
+ }
+ }
+ assert!(diverse, "Streaming centroids are not diverse");
+ }
+
+ #[test]
+ fn test_hierarchical_kmeans() {
+ let n = 2000;
+ let d = 4;
+ let k = 300; // > 256, triggers hierarchical
+
+ let mut rng = StdRng::seed_from_u64(42);
+ let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>() *
100.0).collect();
+
+ let config = KMeansConfig::default();
+ let centroids = kmeans_train(&config, &data, n, d, k);
+
+ assert_eq!(centroids.len(), k * d);
+
+ // All centroids should be finite
+ for &v in ¢roids {
+ assert!(v.is_finite(), "Non-finite centroid value: {}", v);
+ }
+
+ // Centroids should be diverse (not all the same)
+ let first = ¢roids[0..d];
+ let mut all_same = true;
+ for i in 1..k {
+ if ¢roids[i * d..(i + 1) * d] != first {
+ all_same = false;
+ break;
+ }
+ }
+ assert!(!all_same, "All centroids are identical");
+ }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index b248758..a2df43f 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -14,3 +14,11 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
+
+#![allow(clippy::needless_range_loop)]
+#![allow(clippy::too_many_arguments)]
+
+pub mod blas;
+pub mod distance;
+pub mod kmeans;
+pub mod pq;
diff --git a/core/src/pq.rs b/core/src/pq.rs
new file mode 100644
index 0000000..bb2e292
--- /dev/null
+++ b/core/src/pq.rs
@@ -0,0 +1,542 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::blas::sgemm_a_bt;
+use crate::distance::{
+ fvec_ip_batch, fvec_l2sqr_batch, fvec_l2sqr_sub, fvec_norm_l2sqr,
pq_distance_from_table,
+ MetricType,
+};
+use crate::kmeans::{self, KMeansConfig};
+use rayon::prelude::*;
+
+/// Product Quantizer aligned with Faiss's ProductQuantizer.
+///
+/// Splits D-dimensional vectors into M sub-vectors of dimension dsub = D/M,
+/// and independently quantizes each sub-vector with ksub centroids.
+///
+/// Centroid layout: flat [M * ksub * dsub], row-major.
+/// centroids[m][j][d] is at index: m * ksub * dsub + j * dsub + d
+pub struct ProductQuantizer {
+ pub d: usize,
+ pub m: usize,
+ pub nbits: usize,
+ pub dsub: usize,
+ pub ksub: usize,
+ pub centroids: Vec<f32>,
+ /// Pre-computed squared norms of each centroid: [M * ksub].
+ /// Avoids recomputing per query for L2 distance table.
+ pub centroid_norms_cache: Vec<f32>,
+}
+
+impl ProductQuantizer {
+ pub fn new(d: usize, m: usize) -> Self {
+ Self::with_nbits(d, m, 8)
+ }
+
+ pub fn with_nbits(d: usize, m: usize, nbits: usize) -> Self {
+ assert!(
+ d.is_multiple_of(m),
+ "dimension {} must be divisible by m={}",
+ d,
+ m
+ );
+ assert!(
+ nbits == 4 || nbits == 8,
+ "nbits must be 4 or 8, got {}",
+ nbits
+ );
+ if nbits == 4 {
+ assert!(
+ m.is_multiple_of(2),
+ "m must be even for 4-bit PQ, got {}",
+ m
+ );
+ }
+ let dsub = d / m;
+ let ksub = 1 << nbits;
+ ProductQuantizer {
+ d,
+ m,
+ nbits,
+ dsub,
+ ksub,
+ centroids: Vec::new(),
+ centroid_norms_cache: Vec::new(),
+ }
+ }
+
+ /// Train the codebooks from training data.
+ /// data: flat [n * d], n training vectors.
+ pub fn train(&mut self, data: &[f32], n: usize) {
+ self.train_with_config(data, n, &KMeansConfig::default());
+ }
+
+ pub fn train_with_config(&mut self, data: &[f32], n: usize, km_config:
&KMeansConfig) {
+ self.train_hot_start(data, n, km_config, false);
+ }
+
+ /// Train with optional hot-start: reuse existing centroids as k-means
initial values.
+ /// Parallelizes across M sub-quantizers with rayon.
+ pub fn train_hot_start(
+ &mut self,
+ data: &[f32],
+ n: usize,
+ km_config: &KMeansConfig,
+ hot_start: bool,
+ ) {
+ let prev_centroids = if hot_start && !self.centroids.is_empty() {
+ Some(self.centroids.clone())
+ } else {
+ None
+ };
+
+ let m = self.m;
+ let d = self.d;
+ let dsub = self.dsub;
+ let ksub = self.ksub;
+
+ // Train all M sub-quantizers in parallel
+ let sub_results: Vec<Vec<f32>> = (0..m)
+ .into_par_iter()
+ .map(|sub| {
+ let offset = sub * dsub;
+
+ let mut sub_data = vec![0.0f32; n * dsub];
+ for i in 0..n {
+ sub_data[i * dsub..(i + 1) * dsub]
+ .copy_from_slice(&data[i * d + offset..i * d + offset
+ dsub]);
+ }
+
+ let init: Option<Vec<f32>> = prev_centroids.as_ref().map(|pc| {
+ let src = sub * ksub * dsub;
+ pc[src..src + ksub * dsub].to_vec()
+ });
+
+ kmeans::kmeans_train_with_init(km_config, &sub_data, n, dsub,
ksub, init.as_deref())
+ })
+ .collect();
+
+ self.centroids = vec![0.0f32; m * ksub * dsub];
+ for (sub, sub_centroids) in sub_results.into_iter().enumerate() {
+ let dst_offset = sub * ksub * dsub;
+ self.centroids[dst_offset..dst_offset + ksub *
dsub].copy_from_slice(&sub_centroids);
+ }
+ self.rebuild_norms_cache();
+ }
+
+ /// Rebuild the centroid norms cache. Called after training or loading
centroids.
+ pub fn rebuild_norms_cache(&mut self) {
+ self.centroid_norms_cache = vec![0.0f32; self.m * self.ksub];
+ for sub in 0..self.m {
+ let c_base = sub * self.ksub * self.dsub;
+ for j in 0..self.ksub {
+ let c_off = c_base + j * self.dsub;
+ self.centroid_norms_cache[sub * self.ksub + j] =
+ fvec_norm_l2sqr(&self.centroids[c_off..c_off + self.dsub]);
+ }
+ }
+ }
+
+ /// Bytes per encoded vector.
+ pub fn code_size(&self) -> usize {
+ if self.nbits == 4 {
+ self.m / 2
+ } else {
+ self.m
+ }
+ }
+
+ /// Encode a single vector into PQ codes.
+ /// For nbits=8: codes has length M (one byte per sub-quantizer).
+ /// For nbits=4: codes has length M/2 (two nibbles per byte).
+ pub fn encode(&self, x: &[f32], codes: &mut [u8]) {
+ if self.nbits == 4 {
+ self.encode_4bit(x, codes);
+ } else {
+ self.encode_8bit(x, codes);
+ }
+ }
+
+ fn encode_8bit(&self, x: &[f32], codes: &mut [u8]) {
+ for sub in 0..self.m {
+ let x_off = sub * self.dsub;
+ let c_base = sub * self.ksub * self.dsub;
+
+ let mut best = 0u8;
+ let mut best_dist = f32::MAX;
+ for j in 0..self.ksub {
+ let c_off = c_base + j * self.dsub;
+ let dist = fvec_l2sqr_sub(x, x_off, &self.centroids, c_off,
self.dsub);
+ if dist < best_dist {
+ best_dist = dist;
+ best = j as u8;
+ }
+ }
+ codes[sub] = best;
+ }
+ }
+
+ fn encode_4bit(&self, x: &[f32], codes: &mut [u8]) {
+ for pair in 0..self.m / 2 {
+ let sub_lo = pair * 2;
+ let sub_hi = pair * 2 + 1;
+
+ let mut best_lo = 0u8;
+ let mut best_dist_lo = f32::MAX;
+ let x_off_lo = sub_lo * self.dsub;
+ let c_base_lo = sub_lo * self.ksub * self.dsub;
+ for j in 0..self.ksub {
+ let dist = fvec_l2sqr_sub(
+ x,
+ x_off_lo,
+ &self.centroids,
+ c_base_lo + j * self.dsub,
+ self.dsub,
+ );
+ if dist < best_dist_lo {
+ best_dist_lo = dist;
+ best_lo = j as u8;
+ }
+ }
+
+ let mut best_hi = 0u8;
+ let mut best_dist_hi = f32::MAX;
+ let x_off_hi = sub_hi * self.dsub;
+ let c_base_hi = sub_hi * self.ksub * self.dsub;
+ for j in 0..self.ksub {
+ let dist = fvec_l2sqr_sub(
+ x,
+ x_off_hi,
+ &self.centroids,
+ c_base_hi + j * self.dsub,
+ self.dsub,
+ );
+ if dist < best_dist_hi {
+ best_dist_hi = dist;
+ best_hi = j as u8;
+ }
+ }
+
+ // Pack: low nibble + high nibble
+ codes[pair] = best_lo | (best_hi << 4);
+ }
+ }
+
+ /// Encode multiple vectors in parallel.
+ pub fn encode_batch(&self, data: &[f32], n: usize, codes: &mut [u8]) {
+ let d = self.d;
+ let cs = self.code_size();
+
+ codes
+ .par_chunks_mut(cs)
+ .enumerate()
+ .for_each(|(i, code_chunk)| {
+ if i < n {
+ self.encode(&data[i * d..(i + 1) * d], code_chunk);
+ }
+ });
+ }
+
+ /// Decode PQ codes back to an approximate vector.
+ pub fn decode(&self, codes: &[u8], x: &mut [f32]) {
+ if self.nbits == 4 {
+ for pair in 0..self.m / 2 {
+ let byte = codes[pair];
+ let code_lo = (byte & 0x0F) as usize;
+ let code_hi = ((byte >> 4) & 0x0F) as usize;
+
+ let sub_lo = pair * 2;
+ let sub_hi = pair * 2 + 1;
+
+ let c_off_lo = sub_lo * self.ksub * self.dsub + code_lo *
self.dsub;
+ let x_off_lo = sub_lo * self.dsub;
+ x[x_off_lo..x_off_lo + self.dsub]
+ .copy_from_slice(&self.centroids[c_off_lo..c_off_lo +
self.dsub]);
+
+ let c_off_hi = sub_hi * self.ksub * self.dsub + code_hi *
self.dsub;
+ let x_off_hi = sub_hi * self.dsub;
+ x[x_off_hi..x_off_hi + self.dsub]
+ .copy_from_slice(&self.centroids[c_off_hi..c_off_hi +
self.dsub]);
+ }
+ } else {
+ for sub in 0..self.m {
+ let c_off = sub * self.ksub * self.dsub + (codes[sub] as
usize) * self.dsub;
+ let x_off = sub * self.dsub;
+ x[x_off..x_off + self.dsub]
+ .copy_from_slice(&self.centroids[c_off..c_off +
self.dsub]);
+ }
+ }
+ }
+
+ /// Precompute the distance table from a query to all PQ centroids.
+ /// Uses sgemm for dsub >= 4 (L2: ||q-c||²=||q||²+||c||²-2q·cᵀ).
+ pub fn compute_distance_table(&self, query: &[f32], metric: MetricType,
table: &mut [f32]) {
+ if self.dsub >= 4 {
+ self.compute_distance_table_sgemm(query, metric, table);
+ } else {
+ self.compute_distance_table_loop(query, metric, table);
+ }
+ }
+
+ fn compute_distance_table_sgemm(&self, query: &[f32], metric: MetricType,
table: &mut [f32]) {
+ for sub in 0..self.m {
+ let q_off = sub * self.dsub;
+ let c_base = sub * self.ksub * self.dsub;
+ let t_base = sub * self.ksub;
+
+ // Inner product: ip[ksub] = query_sub[1×dsub] ·
centroids_sub[ksub×dsub]ᵀ
+ sgemm_a_bt(
+ 1,
+ self.ksub,
+ self.dsub,
+ 1.0,
+ &query[q_off..q_off + self.dsub],
+ &self.centroids[c_base..c_base + self.ksub * self.dsub],
+ 0.0,
+ &mut table[t_base..t_base + self.ksub],
+ );
+
+ match metric {
+ MetricType::L2 | MetricType::Cosine => {
+ // ||q-c||² = ||q||² + ||c||² - 2·q·c
+ // Use pre-cached centroid norms (avoids recomputing per
query)
+ let q_norm = fvec_norm_l2sqr(&query[q_off..q_off +
self.dsub]);
+ let norms_base = sub * self.ksub;
+ for j in 0..self.ksub {
+ let c_norm = if !self.centroid_norms_cache.is_empty() {
+ self.centroid_norms_cache[norms_base + j]
+ } else {
+ let c_off = c_base + j * self.dsub;
+ fvec_norm_l2sqr(&self.centroids[c_off..c_off +
self.dsub])
+ };
+ table[t_base + j] = q_norm + c_norm - 2.0 *
table[t_base + j];
+ }
+ }
+ MetricType::InnerProduct => {
+ for j in 0..self.ksub {
+ table[t_base + j] = -table[t_base + j];
+ }
+ }
+ }
+ }
+ }
+
+ fn compute_distance_table_loop(&self, query: &[f32], metric: MetricType,
table: &mut [f32]) {
+ for sub in 0..self.m {
+ let q_off = sub * self.dsub;
+ let c_base = sub * self.ksub * self.dsub;
+ let t_base = sub * self.ksub;
+
+ match metric {
+ MetricType::L2 | MetricType::Cosine => {
+ fvec_l2sqr_batch(
+ &query[q_off..q_off + self.dsub],
+ &self.centroids[c_base..c_base + self.ksub *
self.dsub],
+ self.dsub,
+ self.ksub,
+ &mut table[t_base..t_base + self.ksub],
+ );
+ }
+ MetricType::InnerProduct => {
+ fvec_ip_batch(
+ &query[q_off..q_off + self.dsub],
+ &self.centroids[c_base..c_base + self.ksub *
self.dsub],
+ self.dsub,
+ self.ksub,
+ &mut table[t_base..t_base + self.ksub],
+ );
+ for j in 0..self.ksub {
+ table[t_base + j] = -table[t_base + j];
+ }
+ }
+ }
+ }
+ }
+
+ /// Compute inner product table: ip_table[m * ksub + j] = <query_m,
centroid_m_j>.
+ pub fn compute_inner_product_table(&self, query: &[f32], table: &mut
[f32]) {
+ for sub in 0..self.m {
+ let q_off = sub * self.dsub;
+ let c_base = sub * self.ksub * self.dsub;
+ let t_base = sub * self.ksub;
+
+ fvec_ip_batch(
+ &query[q_off..q_off + self.dsub],
+ &self.centroids[c_base..c_base + self.ksub * self.dsub],
+ self.dsub,
+ self.ksub,
+ &mut table[t_base..t_base + self.ksub],
+ );
+ }
+ }
+
+ /// Compute the approximate distance from a distance table.
+ #[inline]
+ pub fn distance_from_table(&self, table: &[f32], codes: &[u8]) -> f32 {
+ if self.nbits == 4 {
+ self.distance_from_table_4bit(table, codes)
+ } else {
+ pq_distance_from_table(table, codes, self.m, self.ksub)
+ }
+ }
+
+ /// 4-bit PQ distance: unpack nibbles and accumulate from 16-entry tables.
+ #[inline]
+ fn distance_from_table_4bit(&self, table: &[f32], codes: &[u8]) -> f32 {
+ let mut dist = 0.0f32;
+ for pair in 0..self.m / 2 {
+ let byte = codes[pair];
+ let code_lo = (byte & 0x0F) as usize;
+ let code_hi = ((byte >> 4) & 0x0F) as usize;
+
+ let sub_lo = pair * 2;
+ let sub_hi = pair * 2 + 1;
+
+ dist += table[sub_lo * self.ksub + code_lo];
+ dist += table[sub_hi * self.ksub + code_hi];
+ }
+ dist
+ }
+
+ /// Compute squared norms of all PQ centroids.
+ /// Uses cache if available, otherwise computes from scratch.
+ pub fn compute_centroid_norms(&self) -> Vec<f32> {
+ if !self.centroid_norms_cache.is_empty() {
+ return self.centroid_norms_cache.clone();
+ }
+ let mut norms = vec![0.0f32; self.m * self.ksub];
+ for sub in 0..self.m {
+ let c_base = sub * self.ksub * self.dsub;
+ for j in 0..self.ksub {
+ let c_off = c_base + j * self.dsub;
+ norms[sub * self.ksub + j] =
+ fvec_norm_l2sqr(&self.centroids[c_off..c_off + self.dsub]);
+ }
+ }
+ norms
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::rngs::StdRng;
+ use rand::{Rng, SeedableRng};
+
+ #[test]
+ fn test_encode_decode_roundtrip() {
+ let d = 8;
+ let m = 2;
+ let n = 100;
+ let mut rng = StdRng::seed_from_u64(42);
+
+ let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+ let mut pq = ProductQuantizer::new(d, m);
+ pq.train(&data, n);
+
+ let original = &data[0..d];
+ let mut codes = vec![0u8; m];
+ pq.encode(original, &mut codes);
+
+ let mut decoded = vec![0.0f32; d];
+ pq.decode(&codes, &mut decoded);
+
+ // Decoded should be a reasonable approximation
+ let error = fvec_l2sqr_sub(original, 0, &decoded, 0, d);
+ assert!(error < 10.0); // PQ introduces quantization error
+ }
+
+ #[test]
+ fn test_distance_table() {
+ let d = 8;
+ let m = 2;
+ let n = 100;
+ let mut rng = StdRng::seed_from_u64(42);
+
+ let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+ let mut pq = ProductQuantizer::new(d, m);
+ pq.train(&data, n);
+
+ let query = &data[0..d];
+ let mut table = vec![0.0f32; m * pq.ksub];
+ pq.compute_distance_table(query, MetricType::L2, &mut table);
+
+ let mut codes = vec![0u8; m];
+ pq.encode(query, &mut codes);
+
+ let dist = pq.distance_from_table(&table, &codes);
+ assert!(dist >= 0.0);
+ }
+
+ #[test]
+ fn test_4bit_encode_decode() {
+ let d = 8;
+ let m = 4; // must be even for 4-bit
+ let n = 200;
+ let mut rng = StdRng::seed_from_u64(42);
+
+ let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+ let mut pq = ProductQuantizer::with_nbits(d, m, 4);
+ assert_eq!(pq.ksub, 16);
+ assert_eq!(pq.code_size(), 2); // m/2 = 2 bytes per vector
+
+ pq.train(&data, n);
+
+ let original = &data[0..d];
+ let mut codes = vec![0u8; pq.code_size()];
+ pq.encode(original, &mut codes);
+
+ // Verify codes are non-trivial (not all zeros)
+ assert!(codes.iter().any(|&b| b != 0));
+
+ let mut decoded = vec![0.0f32; d];
+ pq.decode(&codes, &mut decoded);
+
+ // Should be a reasonable approximation
+ let error = fvec_l2sqr_sub(original, 0, &decoded, 0, d);
+ assert!(error < 20.0); // 4-bit has higher error than 8-bit
+
+ // Distance table
+ let mut table = vec![0.0f32; m * pq.ksub];
+ pq.compute_distance_table(original, MetricType::L2, &mut table);
+ let dist = pq.distance_from_table(&table, &codes);
+ assert!(dist >= 0.0);
+ }
+
+ #[test]
+ fn test_4bit_batch_encode() {
+ let d = 16;
+ let m = 8;
+ let n = 100;
+ let mut rng = StdRng::seed_from_u64(42);
+
+ let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
+
+ let mut pq = ProductQuantizer::with_nbits(d, m, 4);
+ pq.train(&data, n);
+
+ let cs = pq.code_size(); // m/2 = 4
+ let mut codes = vec![0u8; n * cs];
+ pq.encode_batch(&data, n, &mut codes);
+
+ // Verify codes are non-trivial (not all zeros)
+ assert!(codes.iter().any(|&b| b != 0));
+ }
+}