This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 9612a72 Add IVF_FLAT vector index (#18)
9612a72 is described below
commit 9612a724b195b3918411b91154f3f3ed59a2c6bb
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Jun 9 12:00:42 2026 +0800
Add IVF_FLAT vector index (#18)
---
core/Cargo.toml | 4 +
core/benches/recall_bench.rs | 178 ++++
core/src/distance.rs | 45 +-
core/src/ivfflat.rs | 277 ++++++
core/src/ivfflat_io.rs | 1047 ++++++++++++++++++++
core/src/lib.rs | 2 +
.../paimon/index/ivfpq/IVFPQJavaApiTest.java | 26 +
.../apache/paimon/index/ivfpq/IVFFlatNative.java | 52 +
.../apache/paimon/index/ivfpq/IVFFlatReader.java | 115 +++
.../apache/paimon/index/ivfpq/IVFFlatWriter.java | 108 ++
jni/src/lib.rs | 433 ++++++++
python/src/lib.rs | 351 +++++++
12 files changed, 2623 insertions(+), 15 deletions(-)
diff --git a/core/Cargo.toml b/core/Cargo.toml
index 0b51ead..af0d65e 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -34,3 +34,7 @@ criterion = "0.5"
[[bench]]
name = "pq4_bench"
harness = false
+
+[[bench]]
+name = "recall_bench"
+harness = false
diff --git a/core/benches/recall_bench.rs b/core/benches/recall_bench.rs
new file mode 100644
index 0000000..03415be
--- /dev/null
+++ b/core/benches/recall_bench.rs
@@ -0,0 +1,178 @@
+use paimon_vindex_core::distance::{fvec_distance, MetricType};
+use paimon_vindex_core::ivfflat::IVFFlatIndex;
+use paimon_vindex_core::ivfpq::IVFPQIndex;
+use std::collections::HashSet;
+use std::time::Instant;
+
+fn main() {
+ run_scenario(Scenario {
+ name: "small-lists",
+ d: 64,
+ n: 20_000,
+ nq: 50,
+ k: 10,
+ nlist: 64,
+ pq_m: 8,
+ nprobes: &[1, 4, 8, 16, 32, 64],
+ });
+
+ println!();
+
+ run_scenario(Scenario {
+ name: "large-lists",
+ d: 64,
+ n: 50_000,
+ nq: 50,
+ k: 10,
+ nlist: 8,
+ pq_m: 8,
+ nprobes: &[1, 2, 4, 8],
+ });
+}
+
+struct Scenario<'a> {
+ name: &'a str,
+ d: usize,
+ n: usize,
+ nq: usize,
+ k: usize,
+ nlist: usize,
+ pq_m: usize,
+ nprobes: &'a [usize],
+}
+
+fn run_scenario(s: Scenario<'_>) {
+ println!("=== IVF Recall Attribution Benchmark ===");
+ println!(
+ "scenario: {}, n={}, nq={}, d={}, nlist={}, avg_list={}, k={},
metric=L2",
+ s.name,
+ s.n,
+ s.nq,
+ s.d,
+ s.nlist,
+ s.n / s.nlist,
+ s.k
+ );
+
+ let data = generate_clustered_data(s.n, s.d, 32, 42);
+ let ids: Vec<i64> = (0..s.n as i64).collect();
+ let queries = &data[..s.nq * s.d];
+
+ let start = Instant::now();
+ let ground_truth = brute_force_ground_truth(&data, queries, s.n, s.nq,
s.d, s.k);
+ println!("ground truth: {:.2}s", start.elapsed().as_secs_f64());
+
+ let start = Instant::now();
+ let mut ivfpq = IVFPQIndex::new(s.d, s.nlist, s.pq_m, MetricType::L2,
false);
+ ivfpq.train(&data, s.n);
+ ivfpq.add(&data, &ids, s.n);
+ ivfpq.build_precomputed_table();
+ println!("build IVF-PQ: {:.2}s", start.elapsed().as_secs_f64());
+
+ let start = Instant::now();
+ let mut ivfflat = IVFFlatIndex::new(s.d, s.nlist, MetricType::L2);
+ ivfflat.train(&data, s.n);
+ ivfflat.add(&data, &ids, s.n);
+ println!("build IVF-FLAT: {:.2}s", start.elapsed().as_secs_f64());
+
+ println!();
+ println!("index nprobe recall@{} query_ms us/query", s.k);
+ println!("--------- ------ --------- -------- --------");
+
+ for &nprobe in s.nprobes {
+ let mut distances = vec![0.0f32; s.nq * s.k];
+ let mut labels = vec![0i64; s.nq * s.k];
+ let start = Instant::now();
+ ivfpq.search(queries, s.nq, s.k, nprobe, &mut distances, &mut labels);
+ let elapsed = start.elapsed();
+ print_row(
+ "IVF-PQ",
+ nprobe,
+ recall_at_k(&labels, &ground_truth, s.nq, s.k),
+ elapsed,
+ s.nq,
+ );
+
+ let mut distances = vec![0.0f32; s.nq * s.k];
+ let mut labels = vec![0i64; s.nq * s.k];
+ let start = Instant::now();
+ ivfflat.search(queries, s.nq, s.k, nprobe, &mut distances, &mut
labels);
+ let elapsed = start.elapsed();
+ print_row(
+ "IVF-FLAT",
+ nprobe,
+ recall_at_k(&labels, &ground_truth, s.nq, s.k),
+ elapsed,
+ s.nq,
+ );
+ }
+}
+
+fn print_row(index: &str, nprobe: usize, recall: f64, elapsed:
std::time::Duration, nq: usize) {
+ let ms = elapsed.as_secs_f64() * 1000.0;
+ println!(
+ "{:<9} {:>6} {:>8.2}% {:>8.2} {:>8.1}",
+ index,
+ nprobe,
+ recall * 100.0,
+ ms,
+ ms * 1000.0 / nq as f64
+ );
+}
+
+fn recall_at_k(labels: &[i64], ground_truth: &[Vec<i64>], nq: usize, k: usize)
-> f64 {
+ let mut hits = 0usize;
+ for qi in 0..nq {
+ let gt: HashSet<i64> = ground_truth[qi].iter().copied().collect();
+ hits += labels[qi * k..(qi + 1) * k]
+ .iter()
+ .filter(|id| gt.contains(id))
+ .count();
+ }
+ hits as f64 / (nq * k) as f64
+}
+
+fn brute_force_ground_truth(
+ data: &[f32],
+ queries: &[f32],
+ n: usize,
+ nq: usize,
+ d: usize,
+ k: usize,
+) -> Vec<Vec<i64>> {
+ (0..nq)
+ .map(|qi| {
+ let query = &queries[qi * d..(qi + 1) * d];
+ let mut distances: Vec<(f32, i64)> = (0..n)
+ .map(|i| {
+ let vector = &data[i * d..(i + 1) * d];
+ (fvec_distance(query, vector, MetricType::L2), i as i64)
+ })
+ .collect();
+ distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ distances[..k].iter().map(|&(_, id)| id).collect()
+ })
+ .collect()
+}
+
+fn generate_clustered_data(n: usize, d: usize, num_clusters: usize, seed: u64)
-> Vec<f32> {
+ let mut rng_state = seed;
+ let mut next = || {
+ rng_state =
rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
+ ((rng_state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
+ };
+
+ let mut centers = vec![0.0f32; num_clusters * d];
+ for value in &mut centers {
+ *value = next() * 30.0;
+ }
+
+ let mut data = vec![0.0f32; n * d];
+ for i in 0..n {
+ let cluster = i % num_clusters;
+ for j in 0..d {
+ data[i * d + j] = centers[cluster * d + j] + next();
+ }
+ }
+ data
+}
diff --git a/core/src/distance.rs b/core/src/distance.rs
index f6174e1..6c86d3c 100644
--- a/core/src/distance.rs
+++ b/core/src/distance.rs
@@ -86,6 +86,25 @@ pub fn fvec_normalize(v: &mut [f32]) -> f32 {
norm
}
+/// Distance used for ranking search results. Lower is better for all metrics.
+pub fn fvec_distance(query: &[f32], vector: &[f32], metric: MetricType) -> f32
{
+ match metric {
+ MetricType::L2 => fvec_l2sqr(query, vector),
+ MetricType::InnerProduct => -fvec_inner_product(query, vector),
+ MetricType::Cosine => {
+ let dot = fvec_inner_product(query, vector);
+ let nq = fvec_norm_l2sqr(query).sqrt();
+ let nv = fvec_norm_l2sqr(vector).sqrt();
+ let denom = nq * nv;
+ if denom > 0.0 {
+ 1.0 - dot / denom
+ } else {
+ 1.0
+ }
+ }
+ }
+}
+
/// Compute result[i] = a[i] + bf * b[i]. Used for precomputed table merging.
/// Aligned with Faiss's fvec_madd.
pub fn fvec_madd(a: &[f32], b: &[f32], bf: f32, result: &mut [f32]) {
@@ -415,21 +434,7 @@ pub fn fvec_distances_batch(
) {
for i in 0..n {
let vec = &vectors[i * d..(i + 1) * d];
- distances[i] = match metric {
- MetricType::L2 => fvec_l2sqr(query, vec),
- MetricType::InnerProduct => -fvec_inner_product(query, vec),
- MetricType::Cosine => {
- let dot = fvec_inner_product(query, vec);
- let na = fvec_norm_l2sqr(query).sqrt();
- let nb = fvec_norm_l2sqr(vec).sqrt();
- let denom = na * nb;
- if denom > 0.0 {
- 1.0 - dot / denom
- } else {
- 1.0
- }
- }
- };
+ distances[i] = fvec_distance(query, vec, metric);
}
}
@@ -451,6 +456,16 @@ mod tests {
assert!((fvec_inner_product(&a, &b) - 32.0).abs() < 1e-6);
}
+ #[test]
+ fn test_fvec_distance_by_metric() {
+ let a = [1.0, 0.0];
+ let b = [0.0, 1.0];
+
+ assert!((fvec_distance(&a, &b, MetricType::L2) - 2.0).abs() < 1e-6);
+ assert!((fvec_distance(&a, &b, MetricType::InnerProduct) - 0.0).abs()
< 1e-6);
+ assert!((fvec_distance(&a, &b, MetricType::Cosine) - 1.0).abs() <
1e-6);
+ }
+
#[test]
fn test_normalize() {
let mut v = [3.0, 4.0];
diff --git a/core/src/ivfflat.rs b/core/src/ivfflat.rs
new file mode 100644
index 0000000..185bdf1
--- /dev/null
+++ b/core/src/ivfflat.rs
@@ -0,0 +1,277 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, fvec_normalize, MetricType};
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans::{self, KMeansConfig};
+
+/// IVF-FLAT index. Stores raw vectors in each IVF list for exact per-list
scan.
+pub struct IVFFlatIndex {
+ pub d: usize,
+ pub nlist: usize,
+ pub metric: MetricType,
+ pub quantizer_centroids: Vec<f32>,
+ pub ids: Vec<Vec<i64>>,
+ pub vectors: Vec<Vec<f32>>,
+}
+
+impl IVFFlatIndex {
+ pub fn new(d: usize, nlist: usize, metric: MetricType) -> Self {
+ IVFFlatIndex {
+ d,
+ nlist,
+ metric,
+ quantizer_centroids: Vec::new(),
+ ids: vec![Vec::new(); nlist],
+ vectors: vec![Vec::new(); nlist],
+ }
+ }
+
+ pub fn train(&mut self, data: &[f32], n: usize) {
+ let train_data = self.preprocess_vectors(data, n);
+ self.quantizer_centroids =
+ kmeans::kmeans_train(&KMeansConfig::default(), &train_data, n,
self.d, self.nlist);
+ }
+
+ pub fn add(&mut self, data: &[f32], ids: &[i64], n: usize) {
+ let processed = self.preprocess_vectors(data, n);
+ for i in 0..n {
+ let vector = &processed[i * self.d..(i + 1) * self.d];
+ let list_id =
+ kmeans::find_nearest(vector, &self.quantizer_centroids,
self.nlist, self.d);
+ self.ids[list_id].push(ids[i]);
+ self.vectors[list_id].extend_from_slice(vector);
+ }
+ }
+
+ pub fn total_vectors(&self) -> usize {
+ self.ids.iter().map(Vec::len).sum()
+ }
+
+ pub fn search(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ self.search_with_filter(
+ queries,
+ nq,
+ k,
+ nprobe,
+ None,
+ result_distances,
+ result_labels,
+ );
+ }
+
+ pub fn search_with_filter(
+ &self,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ filter: Option<&dyn RowIdFilter>,
+ result_distances: &mut [f32],
+ result_labels: &mut [i64],
+ ) {
+ let processed_queries = self.preprocess_vectors(queries, nq);
+ let (all_probe_indices, _) = kmeans::find_topk_batch(
+ &processed_queries,
+ nq,
+ &self.quantizer_centroids,
+ self.nlist,
+ self.d,
+ nprobe,
+ );
+
+ for qi in 0..nq {
+ let query = &processed_queries[qi * self.d..(qi + 1) * self.d];
+ let mut heap = FlatTopKHeap::new(k);
+
+ for &list_id in &all_probe_indices[qi] {
+ let ids = &self.ids[list_id];
+ let vectors = &self.vectors[list_id];
+ for (local_idx, &id) in ids.iter().enumerate() {
+ if let Some(f) = filter {
+ if !f.contains(id) {
+ continue;
+ }
+ }
+ let vector = &vectors[local_idx * self.d..(local_idx + 1)
* self.d];
+ heap.push(fvec_distance(query, vector, self.metric), id);
+ }
+ }
+
+ let sorted = heap.into_sorted();
+ let out_base = qi * k;
+ for (i, &(dist, id)) in sorted.iter().enumerate() {
+ result_distances[out_base + i] = dist;
+ result_labels[out_base + i] = id;
+ }
+ for i in sorted.len()..k {
+ result_distances[out_base + i] = f32::MAX;
+ result_labels[out_base + i] = -1;
+ }
+ }
+ }
+
+ pub(crate) fn preprocess_vectors(&self, data: &[f32], n: usize) ->
Vec<f32> {
+ let mut processed = data[..n * self.d].to_vec();
+ if self.metric == MetricType::Cosine {
+ for i in 0..n {
+ fvec_normalize(&mut processed[i * self.d..(i + 1) * self.d]);
+ }
+ }
+ processed
+ }
+}
+
+struct FlatTopKHeap {
+ k: usize,
+ data: Vec<(f32, i64)>,
+}
+
+impl FlatTopKHeap {
+ fn new(k: usize) -> Self {
+ Self {
+ k,
+ data: Vec::with_capacity(k),
+ }
+ }
+
+ fn push(&mut self, dist: f32, id: i64) {
+ if self.k == 0 {
+ return;
+ }
+ if self.data.len() < self.k {
+ self.data.push((dist, id));
+ return;
+ }
+ if let Some((worst_idx, _)) = self
+ .data
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
+ {
+ if dist < self.data[worst_idx].0 {
+ self.data[worst_idx] = (dist, id);
+ }
+ }
+ }
+
+ fn into_sorted(mut self) -> Vec<(f32, i64)> {
+ self.data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ self.data
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::distance::MetricType;
+
+ #[test]
+ fn test_ivfflat_add_assigns_all_vectors() {
+ let d = 4;
+ let nlist = 2;
+ let n = 16;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| [i as f32, 0.0, i as f32 + 0.5, 1.0])
+ .collect();
+ let ids: Vec<i64> = (0..n as i64).collect();
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ assert_eq!(index.total_vectors(), n);
+ for list_id in 0..nlist {
+ assert_eq!(index.vectors[list_id].len(), index.ids[list_id].len()
* d);
+ }
+ }
+
+ #[test]
+ fn test_ivfflat_recalls_query_vector() {
+ let d = 4;
+ let nlist = 4;
+ let n = 64;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let query_id = 7;
+ let mut distances = vec![0.0; 5];
+ let mut labels = vec![0; 5];
+ index.search(
+ &data[query_id * d..(query_id + 1) * d],
+ 1,
+ 5,
+ nlist,
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(labels[0], ids[query_id]);
+ assert_eq!(distances[0], 0.0);
+ for i in 1..5 {
+ assert!(distances[i] >= distances[i - 1]);
+ }
+ }
+
+ #[test]
+ fn test_ivfflat_search_with_filter() {
+ use std::collections::HashSet;
+
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+
+ let filter: HashSet<i64> = [12].into_iter().collect();
+ let mut distances = vec![0.0; 2];
+ let mut labels = vec![0; 2];
+ index.search_with_filter(
+ &[0.0, 0.0],
+ 1,
+ 2,
+ 1,
+ Some(&filter),
+ &mut distances,
+ &mut labels,
+ );
+
+ assert_eq!(labels, vec![12, -1]);
+ assert!(distances[0] > 0.0);
+ assert_eq!(distances[1], f32::MAX);
+ }
+}
diff --git a/core/src/ivfflat_io.rs b/core/src/ivfflat_io.rs
new file mode 100644
index 0000000..da398c4
--- /dev/null
+++ b/core/src/ivfflat_io.rs
@@ -0,0 +1,1047 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::distance::{fvec_distance, fvec_normalize, MetricType};
+use crate::io::{SeekRead, SeekWrite};
+use crate::ivfflat::IVFFlatIndex;
+use crate::ivfpq::RowIdFilter;
+use crate::kmeans;
+use roaring::RoaringTreemap;
+use std::io;
+
+pub const IVFFLAT_MAGIC: u32 = 0x4956464C; // "IVFL"
+pub const IVFFLAT_VERSION: u32 = 1;
+pub const IVFFLAT_HEADER_SIZE: usize = 64;
+
+const FLAG_DELTA_IDS: u32 = 1 << 0;
+
+pub fn write_ivfflat_index(index: &IVFFlatIndex, out: &mut dyn SeekWrite) ->
io::Result<()> {
+ let d = index.d;
+ let nlist = index.nlist;
+ validate_index_shape(index)?;
+ let d_i32 = usize_to_i32(d, "dimension")?;
+ let nlist_i32 = usize_to_i32(nlist, "nlist")?;
+ let total_vectors = index.ids.iter().try_fold(0i64, |sum, ids| {
+ let count = usize_to_i64(ids.len(), "total vector count")?;
+ sum.checked_add(count).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "total vector count exceeds i64 length limit",
+ )
+ })
+ })?;
+
+ let mut sorted_lists: Vec<(Vec<i64>, Vec<u8>, Vec<f32>)> =
Vec::with_capacity(nlist);
+ for list_id in 0..nlist {
+ let count = index.ids[list_id].len();
+ if count == 0 {
+ sorted_lists.push((Vec::new(), Vec::new(), Vec::new()));
+ continue;
+ }
+
+ let mut order: Vec<usize> = (0..count).collect();
+ order.sort_by_key(|&idx| index.ids[list_id][idx]);
+
+ let sorted_ids: Vec<i64> = order.iter().map(|&idx|
index.ids[list_id][idx]).collect();
+ let mut sorted_vectors = Vec::with_capacity(count * d);
+ for idx in order {
+ sorted_vectors.extend_from_slice(&index.vectors[list_id][idx *
d..(idx + 1) * d]);
+ }
+ let (_, id_bytes) = encode_delta_varint_ids(&sorted_ids);
+ sorted_lists.push((sorted_ids, id_bytes, sorted_vectors));
+ }
+
+ write_u32_le(out, IVFFLAT_MAGIC)?;
+ write_u32_le(out, IVFFLAT_VERSION)?;
+ write_i32_le(out, d_i32)?;
+ write_i32_le(out, nlist_i32)?;
+ write_u32_le(out, index.metric as u32)?;
+ write_i64_le(out, total_vectors)?;
+ write_u32_le(out, FLAG_DELTA_IDS)?;
+ out.write_all(&[0u8; 32])?;
+
+ write_f32_slice(out, &index.quantizer_centroids)?;
+
+ let offset_table_size = nlist.checked_mul(16).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-FLAT offset table size overflow",
+ )
+ })?;
+ let data_start = out
+ .pos()
+ .checked_add(offset_table_size as u64)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-FLAT data start offset overflow",
+ )
+ })?;
+ let mut list_offsets = vec![0i64; nlist];
+ let mut list_counts = vec![0i32; nlist];
+ let mut list_id_bytes_lens = vec![0i32; nlist];
+ let mut current_offset = data_start;
+
+ for list_id in 0..nlist {
+ list_offsets[list_id] = u64_to_i64(current_offset, "list offset")?;
+ let count = sorted_lists[list_id].0.len();
+ list_counts[list_id] = usize_to_i32(count, "list count")?;
+ if count > 0 {
+ let id_bytes_len = sorted_lists[list_id].1.len();
+ list_id_bytes_lens[list_id] = usize_to_i32(id_bytes_len, "delta ID
section")?;
+ let vector_bytes = checked_list_bytes(
+ count,
+ d.checked_mul(4).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-FLAT bytes per vector overflow",
+ )
+ })?,
+ )?;
+ let list_bytes = 12usize
+ .checked_add(id_bytes_len)
+ .and_then(|len| len.checked_add(vector_bytes))
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "IVF-FLAT list
size overflow")
+ })?;
+ current_offset = current_offset
+ .checked_add(list_bytes as u64)
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidInput, "IVF-FLAT
offset overflow")
+ })?;
+ }
+ }
+
+ for list_id in 0..nlist {
+ write_i64_le(out, list_offsets[list_id])?;
+ write_i32_le(out, list_counts[list_id])?;
+ write_i32_le(out, list_id_bytes_lens[list_id])?;
+ }
+
+ for (sorted_ids, id_bytes, sorted_vectors) in sorted_lists {
+ if sorted_ids.is_empty() {
+ continue;
+ }
+ write_i64_le(out, sorted_ids[0])?;
+ write_i32_le(out, id_bytes.len() as i32)?;
+ out.write_all(&id_bytes)?;
+ write_f32_slice(out, &sorted_vectors)?;
+ }
+
+ Ok(())
+}
+
+pub struct IVFFlatIndexReader<R: SeekRead> {
+ reader: R,
+ pub d: usize,
+ pub nlist: usize,
+ pub metric: MetricType,
+ pub total_vectors: i64,
+ pub quantizer_centroids: Vec<f32>,
+ pub list_offsets: Vec<i64>,
+ pub list_counts: Vec<i32>,
+ pub list_id_bytes_lens: Vec<i32>,
+ delta_ids: bool,
+ loaded: bool,
+}
+
+impl<R: SeekRead> IVFFlatIndexReader<R> {
+ pub fn open(mut reader: R) -> io::Result<Self> {
+ reader.seek(0)?;
+
+ let magic = read_u32_le(&mut reader)?;
+ if magic != IVFFLAT_MAGIC {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Invalid IVFFLAT magic: 0x{:08X}", magic),
+ ));
+ }
+ let version = read_u32_le(&mut reader)?;
+ if version != IVFFLAT_VERSION {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unsupported IVFFLAT version: {}", version),
+ ));
+ }
+
+ let d = validate_positive_i32(read_i32_le(&mut reader)?, "d")? as
usize;
+ let nlist = validate_positive_i32(read_i32_le(&mut reader)?, "nlist")?
as usize;
+ let metric_code = read_u32_le(&mut reader)?;
+ let metric = MetricType::from_code(metric_code).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Unknown metric type: {}", metric_code),
+ )
+ })?;
+ let total_vectors = read_i64_le(&mut reader)?;
+ let flags = read_u32_le(&mut reader)?;
+ let mut reserved = [0u8; 32];
+ reader.read_exact(&mut reserved)?;
+
+ Ok(Self {
+ reader,
+ d,
+ nlist,
+ metric,
+ total_vectors,
+ quantizer_centroids: Vec::new(),
+ list_offsets: Vec::new(),
+ list_counts: Vec::new(),
+ list_id_bytes_lens: Vec::new(),
+ delta_ids: flags & FLAG_DELTA_IDS != 0,
+ loaded: false,
+ })
+ }
+
+ pub fn ensure_loaded(&mut self) -> io::Result<()> {
+ if self.loaded {
+ return Ok(());
+ }
+
+ self.reader.seek(IVFFLAT_HEADER_SIZE as u64)?;
+ self.quantizer_centroids =
+ read_f32_vec(&mut self.reader, checked_section_size(self.nlist,
self.d)?)?;
+ self.list_offsets = vec![0; self.nlist];
+ self.list_counts = vec![0; self.nlist];
+ self.list_id_bytes_lens = vec![0; self.nlist];
+ for list_id in 0..self.nlist {
+ self.list_offsets[list_id] = read_i64_le(&mut self.reader)?;
+ let count = read_i32_le(&mut self.reader)?;
+ if count < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("negative list count {} at list {}", count,
list_id),
+ ));
+ }
+ self.list_counts[list_id] = count;
+ let id_bytes_len = read_i32_le(&mut self.reader)?;
+ if id_bytes_len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("negative id_bytes_len {} at list {}",
id_bytes_len, list_id),
+ ));
+ }
+ self.list_id_bytes_lens[list_id] = id_bytes_len;
+ }
+
+ self.loaded = true;
+ Ok(())
+ }
+
+ pub fn read_inverted_list(&mut self, list_id: usize) ->
io::Result<(Vec<i64>, Vec<f32>)> {
+ self.ensure_loaded()?;
+ if list_id >= self.nlist {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("list_id {} out of range (nlist={})", list_id,
self.nlist),
+ ));
+ }
+ let count = self.list_counts[list_id] as usize;
+ if count == 0 {
+ return Ok((Vec::new(), Vec::new()));
+ }
+
+ let offset = checked_list_offset(self.list_offsets[list_id], list_id)?;
+ let vector_bytes = checked_list_bytes(count, self.d * 4)?;
+ if self.delta_ids {
+ let id_bytes_len = self.list_id_bytes_lens[list_id] as usize;
+ let payload_len = 12usize
+ .checked_add(id_bytes_len)
+ .and_then(|len| len.checked_add(vector_bytes))
+ .ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidData, "IVF-FLAT list
payload overflow")
+ })?;
+ let mut payload = vec![0u8; payload_len];
+ self.reader.pread(offset, &mut payload)?;
+ let base_id =
i64::from_le_bytes(payload[0..8].try_into().unwrap());
+ let encoded_len =
i32::from_le_bytes(payload[8..12].try_into().unwrap());
+ if encoded_len < 0 || encoded_len as usize != id_bytes_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-FLAT id_bytes_len mismatch",
+ ));
+ }
+ let ids = decode_delta_varint_ids(base_id, &payload[12..12 +
id_bytes_len], count)?;
+ let vectors = bytes_to_f32_vec(&payload[12 + id_bytes_len..])?;
+ Ok((ids, vectors))
+ } else {
+ Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-FLAT reader only supports delta IDs",
+ ))
+ }
+ }
+
+ pub fn search(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ self.search_with_filter(query, k, nprobe, None)
+ }
+
+ pub fn search_with_filter(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ filter: Option<&dyn RowIdFilter>,
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ self.ensure_loaded()?;
+ if query.len() != self.d {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "query length {} does not match index dimension {}",
+ query.len(),
+ self.d
+ ),
+ ));
+ }
+ if k == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "k must be greater than 0",
+ ));
+ }
+ if nprobe == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nprobe must be greater than 0",
+ ));
+ }
+
+ let mut q = query.to_vec();
+ if self.metric == MetricType::Cosine {
+ fvec_normalize(&mut q);
+ }
+
+ let (probe_indices, _) =
+ kmeans::find_topk(&q, &self.quantizer_centroids, self.nlist,
self.d, nprobe);
+ let mut heap = ReaderTopKHeap::new(k);
+
+ for list_id in probe_indices {
+ let (ids, vectors) = self.read_inverted_list(list_id)?;
+ for (local_idx, &id) in ids.iter().enumerate() {
+ if let Some(f) = filter {
+ if !f.contains(id) {
+ continue;
+ }
+ }
+ let vector = &vectors[local_idx * self.d..(local_idx + 1) *
self.d];
+ heap.push(fvec_distance(&q, vector, self.metric), id);
+ }
+ }
+
+ let sorted = heap.into_sorted();
+ let mut labels: Vec<i64> = sorted.iter().map(|&(_, id)| id).collect();
+ let mut distances: Vec<f32> = sorted.iter().map(|&(dist, _)|
dist).collect();
+ labels.resize(k, -1);
+ distances.resize(k, f32::MAX);
+ Ok((labels, distances))
+ }
+
+ pub fn search_with_roaring_filter(
+ &mut self,
+ query: &[f32],
+ k: usize,
+ nprobe: usize,
+ roaring_filter_bytes: &[u8],
+ ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ let filter = decode_roaring_filter(roaring_filter_bytes)?;
+ self.search_with_filter(query, k, nprobe, Some(&filter))
+ }
+}
+
+/// Batch search for IVF-FLAT readers. Each unique probed list is read once and
+/// scanned for all queries that selected it.
+pub fn search_batch_ivfflat_reader<R: SeekRead>(
+ reader: &mut IVFFlatIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ search_batch_ivfflat_reader_filter(reader, queries, nq, k, nprobe, None)
+}
+
+pub fn search_batch_ivfflat_reader_filter<R: SeekRead>(
+ reader: &mut IVFFlatIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ filter: Option<&dyn RowIdFilter>,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ reader.ensure_loaded()?;
+ let d = reader.d;
+ if nq == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nq must be greater than 0",
+ ));
+ }
+ let expected_query_len = nq.checked_mul(d).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nq * dimension overflows usize",
+ )
+ })?;
+ if queries.len() != expected_query_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "queries length {} does not match nq * dimension {}",
+ queries.len(),
+ expected_query_len
+ ),
+ ));
+ }
+ if k == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "k must be greater than 0",
+ ));
+ }
+ if nprobe == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nprobe must be greater than 0",
+ ));
+ }
+
+ let mut processed = queries[..expected_query_len].to_vec();
+ if reader.metric == MetricType::Cosine {
+ for qi in 0..nq {
+ fvec_normalize(&mut processed[qi * d..(qi + 1) * d]);
+ }
+ }
+
+ let (all_probe_indices, _) = kmeans::find_topk_batch(
+ &processed,
+ nq,
+ &reader.quantizer_centroids,
+ reader.nlist,
+ d,
+ nprobe,
+ );
+
+ let mut list_to_queries = vec![Vec::new(); reader.nlist];
+ let mut unique_lists = Vec::new();
+ for (qi, probe_indices) in all_probe_indices.iter().enumerate() {
+ for &list_id in probe_indices {
+ if list_to_queries[list_id].is_empty() {
+ unique_lists.push(list_id);
+ }
+ list_to_queries[list_id].push(qi);
+ }
+ }
+
+ let mut heaps: Vec<ReaderTopKHeap> = (0..nq).map(|_|
ReaderTopKHeap::new(k)).collect();
+ for list_id in unique_lists {
+ let count = reader.list_counts[list_id] as usize;
+ if count == 0 {
+ continue;
+ }
+ let (ids, vectors) = reader.read_inverted_list(list_id)?;
+ for &qi in &list_to_queries[list_id] {
+ let query = &processed[qi * d..(qi + 1) * d];
+ for (local_idx, &id) in ids.iter().enumerate() {
+ if let Some(f) = filter {
+ if !f.contains(id) {
+ continue;
+ }
+ }
+ let vector = &vectors[local_idx * d..(local_idx + 1) * d];
+ heaps[qi].push(fvec_distance(query, vector, reader.metric),
id);
+ }
+ }
+ }
+
+ let mut result_ids = vec![-1i64; nq * k];
+ let mut result_dists = vec![f32::MAX; nq * k];
+ for qi in 0..nq {
+ let sorted = std::mem::replace(&mut heaps[qi],
ReaderTopKHeap::new(0)).into_sorted();
+ let base = qi * k;
+ for (i, &(dist, id)) in sorted.iter().enumerate() {
+ result_ids[base + i] = id;
+ result_dists[base + i] = dist;
+ }
+ }
+
+ Ok((result_ids, result_dists))
+}
+
+pub fn search_batch_ivfflat_reader_roaring_filter<R: SeekRead>(
+ reader: &mut IVFFlatIndexReader<R>,
+ queries: &[f32],
+ nq: usize,
+ k: usize,
+ nprobe: usize,
+ roaring_filter_bytes: &[u8],
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ let filter = decode_roaring_filter(roaring_filter_bytes)?;
+ search_batch_ivfflat_reader_filter(reader, queries, nq, k, nprobe,
Some(&filter))
+}
+
+struct ReaderTopKHeap {
+ k: usize,
+ data: Vec<(f32, i64)>,
+}
+
+impl ReaderTopKHeap {
+ fn new(k: usize) -> Self {
+ Self {
+ k,
+ data: Vec::with_capacity(k),
+ }
+ }
+
+ fn push(&mut self, dist: f32, id: i64) {
+ if self.k == 0 {
+ return;
+ }
+ if self.data.len() < self.k {
+ self.data.push((dist, id));
+ return;
+ }
+ if let Some((worst_idx, _)) = self
+ .data
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
+ {
+ if dist < self.data[worst_idx].0 {
+ self.data[worst_idx] = (dist, id);
+ }
+ }
+ }
+
+ fn into_sorted(mut self) -> Vec<(f32, i64)> {
+ self.data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ self.data
+ }
+}
+
+fn write_u32_le(out: &mut dyn SeekWrite, v: u32) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+fn write_i32_le(out: &mut dyn SeekWrite, v: i32) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+fn write_i64_le(out: &mut dyn SeekWrite, v: i64) -> io::Result<()> {
+ out.write_all(&v.to_le_bytes())
+}
+
+fn write_f32_slice(out: &mut dyn SeekWrite, data: &[f32]) -> io::Result<()> {
+ let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
+ out.write_all(&bytes)
+}
+
+fn read_u32_le(reader: &mut dyn SeekRead) -> io::Result<u32> {
+ let mut buf = [0u8; 4];
+ reader.read_exact(&mut buf)?;
+ Ok(u32::from_le_bytes(buf))
+}
+
+fn read_i32_le(reader: &mut dyn SeekRead) -> io::Result<i32> {
+ let mut buf = [0u8; 4];
+ reader.read_exact(&mut buf)?;
+ Ok(i32::from_le_bytes(buf))
+}
+
+fn read_i64_le(reader: &mut dyn SeekRead) -> io::Result<i64> {
+ let mut buf = [0u8; 8];
+ reader.read_exact(&mut buf)?;
+ Ok(i64::from_le_bytes(buf))
+}
+
+fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
+ if val <= 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("invalid header field {}: {} (must be positive)", field,
val),
+ ));
+ }
+ Ok(val)
+}
+
+fn validate_index_shape(index: &IVFFlatIndex) -> io::Result<()> {
+ if index.d == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "dimension must be greater than 0",
+ ));
+ }
+ if index.nlist == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nlist must be greater than 0",
+ ));
+ }
+ if index.ids.len() != index.nlist || index.vectors.len() != index.nlist {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-FLAT list storage does not match nlist",
+ ));
+ }
+ let centroid_len = checked_section_size(index.nlist, index.d)?;
+ if index.quantizer_centroids.len() != centroid_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "centroid length {} does not match nlist*d {}",
+ index.quantizer_centroids.len(),
+ centroid_len
+ ),
+ ));
+ }
+ for list_id in 0..index.nlist {
+ let expected_vector_len =
+ index.ids[list_id]
+ .len()
+ .checked_mul(index.d)
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "IVF-FLAT vector length overflow",
+ )
+ })?;
+ if index.vectors[list_id].len() != expected_vector_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "list {} vector length {} does not match ids*d {}",
+ list_id,
+ index.vectors[list_id].len(),
+ expected_vector_len
+ ),
+ ));
+ }
+ }
+ Ok(())
+}
+
+fn usize_to_i32(value: usize, field: &str) -> io::Result<i32> {
+ if value > i32::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i32 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i32)
+}
+
+fn usize_to_i64(value: usize, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as usize {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 length limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
+fn u64_to_i64(value: u64, field: &str) -> io::Result<i64> {
+ if value > i64::MAX as u64 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("{} exceeds i64 offset limit: {}", field, value),
+ ));
+ }
+ Ok(value as i64)
+}
+
+const MAX_SECTION_ELEMENTS: usize = 1 << 30;
+
+fn checked_section_size(a: usize, b: usize) -> io::Result<usize> {
+ let result = a.checked_mul(b).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "section size overflow in IVF-FLAT header",
+ )
+ })?;
+ if result > MAX_SECTION_ELEMENTS {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "section size {} exceeds maximum {}",
+ result, MAX_SECTION_ELEMENTS
+ ),
+ ));
+ }
+ Ok(result)
+}
+
+fn checked_list_offset(offset: i64, list_id: usize) -> io::Result<u64> {
+ if offset < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("negative list offset {} at list {}", offset, list_id),
+ ));
+ }
+ Ok(offset as u64)
+}
+
+fn checked_list_bytes(count: usize, bytes_per_entry: usize) ->
io::Result<usize> {
+ count.checked_mul(bytes_per_entry).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "IVF-FLAT list byte size overflow",
+ )
+ })
+}
+
+fn read_f32_vec(reader: &mut dyn SeekRead, count: usize) ->
io::Result<Vec<f32>> {
+ let mut buf = vec![0u8; count * 4];
+ reader.read_exact(&mut buf)?;
+ bytes_to_f32_vec(&buf)
+}
+
+fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
+ if !bytes.len().is_multiple_of(4) {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "f32 byte section is not 4-byte aligned",
+ ));
+ }
+ Ok(bytes
+ .chunks_exact(4)
+ .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
+ .collect())
+}
+
+fn encode_varint(mut val: u64, buf: &mut Vec<u8>) {
+ while val >= 0x80 {
+ buf.push((val as u8) | 0x80);
+ val >>= 7;
+ }
+ buf.push(val as u8);
+}
+
+fn decode_varint(buf: &[u8], pos: &mut usize) -> io::Result<u64> {
+ let mut val = 0u64;
+ let mut shift = 0u32;
+ loop {
+ if *pos >= buf.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "truncated varint",
+ ));
+ }
+ let b = buf[*pos] as u64;
+ *pos += 1;
+ let payload = b & 0x7F;
+ if shift == 63 && payload > 1 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "varint exceeds u64 range",
+ ));
+ }
+ val |= payload << shift;
+ if b & 0x80 == 0 {
+ break;
+ }
+ shift += 7;
+ if shift > 63 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "varint exceeds 64 bits",
+ ));
+ }
+ }
+ Ok(val)
+}
+
+fn encode_delta_varint_ids(ids: &[i64]) -> (i64, Vec<u8>) {
+ if ids.is_empty() {
+ return (0, Vec::new());
+ }
+ let base = ids[0];
+ let mut buf = Vec::with_capacity(ids.len() * 2);
+ let mut prev = base;
+ for &id in ids {
+ let delta = (id as u64).wrapping_sub(prev as u64);
+ encode_varint(delta, &mut buf);
+ prev = id;
+ }
+ (base, buf)
+}
+
+fn decode_delta_varint_ids(base: i64, buf: &[u8], count: usize) ->
io::Result<Vec<i64>> {
+ let mut ids = Vec::with_capacity(count);
+ let mut pos = 0;
+ let mut current = base as u64;
+ let mut prev_signed = base;
+ for _ in 0..count {
+ let delta = decode_varint(buf, &mut pos)?;
+ current = current.wrapping_add(delta);
+ let id = current as i64;
+ if id < prev_signed {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "decoded ID sequence is not monotonically non-decreasing",
+ ));
+ }
+ prev_signed = id;
+ ids.push(id);
+ }
+ Ok(ids)
+}
+
+fn decode_roaring_filter(bytes: &[u8]) -> io::Result<RoaringTreemap> {
+ RoaringTreemap::deserialize_from(bytes).map_err(|e| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("invalid RoaringTreemap filter: {}", e),
+ )
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::distance::MetricType;
+ use crate::io::PosWriter;
+ use crate::ivfflat::IVFFlatIndex;
+ use std::io::Cursor;
+
+ #[test]
+ fn test_ivfflat_write_read_search_roundtrip() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut expected_distances = vec![0.0; 5];
+ let mut expected_labels = vec![0; 5];
+ index.search(
+ &data[7 * d..8 * d],
+ 1,
+ 5,
+ nlist,
+ &mut expected_distances,
+ &mut expected_labels,
+ );
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfflat_index(&index, &mut writer).unwrap();
+
+ let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let (labels, distances) = reader.search(&data[7 * d..8 * d], 5,
nlist).unwrap();
+
+ assert_eq!(labels, expected_labels);
+ assert_eq!(distances, expected_distances);
+ }
+
+ #[test]
+ fn test_ivfflat_reader_search_with_filter() {
+ use std::collections::HashSet;
+
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfflat_index(&index, &mut writer).unwrap();
+
+ let filter: HashSet<i64> = [12].into_iter().collect();
+ let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let (labels, distances) = reader
+ .search_with_filter(&[0.0, 0.0], 2, 1, Some(&filter))
+ .unwrap();
+
+ assert_eq!(labels, vec![12, -1]);
+ assert_eq!(distances[0], 200.0);
+ assert_eq!(distances[1], f32::MAX);
+ }
+
+ #[test]
+ fn test_ivfflat_reader_search_with_roaring_filter_bytes() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfflat_index(&index, &mut writer).unwrap();
+
+ let mut allowed = RoaringTreemap::new();
+ allowed.insert(12);
+ let mut filter_bytes = Vec::new();
+ allowed.serialize_into(&mut filter_bytes).unwrap();
+
+ let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let (labels, distances) = reader
+ .search_with_roaring_filter(&[0.0, 0.0], 2, 1, &filter_bytes)
+ .unwrap();
+
+ assert_eq!(labels, vec![12, -1]);
+ assert_eq!(distances[0], 200.0);
+ assert_eq!(distances[1], f32::MAX);
+ }
+
+ #[test]
+ fn test_ivfflat_batch_reader_matches_single_reader_search() {
+ let d = 4;
+ let nlist = 4;
+ let n = 128;
+ let data: Vec<f32> = (0..n)
+ .flat_map(|i| {
+ let cluster = (i % nlist) as f32 * 100.0;
+ [cluster + i as f32 * 0.01, 1.0, 2.0, 3.0]
+ })
+ .collect();
+ let ids: Vec<i64> = (1000..1000 + n as i64).collect();
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, n);
+ index.add(&data, &ids, n);
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfflat_index(&index, &mut writer).unwrap();
+
+ let queries = [&data[7 * d..8 * d], &data[63 * d..64 * d]].concat();
+ let k = 5;
+ let nprobe = 3;
+ let mut batch_reader =
IVFFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ let (batch_labels, batch_distances) =
+ search_batch_ivfflat_reader(&mut batch_reader, &queries, 2, k,
nprobe).unwrap();
+
+ for qi in 0..2 {
+ let mut single_reader =
IVFFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ let query = &queries[qi * d..(qi + 1) * d];
+ let (single_labels, single_distances) =
single_reader.search(query, k, nprobe).unwrap();
+ assert_eq!(&batch_labels[qi * k..(qi + 1) * k], single_labels);
+ assert_eq!(&batch_distances[qi * k..(qi + 1) * k],
single_distances);
+ }
+ }
+
+ #[test]
+ fn test_ivfflat_batch_reader_search_with_roaring_filter_bytes() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let ids = vec![10, 11, 12];
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, 3);
+ index.add(&data, &ids, 3);
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfflat_index(&index, &mut writer).unwrap();
+
+ let mut allowed = RoaringTreemap::new();
+ allowed.insert(12);
+ let mut filter_bytes = Vec::new();
+ allowed.serialize_into(&mut filter_bytes).unwrap();
+
+ let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ let queries = vec![0.0, 0.0, 10.0, 10.0];
+ let (labels, distances) = search_batch_ivfflat_reader_roaring_filter(
+ &mut reader,
+ &queries,
+ 2,
+ 2,
+ 1,
+ &filter_bytes,
+ )
+ .unwrap();
+
+ assert_eq!(labels, vec![12, -1, 12, -1]);
+ assert_eq!(distances, vec![200.0, f32::MAX, 0.0, f32::MAX]);
+ }
+
+ #[test]
+ fn test_ivfflat_reader_validates_inputs() {
+ let d = 2;
+ let nlist = 1;
+ let data = vec![0.0, 0.0, 1.0, 1.0];
+ let ids = vec![1, 2];
+
+ let mut index = IVFFlatIndex::new(d, nlist, MetricType::L2);
+ index.train(&data, 2);
+ index.add(&data, &ids, 2);
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ write_ivfflat_index(&index, &mut writer).unwrap();
+
+ let mut reader =
IVFFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ assert!(reader.search(&[0.0], 1, 1).is_err());
+
+ let mut reader =
IVFFlatIndexReader::open(Cursor::new(buf.clone())).unwrap();
+ assert!(reader.search(&[0.0, 0.0], 0, 1).is_err());
+
+ let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap();
+ assert!(reader.search(&[0.0, 0.0], 1, 0).is_err());
+ }
+
+ #[test]
+ fn test_ivfflat_writer_validates_shape_before_writing() {
+ let mut index = IVFFlatIndex::new(2, 1, MetricType::L2);
+ index.quantizer_centroids = vec![0.0, 0.0];
+ index.ids[0] = vec![1, 2];
+ index.vectors[0] = vec![0.0, 0.0];
+
+ let mut buf = Vec::new();
+ let mut writer = PosWriter::new(&mut buf);
+ let err = write_ivfflat_index(&index, &mut writer).unwrap_err();
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err.to_string().contains("vector length"));
+ }
+
+ #[test]
+ fn test_ivfflat_reader_rejects_bad_magic() {
+ let mut buf = vec![0u8; IVFFLAT_HEADER_SIZE];
+ buf[0..4].copy_from_slice(&0x12345678u32.to_le_bytes());
+
+ let err = match IVFFlatIndexReader::open(Cursor::new(buf)) {
+ Ok(_) => panic!("bad magic should be rejected"),
+ Err(err) => err,
+ };
+ assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
+ }
+}
diff --git a/core/src/lib.rs b/core/src/lib.rs
index 8be3e43..e34cd95 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -22,6 +22,8 @@ pub mod blas;
pub mod distance;
pub mod fastscan;
pub mod io;
+pub mod ivfflat;
+pub mod ivfflat_io;
pub mod ivfpq;
pub mod kmeans;
pub mod opq;
diff --git a/jni/java-test/org/apache/paimon/index/ivfpq/IVFPQJavaApiTest.java
b/jni/java-test/org/apache/paimon/index/ivfpq/IVFPQJavaApiTest.java
index 3b8998e..c5557ef 100644
--- a/jni/java-test/org/apache/paimon/index/ivfpq/IVFPQJavaApiTest.java
+++ b/jni/java-test/org/apache/paimon/index/ivfpq/IVFPQJavaApiTest.java
@@ -26,6 +26,7 @@ public class IVFPQJavaApiTest {
testSingleResultCopiesArrays();
testBatchResultCopiesArraysAndSlicesRows();
testReaderAndWriterApiCompile();
+ testFlatReaderAndWriterApiCompile();
}
private static void testMetricCodes() {
@@ -104,6 +105,31 @@ public class IVFPQJavaApiTest {
}
}
+ private static void testFlatReaderAndWriterApiCompile() {
+ IVFFlatReader closedReader =
IVFFlatReader.fromNativePointerForTesting(0L);
+ closedReader.close();
+ closedReader.close();
+
+ IVFFlatWriter closedWriter =
IVFFlatWriter.fromNativePointerForTesting(0L, 2);
+ closedWriter.close();
+ closedWriter.close();
+
+ if (System.currentTimeMillis() < 0) {
+ IVFFlatReader reader = new IVFFlatReader(new Object());
+ reader.dimension();
+ reader.totalVectors();
+ reader.search(new float[] {0.0f, 1.0f}, 10, 4);
+ reader.search(new float[] {0.0f, 1.0f}, 10, 4, new byte[] {1, 2});
+ reader.searchBatch(new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2, 10, 4);
+ reader.searchBatch(new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2, 10, 4,
new byte[] {1, 2});
+
+ IVFFlatWriter writer = new IVFFlatWriter(2, 4, Metric.L2);
+ writer.train(new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2);
+ writer.addVectors(new long[] {1L, 2L}, new float[] {0.0f, 1.0f,
2.0f, 3.0f}, 2);
+ writer.writeIndex(new Object());
+ }
+ }
+
private static void assertEquals(int expected, int actual) {
if (expected != actual) {
throw new AssertionError("expected " + expected + " but got " +
actual);
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFFlatNative.java
b/jni/java/org/apache/paimon/index/ivfpq/IVFFlatNative.java
new file mode 100644
index 0000000..aa20721
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFFlatNative.java
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+final class IVFFlatNative {
+
+ private IVFFlatNative() {}
+
+ static native long createWriter(int d, int nlist, int metric);
+
+ static native void train(long ptr, float[] data, int n);
+
+ static native void addVectors(long ptr, long[] ids, float[] data, int n);
+
+ static native void writeIndex(long ptr, Object streamOutput);
+
+ static native void freeWriter(long ptr);
+
+ static native long openReader(Object streamInput);
+
+ static native IVFPQResult search(long ptr, float[] query, int k, int
nprobe);
+
+ static native IVFPQResult searchWithRoaringFilter(
+ long ptr, float[] query, int k, int nprobe, byte[] roaringFilter);
+
+ static native IVFPQBatchResult searchBatch(
+ long ptr, float[] queries, int queryCount, int k, int nprobe);
+
+ static native IVFPQBatchResult searchBatchWithRoaringFilter(
+ long ptr, float[] queries, int queryCount, int k, int nprobe,
byte[] roaringFilter);
+
+ static native int getDimension(long ptr);
+
+ static native long getTotalVectors(long ptr);
+
+ static native void freeReader(long ptr);
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFFlatReader.java
b/jni/java/org/apache/paimon/index/ivfpq/IVFFlatReader.java
new file mode 100644
index 0000000..7c8ec57
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFFlatReader.java
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+public final class IVFFlatReader implements AutoCloseable {
+
+ private long nativePtr;
+
+ public IVFFlatReader(Object input) {
+ if (input == null) {
+ throw new NullPointerException("input");
+ }
+ this.nativePtr = IVFFlatNative.openReader(input);
+ }
+
+ private IVFFlatReader(long nativePtr) {
+ this.nativePtr = nativePtr;
+ }
+
+ static IVFFlatReader fromNativePointerForTesting(long nativePtr) {
+ return new IVFFlatReader(nativePtr);
+ }
+
+ public int dimension() {
+ return IVFFlatNative.getDimension(requireOpen());
+ }
+
+ public long totalVectors() {
+ return IVFFlatNative.getTotalVectors(requireOpen());
+ }
+
+ public IVFPQResult search(float[] query, int topK, int nprobe) {
+ if (query == null) {
+ throw new NullPointerException("query");
+ }
+ validatePositive(topK, "topK");
+ validatePositive(nprobe, "nprobe");
+ return IVFFlatNative.search(requireOpen(), query, topK, nprobe);
+ }
+
+ public IVFPQResult search(float[] query, int topK, int nprobe, byte[]
roaringFilter) {
+ if (query == null) {
+ throw new NullPointerException("query");
+ }
+ if (roaringFilter == null) {
+ throw new NullPointerException("roaringFilter");
+ }
+ validatePositive(topK, "topK");
+ validatePositive(nprobe, "nprobe");
+ return IVFFlatNative.searchWithRoaringFilter(
+ requireOpen(), query, topK, nprobe, roaringFilter);
+ }
+
+ public IVFPQBatchResult searchBatch(float[] queries, int queryCount, int
topK, int nprobe) {
+ if (queries == null) {
+ throw new NullPointerException("queries");
+ }
+ validatePositive(queryCount, "queryCount");
+ validatePositive(topK, "topK");
+ validatePositive(nprobe, "nprobe");
+ return IVFFlatNative.searchBatch(requireOpen(), queries, queryCount,
topK, nprobe);
+ }
+
+ public IVFPQBatchResult searchBatch(
+ float[] queries, int queryCount, int topK, int nprobe, byte[]
roaringFilter) {
+ if (queries == null) {
+ throw new NullPointerException("queries");
+ }
+ if (roaringFilter == null) {
+ throw new NullPointerException("roaringFilter");
+ }
+ validatePositive(queryCount, "queryCount");
+ validatePositive(topK, "topK");
+ validatePositive(nprobe, "nprobe");
+ return IVFFlatNative.searchBatchWithRoaringFilter(
+ requireOpen(), queries, queryCount, topK, nprobe,
roaringFilter);
+ }
+
+ @Override
+ public void close() {
+ long ptr = nativePtr;
+ nativePtr = 0L;
+ if (ptr != 0L) {
+ IVFFlatNative.freeReader(ptr);
+ }
+ }
+
+ private long requireOpen() {
+ if (nativePtr == 0L) {
+ throw new IllegalStateException("IVFFlatReader is closed");
+ }
+ return nativePtr;
+ }
+
+ private static void validatePositive(int value, String name) {
+ if (value <= 0) {
+ throw new IllegalArgumentException(name + " must be > 0");
+ }
+ }
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFFlatWriter.java
b/jni/java/org/apache/paimon/index/ivfpq/IVFFlatWriter.java
new file mode 100644
index 0000000..22f484d
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFFlatWriter.java
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+public final class IVFFlatWriter implements AutoCloseable {
+
+ private final int dimension;
+ private long nativePtr;
+
+ public IVFFlatWriter(int dimension, int nlist, Metric metric) {
+ if (metric == null) {
+ throw new NullPointerException("metric");
+ }
+ validatePositive(dimension, "dimension");
+ validatePositive(nlist, "nlist");
+ this.dimension = dimension;
+ this.nativePtr = IVFFlatNative.createWriter(dimension, nlist,
metric.code());
+ }
+
+ private IVFFlatWriter(long nativePtr, int dimension) {
+ this.nativePtr = nativePtr;
+ this.dimension = dimension;
+ }
+
+ static IVFFlatWriter fromNativePointerForTesting(long nativePtr, int
dimension) {
+ return new IVFFlatWriter(nativePtr, dimension);
+ }
+
+ public int dimension() {
+ return dimension;
+ }
+
+ public void train(float[] data, int vectorCount) {
+ validateVectors(data, vectorCount);
+ IVFFlatNative.train(requireOpen(), data, vectorCount);
+ }
+
+ public void addVectors(long[] ids, float[] data, int vectorCount) {
+ if (ids == null) {
+ throw new NullPointerException("ids");
+ }
+ validateVectors(data, vectorCount);
+ if (ids.length < vectorCount) {
+ throw new IllegalArgumentException(
+ "ids length " + ids.length + " < vectorCount " +
vectorCount);
+ }
+ IVFFlatNative.addVectors(requireOpen(), ids, data, vectorCount);
+ }
+
+ public void writeIndex(Object output) {
+ if (output == null) {
+ throw new NullPointerException("output");
+ }
+ IVFFlatNative.writeIndex(requireOpen(), output);
+ }
+
+ @Override
+ public void close() {
+ long ptr = nativePtr;
+ nativePtr = 0L;
+ if (ptr != 0L) {
+ IVFFlatNative.freeWriter(ptr);
+ }
+ }
+
+ private void validateVectors(float[] data, int vectorCount) {
+ if (data == null) {
+ throw new NullPointerException("data");
+ }
+ validatePositive(vectorCount, "vectorCount");
+ long expected = (long) vectorCount * (long) dimension;
+ if (expected > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("vectorCount * dimension
overflows int");
+ }
+ if (data.length < expected) {
+ throw new IllegalArgumentException(
+ "data length " + data.length + " < vectorCount * dimension
" + expected);
+ }
+ }
+
+ private long requireOpen() {
+ if (nativePtr == 0L) {
+ throw new IllegalStateException("IVFFlatWriter is closed");
+ }
+ return nativePtr;
+ }
+
+ private static void validatePositive(int value, String name) {
+ if (value <= 0) {
+ throw new IllegalArgumentException(name + " must be > 0");
+ }
+ }
+}
diff --git a/jni/src/lib.rs b/jni/src/lib.rs
index e7d3fd5..ca0b079 100644
--- a/jni/src/lib.rs
+++ b/jni/src/lib.rs
@@ -22,6 +22,11 @@ use jni::sys::{jboolean, jint, jlong, jobject};
use jni::JNIEnv;
use paimon_vindex_core::distance::MetricType;
use paimon_vindex_core::io::{write_index, IVFPQIndexReader};
+use paimon_vindex_core::ivfflat::IVFFlatIndex;
+use paimon_vindex_core::ivfflat_io::{
+ search_batch_ivfflat_reader, search_batch_ivfflat_reader_roaring_filter,
write_ivfflat_index,
+ IVFFlatIndexReader,
+};
use paimon_vindex_core::ivfpq::{
search_batch_reader, search_batch_reader_roaring_filter, IVFPQIndex,
};
@@ -48,6 +53,22 @@ fn deref_reader(ptr: jlong) -> Option<&'static mut
IVFPQIndexReader<JniSeekableS
}
}
+fn deref_flat_writer(ptr: jlong) -> Option<&'static mut IVFFlatIndex> {
+ if ptr == 0 {
+ None
+ } else {
+ Some(unsafe { &mut *(ptr as *mut IVFFlatIndex) })
+ }
+}
+
+fn deref_flat_reader(ptr: jlong) -> Option<&'static mut
IVFFlatIndexReader<JniSeekableStream>> {
+ if ptr == 0 {
+ None
+ } else {
+ Some(unsafe { &mut *(ptr as *mut
IVFFlatIndexReader<JniSeekableStream>) })
+ }
+}
+
fn read_byte_array(env: &mut JNIEnv, array: JByteArray) -> Result<Vec<u8>,
String> {
if array.as_raw().is_null() {
return Err("filter byte array is null".to_string());
@@ -592,3 +613,415 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_freeReader
}
}
}
+
+// --- IVF-FLAT Writer API ---
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_createWriter(
+ mut env: JNIEnv,
+ _class: JClass,
+ d: jint,
+ nlist: jint,
+ metric: jint,
+) -> jlong {
+ if d <= 0 || nlist <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: d={}, nlist={}", d, nlist),
+ );
+ }
+
+ let metric_type = match MetricType::from_code(metric as u32) {
+ Some(m) => m,
+ None => return throw_and_return(&mut env, &format!("Unknown metric:
{}", metric)),
+ };
+
+ let index = Box::new(IVFFlatIndex::new(d as usize, nlist as usize,
metric_type));
+ Box::into_raw(index) as jlong
+}
+
+#[no_mangle]
+pub extern "system" fn Java_org_apache_paimon_index_ivfpq_IVFFlatNative_train(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ data: JFloatArray,
+ n: jint,
+) {
+ let index = match deref_flat_writer(ptr) {
+ Some(i) => i,
+ None => return throw_and_return(&mut env, "null native pointer (writer
already freed?)"),
+ };
+ if n <= 0 {
+ return throw_and_return(&mut env, &format!("invalid n: {}", n));
+ }
+ let n = n as usize;
+ let len = match env.get_array_length(&data) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if len < n * index.d {
+ return throw_and_return(
+ &mut env,
+ &format!("data array too short: {} < n*d={}", len, n * index.d),
+ );
+ }
+ let mut buf = vec![0.0f32; len];
+ if let Err(e) = env.get_float_array_region(&data, 0, &mut buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+ index.train(&buf, n);
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_addVectors(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ ids: JLongArray,
+ data: JFloatArray,
+ n: jint,
+) {
+ let index = match deref_flat_writer(ptr) {
+ Some(i) => i,
+ None => return throw_and_return(&mut env, "null native pointer (writer
already freed?)"),
+ };
+ if n <= 0 {
+ return throw_and_return(&mut env, &format!("invalid n: {}", n));
+ }
+ let n = n as usize;
+ let id_len = match env.get_array_length(&ids) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if id_len < n {
+ return throw_and_return(
+ &mut env,
+ &format!("ids array too short: {} < n={}", id_len, n),
+ );
+ }
+ let mut id_buf = vec![0i64; n];
+ if let Err(e) = env.get_long_array_region(&ids, 0, &mut id_buf) {
+ return throw_and_return(&mut env, &format!("get_long_array_region:
{}", e));
+ }
+
+ let data_len = match env.get_array_length(&data) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if data_len < n * index.d {
+ return throw_and_return(
+ &mut env,
+ &format!("data array too short: {} < n*d={}", data_len, n *
index.d),
+ );
+ }
+ let mut data_buf = vec![0.0f32; data_len];
+ if let Err(e) = env.get_float_array_region(&data, 0, &mut data_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+ index.add(&data_buf, &id_buf, n);
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_writeIndex(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ stream_output: JObject,
+) {
+ let index = match deref_flat_writer(ptr) {
+ Some(i) => i,
+ None => return throw_and_return(&mut env, "null native pointer (writer
already freed?)"),
+ };
+ let jvm = match env.get_java_vm() {
+ Ok(vm) => vm,
+ Err(e) => return throw_and_return(&mut env, &format!("get_java_vm:
{}", e)),
+ };
+ let global_ref = match env.new_global_ref(stream_output) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("new_global_ref:
{}", e)),
+ };
+ let mut writer = JniOutputStream::new(jvm, global_ref);
+ if let Err(e) = write_ivfflat_index(index, &mut writer) {
+ throw_and_return::<()>(&mut env, &format!("write_ivfflat_index: {}",
e));
+ }
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_freeWriter(
+ _env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) {
+ if ptr != 0 {
+ unsafe {
+ drop(Box::from_raw(ptr as *mut IVFFlatIndex));
+ }
+ }
+}
+
+// --- IVF-FLAT Reader API ---
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_openReader(
+ mut env: JNIEnv,
+ _class: JClass,
+ stream_input: JObject,
+) -> jlong {
+ let jvm = match env.get_java_vm() {
+ Ok(vm) => vm,
+ Err(e) => return throw_and_return(&mut env, &format!("get_java_vm:
{}", e)),
+ };
+ let global_ref = match env.new_global_ref(stream_input) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("new_global_ref:
{}", e)),
+ };
+ let stream = JniSeekableStream::new(jvm, global_ref);
+ let reader = match IVFFlatIndexReader::open(stream) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("open: {}", e)),
+ };
+ Box::into_raw(Box::new(reader)) as jlong
+}
+
+#[no_mangle]
+pub extern "system" fn Java_org_apache_paimon_index_ivfpq_IVFFlatNative_search(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ query: JFloatArray,
+ k: jint,
+ nprobe: jint,
+) -> jobject {
+ let reader = match deref_flat_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ if k <= 0 || nprobe <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: k={}, nprobe={}", k, nprobe),
+ );
+ }
+ let d = reader.d;
+ let query_len = match env.get_array_length(&query) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if query_len != d {
+ return throw_and_return(
+ &mut env,
+ &format!("query array length {} != d={}", query_len, d),
+ );
+ }
+ let mut query_buf = vec![0.0f32; d];
+ if let Err(e) = env.get_float_array_region(&query, 0, &mut query_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+ let (ids, dists) = match reader.search(&query_buf, k as usize, nprobe as
usize) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("search: {}", e)),
+ };
+ build_result(&mut env, ids, dists)
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_searchWithRoaringFilter(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ query: JFloatArray,
+ k: jint,
+ nprobe: jint,
+ roaring_filter: JByteArray,
+) -> jobject {
+ let reader = match deref_flat_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ if k <= 0 || nprobe <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: k={}, nprobe={}", k, nprobe),
+ );
+ }
+ let d = reader.d;
+ let query_len = match env.get_array_length(&query) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if query_len != d {
+ return throw_and_return(
+ &mut env,
+ &format!("query array length {} != d={}", query_len, d),
+ );
+ }
+ let mut query_buf = vec![0.0f32; d];
+ if let Err(e) = env.get_float_array_region(&query, 0, &mut query_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+ let filter_bytes = match read_byte_array(&mut env, roaring_filter) {
+ Ok(bytes) => bytes,
+ Err(e) => return throw_and_return(&mut env, &e),
+ };
+ let (ids, dists) = match reader.search_with_roaring_filter(
+ &query_buf,
+ k as usize,
+ nprobe as usize,
+ &filter_bytes,
+ ) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env,
&format!("search_with_filter: {}", e)),
+ };
+ build_result(&mut env, ids, dists)
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_searchBatch(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ queries: JFloatArray,
+ nq: jint,
+ k: jint,
+ nprobe: jint,
+) -> jobject {
+ let reader = match deref_flat_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ if nq <= 0 || k <= 0 || nprobe <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: nq={}, k={}, nprobe={}", nq, k,
nprobe),
+ );
+ }
+
+ let d = reader.d;
+ let nq = nq as usize;
+ let k = k as usize;
+ let query_len = match env.get_array_length(&queries) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if query_len != nq * d {
+ return throw_and_return(
+ &mut env,
+ &format!("queries array length {} != nq*d={}", query_len, nq * d),
+ );
+ }
+
+ let mut query_buf = vec![0.0f32; nq * d];
+ if let Err(e) = env.get_float_array_region(&queries, 0, &mut query_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+
+ let (all_ids, all_dists) =
+ match search_batch_ivfflat_reader(reader, &query_buf, nq, k, nprobe as
usize) {
+ Ok(result) => result,
+ Err(e) => return throw_and_return(&mut env,
&format!("search_batch: {}", e)),
+ };
+
+ build_batch_result(&mut env, all_ids, all_dists, nq, k)
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_searchBatchWithRoaringFilter(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ queries: JFloatArray,
+ nq: jint,
+ k: jint,
+ nprobe: jint,
+ roaring_filter: JByteArray,
+) -> jobject {
+ let reader = match deref_flat_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ if nq <= 0 || k <= 0 || nprobe <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: nq={}, k={}, nprobe={}", nq, k,
nprobe),
+ );
+ }
+
+ let d = reader.d;
+ let nq = nq as usize;
+ let k = k as usize;
+ let query_len = match env.get_array_length(&queries) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if query_len != nq * d {
+ return throw_and_return(
+ &mut env,
+ &format!("queries array length {} != nq*d={}", query_len, nq * d),
+ );
+ }
+
+ let mut query_buf = vec![0.0f32; nq * d];
+ if let Err(e) = env.get_float_array_region(&queries, 0, &mut query_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+
+ let filter_bytes = match read_byte_array(&mut env, roaring_filter) {
+ Ok(bytes) => bytes,
+ Err(e) => return throw_and_return(&mut env, &e),
+ };
+ let (all_ids, all_dists) = match
search_batch_ivfflat_reader_roaring_filter(
+ reader,
+ &query_buf,
+ nq,
+ k,
+ nprobe as usize,
+ &filter_bytes,
+ ) {
+ Ok(result) => result,
+ Err(e) => return throw_and_return(&mut env,
&format!("search_batch_with_filter: {}", e)),
+ };
+
+ build_batch_result(&mut env, all_ids, all_dists, nq, k)
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_getDimension(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) -> jint {
+ let reader = match deref_flat_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ reader.d as jint
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_getTotalVectors(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) -> jlong {
+ let reader = match deref_flat_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ reader.total_vectors
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFFlatNative_freeReader(
+ _env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) {
+ if ptr != 0 {
+ unsafe {
+ drop(Box::from_raw(
+ ptr as *mut IVFFlatIndexReader<JniSeekableStream>,
+ ));
+ }
+ }
+}
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 694dd0b..9d8c492 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -22,6 +22,11 @@ use numpy::{
};
use paimon_vindex_core::distance::MetricType;
use paimon_vindex_core::io::{write_index, IVFPQIndexReader, SeekRead};
+use paimon_vindex_core::ivfflat::IVFFlatIndex;
+use paimon_vindex_core::ivfflat_io::{
+ search_batch_ivfflat_reader, search_batch_ivfflat_reader_roaring_filter,
write_ivfflat_index,
+ IVFFlatIndexReader,
+};
use paimon_vindex_core::ivfpq::{
search_batch_reader, search_batch_reader_roaring_filter, IVFPQIndex,
};
@@ -183,6 +188,17 @@ struct IVFPQWriter {
dimension: usize,
}
+#[pyclass]
+struct IVFFlatReader {
+ inner: IVFFlatIndexReader<PyFileStream>,
+}
+
+#[pyclass]
+struct IVFFlatWriter {
+ index: Option<IVFFlatIndex>,
+ dimension: usize,
+}
+
#[pymethods]
impl IVFPQReader {
#[new]
@@ -417,8 +433,230 @@ impl IVFPQWriter {
}
}
+#[pymethods]
+impl IVFFlatReader {
+ #[new]
+ fn new(file: PyObject) -> PyResult<Self> {
+ let stream = PyFileStream { file };
+ let reader = IVFFlatIndexReader::open(stream)
+ .map_err(|e| PyIOError::new_err(format!("Failed to open IVF-FLAT
index: {}", e)))?;
+ Ok(IVFFlatReader { inner: reader })
+ }
+
+ #[getter]
+ fn dimension(&self) -> usize {
+ self.inner.d
+ }
+
+ #[getter]
+ fn nlist(&self) -> usize {
+ self.inner.nlist
+ }
+
+ #[getter]
+ fn total_vectors(&self) -> i64 {
+ self.inner.total_vectors
+ }
+
+ #[allow(clippy::type_complexity)]
+ #[pyo3(signature = (query, top_k, nprobe, filter_bytes=None))]
+ fn search<'py>(
+ &mut self,
+ py: Python<'py>,
+ query: PyReadonlyArray1<f32>,
+ top_k: usize,
+ nprobe: usize,
+ filter_bytes: Option<&Bound<'_, PyAny>>,
+ ) -> PyResult<(Bound<'py, PyArray1<i64>>, Bound<'py, PyArray1<f32>>)> {
+ let query_slice = query.as_slice()?;
+ if query_slice.len() != self.inner.d {
+ return Err(PyValueError::new_err(format!(
+ "query length {} != index dimension {}",
+ query_slice.len(),
+ self.inner.d
+ )));
+ }
+ validate_positive(top_k, "top_k")?;
+ validate_positive(nprobe, "nprobe")?;
+
+ let (ids, dists) = if let Some(bytes) =
decode_filter_bytes(filter_bytes)? {
+ self.inner
+ .search_with_roaring_filter(query_slice, top_k, nprobe, bytes)
+ .map_err(|e| PyIOError::new_err(format!("Search failed: {}",
e)))?
+ } else {
+ self.inner
+ .search(query_slice, top_k, nprobe)
+ .map_err(|e| PyIOError::new_err(format!("Search failed: {}",
e)))?
+ };
+ Ok((
+ PyArray1::from_vec_bound(py, ids),
+ PyArray1::from_vec_bound(py, dists),
+ ))
+ }
+
+ #[allow(clippy::type_complexity)]
+ #[pyo3(signature = (queries, top_k, nprobe, filter_bytes=None))]
+ fn search_batch<'py>(
+ &mut self,
+ py: Python<'py>,
+ queries: PyReadonlyArray2<f32>,
+ top_k: usize,
+ nprobe: usize,
+ filter_bytes: Option<&Bound<'_, PyAny>>,
+ ) -> PyResult<(Bound<'py, PyArray2<i64>>, Bound<'py, PyArray2<f32>>)> {
+ let shape = queries.shape();
+ let query_count = validate_matrix_shape(shape, self.inner.d, "query",
"index dimension")?;
+ validate_positive(top_k, "top_k")?;
+ validate_positive(nprobe, "nprobe")?;
+
+ let query_slice = queries.as_slice().map_err(|_| {
+ PyValueError::new_err("queries must be a contiguous
two-dimensional float32 array")
+ })?;
+
+ let (ids, dists) = if let Some(bytes) =
decode_filter_bytes(filter_bytes)? {
+ search_batch_ivfflat_reader_roaring_filter(
+ &mut self.inner,
+ query_slice,
+ query_count,
+ top_k,
+ nprobe,
+ bytes,
+ )
+ .map_err(|e| PyIOError::new_err(format!("Batch search failed: {}",
e)))?
+ } else {
+ search_batch_ivfflat_reader(
+ &mut self.inner,
+ query_slice,
+ query_count,
+ top_k,
+ nprobe,
+ )
+ .map_err(|e| PyIOError::new_err(format!("Batch search failed: {}",
e)))?
+ };
+
+ Ok((
+ pyarray2_from_flat(py, ids, query_count, top_k)?,
+ pyarray2_from_flat(py, dists, query_count, top_k)?,
+ ))
+ }
+
+ fn close(&mut self) -> PyResult<()> {
+ Ok(())
+ }
+
+ fn __enter__(slf: Py<Self>) -> Py<Self> {
+ slf
+ }
+
+ #[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
+ fn __exit__(
+ &mut self,
+ _exc_type: Option<&Bound<'_, pyo3::types::PyType>>,
+ _exc_val: Option<&Bound<'_, pyo3::types::PyAny>>,
+ _exc_tb: Option<&Bound<'_, pyo3::types::PyAny>>,
+ ) -> PyResult<bool> {
+ self.close()?;
+ Ok(false)
+ }
+}
+
+#[pymethods]
+impl IVFFlatWriter {
+ #[new]
+ #[pyo3(signature = (dimension, nlist, metric="l2"))]
+ fn new(dimension: usize, nlist: usize, metric: &str) -> PyResult<Self> {
+ validate_positive(dimension, "dimension")?;
+ validate_positive(nlist, "nlist")?;
+ let metric = parse_metric(metric)?;
+ Ok(IVFFlatWriter {
+ index: Some(IVFFlatIndex::new(dimension, nlist, metric)),
+ dimension,
+ })
+ }
+
+ #[getter]
+ fn dimension(&self) -> usize {
+ self.dimension
+ }
+
+ fn train(&mut self, data: PyReadonlyArray2<f32>) -> PyResult<()> {
+ let shape = data.shape();
+ let row_count = validate_matrix_shape(shape, self.dimension, "data",
"writer dimension")?;
+ let data_slice = data.as_slice().map_err(|_| {
+ PyValueError::new_err("data must be a contiguous two-dimensional
float32 array")
+ })?;
+ self.index_mut()?.train(data_slice, row_count);
+ Ok(())
+ }
+
+ fn add_vectors(
+ &mut self,
+ ids: PyReadonlyArray1<i64>,
+ data: PyReadonlyArray2<f32>,
+ ) -> PyResult<()> {
+ let shape = data.shape();
+ let row_count = validate_matrix_shape(shape, self.dimension, "data",
"writer dimension")?;
+ let id_slice = ids.as_slice()?;
+ if id_slice.len() != row_count {
+ return Err(PyValueError::new_err(format!(
+ "ids length {} != vector count {}",
+ id_slice.len(),
+ row_count
+ )));
+ }
+ let data_slice = data.as_slice().map_err(|_| {
+ PyValueError::new_err("data must be a contiguous two-dimensional
float32 array")
+ })?;
+ self.index_mut()?.add(data_slice, id_slice, row_count);
+ Ok(())
+ }
+
+ fn write(&mut self, file: PyObject) -> PyResult<()> {
+ let mut stream = PyOutputStream { file, pos: 0 };
+ write_ivfflat_index(self.index_ref()?, &mut stream)
+ .map_err(|e| PyIOError::new_err(format!("Failed to write IVF-FLAT
index: {}", e)))?;
+ Ok(())
+ }
+
+ fn close(&mut self) -> PyResult<()> {
+ self.index = None;
+ Ok(())
+ }
+
+ fn __enter__(slf: Py<Self>) -> Py<Self> {
+ slf
+ }
+
+ #[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
+ fn __exit__(
+ &mut self,
+ _exc_type: Option<&Bound<'_, pyo3::types::PyType>>,
+ _exc_val: Option<&Bound<'_, pyo3::types::PyAny>>,
+ _exc_tb: Option<&Bound<'_, pyo3::types::PyAny>>,
+ ) -> PyResult<bool> {
+ self.close()?;
+ Ok(false)
+ }
+}
+
+impl IVFFlatWriter {
+ fn index_ref(&self) -> PyResult<&IVFFlatIndex> {
+ self.index
+ .as_ref()
+ .ok_or_else(|| PyValueError::new_err("IVFFlatWriter is closed"))
+ }
+
+ fn index_mut(&mut self) -> PyResult<&mut IVFFlatIndex> {
+ self.index
+ .as_mut()
+ .ok_or_else(|| PyValueError::new_err("IVFFlatWriter is closed"))
+ }
+}
+
#[pymodule]
fn paimon_vindex(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
+ m.add_class::<IVFFlatReader>()?;
+ m.add_class::<IVFFlatWriter>()?;
m.add_class::<IVFPQReader>()?;
m.add_class::<IVFPQWriter>()?;
Ok(())
@@ -570,6 +808,119 @@ mod tests {
});
}
+ #[test]
+ fn python_flat_writer_can_build_an_index_for_reader() {
+ Python::with_gil(|py| {
+ let io = py.import_bound("io").unwrap();
+ let output = io.getattr("BytesIO").unwrap().call0().unwrap();
+ let mut writer = IVFFlatWriter::new(16, 4, "l2").unwrap();
+ let data = generate_clustered_data(500, 16, 4);
+ let ids: Vec<i64> = (0..500).collect();
+
+ let train = PyArray::from_vec2_bound(
+ py,
+ &data
+ .chunks(16)
+ .map(|chunk| chunk.to_vec())
+ .collect::<Vec<_>>(),
+ )
+ .unwrap();
+ let id_array = PyArray1::from_vec_bound(py, ids);
+
+ writer.train(train.readonly()).unwrap();
+ writer
+ .add_vectors(id_array.readonly(), train.readonly())
+ .unwrap();
+ writer.write(output.as_any().clone().unbind()).unwrap();
+
+ output.call_method1("seek", (0,)).unwrap();
+ let mut reader = IVFFlatReader::new(output.unbind()).unwrap();
+ let query = PyArray1::from_vec_bound(py, data[0..16].to_vec());
+
+ let (result_ids, _) = reader.search(py, query.readonly(), 5, 2,
None).unwrap();
+
+ assert_eq!(result_ids.len(), 5);
+ assert_eq!(result_ids.readonly().as_slice().unwrap()[0], 0);
+ });
+ }
+
+ #[test]
+ fn python_flat_reader_accepts_roaring_filter_bytes() {
+ Python::with_gil(|py| {
+ let io = py.import_bound("io").unwrap();
+ let output = io.getattr("BytesIO").unwrap().call0().unwrap();
+ let mut writer = IVFFlatWriter::new(2, 1, "l2").unwrap();
+ let data = vec![0.0f32, 0.0, 0.1, 0.0, 10.0, 10.0];
+ let train = PyArray::from_vec2_bound(py, &[vec![0.0f32, 0.0],
vec![0.1, 0.0], vec![10.0, 10.0]])
+ .unwrap();
+ let id_array = PyArray1::from_vec_bound(py, vec![10i64, 11, 12]);
+
+ writer.train(train.readonly()).unwrap();
+ writer
+ .add_vectors(id_array.readonly(), train.readonly())
+ .unwrap();
+ writer.write(output.as_any().clone().unbind()).unwrap();
+
+ let mut allowed = RoaringTreemap::new();
+ allowed.insert(12);
+ let mut filter_bytes = Vec::new();
+ allowed.serialize_into(&mut filter_bytes).unwrap();
+ let filter = PyBytes::new_bound(py, &filter_bytes);
+
+ output.call_method1("seek", (0,)).unwrap();
+ let mut reader = IVFFlatReader::new(output.unbind()).unwrap();
+ let query = PyArray1::from_vec_bound(py, data[0..2].to_vec());
+ let (result_ids, result_dists) = reader
+ .search(py, query.readonly(), 2, 1, Some(filter.as_any()))
+ .unwrap();
+
+ assert_eq!(result_ids.readonly().as_slice().unwrap(), &[12, -1]);
+ assert_eq!(result_dists.readonly().as_slice().unwrap()[1],
f32::MAX);
+ });
+ }
+
+ #[test]
+ fn python_flat_batch_search_accepts_roaring_filter_bytes() {
+ Python::with_gil(|py| {
+ let io = py.import_bound("io").unwrap();
+ let output = io.getattr("BytesIO").unwrap().call0().unwrap();
+ let mut writer = IVFFlatWriter::new(2, 1, "l2").unwrap();
+ let train = PyArray::from_vec2_bound(
+ py,
+ &[vec![0.0f32, 0.0], vec![0.1, 0.0], vec![10.0, 10.0]],
+ )
+ .unwrap();
+ let id_array = PyArray1::from_vec_bound(py, vec![10i64, 11, 12]);
+
+ writer.train(train.readonly()).unwrap();
+ writer
+ .add_vectors(id_array.readonly(), train.readonly())
+ .unwrap();
+ writer.write(output.as_any().clone().unbind()).unwrap();
+
+ let mut allowed = RoaringTreemap::new();
+ allowed.insert(12);
+ let mut filter_bytes = Vec::new();
+ allowed.serialize_into(&mut filter_bytes).unwrap();
+ let filter = PyBytes::new_bound(py, &filter_bytes);
+
+ output.call_method1("seek", (0,)).unwrap();
+ let mut reader = IVFFlatReader::new(output.unbind()).unwrap();
+ let queries =
+ PyArray::from_vec2_bound(py, &[vec![0.0f32, 0.0], vec![10.0,
10.0]]).unwrap();
+ let (result_ids, result_dists) = reader
+ .search_batch(py, queries.readonly(), 2, 1,
Some(filter.as_any()))
+ .unwrap();
+
+ assert_eq!(result_ids.shape(), &[2, 2]);
+ assert_eq!(
+ result_ids.readonly().as_slice().unwrap(),
+ &[12, -1, 12, -1]
+ );
+ assert_eq!(result_dists.readonly().as_slice().unwrap()[1],
f32::MAX);
+ });
+ }
+
#[test]
fn python_batch_search_validates_query_shape() {
Python::with_gil(|py| {