This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new eddf3d4  Add Roaring bitmap filter pushdown (#15)
eddf3d4 is described below

commit eddf3d4a8addfce1ade2b200e8ec451fdb47164b
Author: Jingsong Lee <[email protected]>
AuthorDate: Mon Jun 8 20:25:54 2026 +0800

    Add Roaring bitmap filter pushdown (#15)
---
 Cargo.lock        |  17 ++++
 README.md         |  12 +++
 core/Cargo.toml   |   1 +
 core/src/io.rs    |  17 ++++
 core/src/ivfpq.rs | 264 +++++++++++++++++++++++++++++++++++++++++++++++++++---
 jni/src/lib.rs    | 229 +++++++++++++++++++++++++++++++++++++---------
 python/Cargo.lock |  17 ++++
 python/src/lib.rs |  30 ++++---
 8 files changed, 520 insertions(+), 67 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 0248744..d6560e3 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -50,6 +50,12 @@ version = "1.25.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec"
 
+[[package]]
+name = "byteorder"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
+
 [[package]]
 name = "bytes"
 version = "1.11.1"
@@ -467,6 +473,7 @@ dependencies = [
  "nalgebra",
  "rand",
  "rayon",
+ "roaring",
 ]
 
 [[package]]
@@ -629,6 +636,16 @@ version = "0.8.10"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a"
 
+[[package]]
+name = "roaring"
+version = "0.11.4"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "1dedc5658c6ecb3bdb5ef5f3295bb9253f42dcf3fd1402c03f6b1f7659c3c4a9"
+dependencies = [
+ "bytemuck",
+ "byteorder",
+]
+
 [[package]]
 name = "rustversion"
 version = "1.0.22"
diff --git a/README.md b/README.md
index 9c7b47a..2826912 100644
--- a/README.md
+++ b/README.md
@@ -24,6 +24,18 @@
 
 Pure Rust IVF-PQ implementation for Apache Paimon. Designed for data lake 
(S3/HDFS/OSS) with seek-based I/O, supporting both 8-bit and 4-bit PQ with SIMD 
acceleration.
 
+## Metadata Filter Pushdown
+
+The vector index accepts a serialized 64-bit Roaring bitmap of allowed row IDs 
during reader search. This lets the Paimon query layer evaluate metadata 
predicates with table/scalar indexes first, then pass the matching row-id set 
into IVF-PQ as an ANN prefilter.
+
+Bindings expose the same wire format:
+
+- Rust core: `search_with_reader_roaring_filter` and 
`search_batch_reader_roaring_filter`
+- JNI: `searchWithRoaringFilter` and `searchBatchWithRoaringFilter` with 
`byte[]`
+- Python: `IVFPQReader.search(..., filter_bytes=...)`
+
+Row IDs must be non-negative to map directly into `RoaringTreemap`'s `u64` 
domain.
+
 ## Contributing
 
 Apache Paimon Vector Index is an exciting project currently under active 
development. Whether you're looking to use it in your projects or contribute to 
its growth, there are several ways you can get involved:
diff --git a/core/Cargo.toml b/core/Cargo.toml
index 120be93..0b51ead 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -26,6 +26,7 @@ nalgebra = "0.33"
 rand = "0.8"
 rayon = "1.10"
 matrixmultiply = "0.3"
+roaring = "0.11"
 
 [dev-dependencies]
 criterion = "0.5"
diff --git a/core/src/io.rs b/core/src/io.rs
index bf1997d..dc8d8bc 100644
--- a/core/src/io.rs
+++ b/core/src/io.rs
@@ -799,6 +799,23 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
         crate::ivfpq::search_with_reader(self, query, k, nprobe)
     }
 
+    pub fn search_with_roaring_filter(
+        &mut self,
+        query: &[f32],
+        k: usize,
+        nprobe: usize,
+        roaring_filter_bytes: &[u8],
+    ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        self.ensure_loaded()?;
+        crate::ivfpq::search_with_reader_roaring_filter(
+            self,
+            query,
+            k,
+            nprobe,
+            roaring_filter_bytes,
+        )
+    }
+
     pub fn supports_concurrent_pread(&self) -> bool {
         self.reader.supports_concurrent_pread()
     }
diff --git a/core/src/ivfpq.rs b/core/src/ivfpq.rs
index e9c0974..61c5363 100644
--- a/core/src/ivfpq.rs
+++ b/core/src/ivfpq.rs
@@ -23,9 +23,35 @@ use crate::kmeans::{self, KMeansConfig};
 use crate::opq::OPQMatrix;
 use crate::pq::ProductQuantizer;
 use rayon::prelude::*;
+use roaring::RoaringTreemap;
 use std::collections::{HashMap, HashSet};
 use std::io;
 
+pub trait RowIdFilter: Sync {
+    fn contains(&self, id: i64) -> bool;
+}
+
+impl RowIdFilter for HashSet<i64> {
+    fn contains(&self, id: i64) -> bool {
+        HashSet::contains(self, &id)
+    }
+}
+
+impl RowIdFilter for RoaringTreemap {
+    fn contains(&self, id: i64) -> bool {
+        id >= 0 && RoaringTreemap::contains(self, id as u64)
+    }
+}
+
+fn decode_roaring_filter(bytes: &[u8]) -> io::Result<RoaringTreemap> {
+    RoaringTreemap::deserialize_from(bytes).map_err(|e| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!("invalid RoaringTreemap filter: {}", e),
+        )
+    })
+}
+
 /// IVF-PQ index aligned with Faiss's IndexIVFPQ.
 pub struct IVFPQIndex {
     pub d: usize,
@@ -330,7 +356,7 @@ impl IVFPQIndex {
         nq: usize,
         k: usize,
         nprobe: usize,
-        filter: Option<&HashSet<i64>>,
+        filter: Option<&dyn RowIdFilter>,
         result_distances: &mut [f32],
         result_labels: &mut [i64],
     ) {
@@ -407,7 +433,7 @@ impl IVFPQIndex {
                         );
                         for i in 0..count {
                             if let Some(f) = filter {
-                                if !f.contains(&self.ids[list_id][i]) {
+                                if !f.contains(self.ids[list_id][i]) {
                                     continue;
                                 }
                             }
@@ -655,7 +681,7 @@ fn scan_codes_4bit(
     m: usize,
     _ksub: usize,
     dis0: f32,
-    filter: Option<&HashSet<i64>>,
+    filter: Option<&dyn RowIdFilter>,
     heap: &mut TopKHeap,
 ) {
     let mut dists = vec![0.0f32; count];
@@ -663,7 +689,7 @@ fn scan_codes_4bit(
 
     for i in 0..count {
         if let Some(f) = filter {
-            if !f.contains(&ids[i]) {
+            if !f.contains(ids[i]) {
                 continue;
             }
         }
@@ -680,7 +706,7 @@ fn scan_codes_4bit_transposed(
     count: usize,
     m: usize,
     dis0: f32,
-    filter: Option<&HashSet<i64>>,
+    filter: Option<&dyn RowIdFilter>,
     heap: &mut TopKHeap,
 ) {
     let cs = m / 2;
@@ -736,7 +762,7 @@ fn scan_codes_4bit_transposed(
 
     for i in 0..count {
         if let Some(f) = filter {
-            if !f.contains(&ids[i]) {
+            if !f.contains(ids[i]) {
                 continue;
             }
         }
@@ -754,7 +780,7 @@ fn scan_codes_transposed(
     m: usize,
     ksub: usize,
     dis0: f32,
-    filter: Option<&HashSet<i64>>,
+    filter: Option<&dyn RowIdFilter>,
     heap: &mut TopKHeap,
 ) {
     let mut dists = vec![dis0; count];
@@ -768,7 +794,7 @@ fn scan_codes_transposed(
 
     for i in 0..count {
         if let Some(f) = filter {
-            if !f.contains(&ids[i]) {
+            if !f.contains(ids[i]) {
                 continue;
             }
         }
@@ -785,7 +811,7 @@ fn scan_codes_batched(
     m: usize,
     ksub: usize,
     dis0: f32,
-    filter: Option<&HashSet<i64>>,
+    filter: Option<&dyn RowIdFilter>,
     heap: &mut TopKHeap,
 ) {
     let mut i = 0;
@@ -803,7 +829,7 @@ fn scan_codes_batched(
             let idx = i + j;
             let id = ids[idx];
             if let Some(f) = filter {
-                if !f.contains(&id) {
+                if !f.contains(id) {
                     continue;
                 }
             }
@@ -817,7 +843,7 @@ fn scan_codes_batched(
         let dist = dis0 + pq_distance_from_table(sim_table, code, m, ksub);
         let id = ids[i];
         if let Some(f) = filter {
-            if !f.contains(&id) {
+            if !f.contains(id) {
                 i += 1;
                 continue;
             }
@@ -839,7 +865,7 @@ struct ReaderSearchContext<'a> {
     q: &'a [f32],
     ip_table: &'a [f32],
     use_precomputed: bool,
-    filter: Option<&'a HashSet<i64>>,
+    filter: Option<&'a dyn RowIdFilter>,
     d: usize,
     m: usize,
     ksub: usize,
@@ -867,7 +893,7 @@ pub fn search_with_reader_filter<R: SeekRead>(
     query: &[f32],
     k: usize,
     nprobe: usize,
-    filter: Option<&HashSet<i64>>,
+    filter: Option<&dyn RowIdFilter>,
 ) -> io::Result<(Vec<i64>, Vec<f32>)> {
     reader.ensure_loaded()?;
     let d = reader.d;
@@ -1022,6 +1048,18 @@ pub fn search_with_reader_filter<R: SeekRead>(
     Ok((result_ids, result_dists))
 }
 
+/// Search with a cross-language serialized RoaringTreemap row-id filter.
+pub fn search_with_reader_roaring_filter<R: SeekRead>(
+    reader: &mut IVFPQIndexReader<R>,
+    query: &[f32],
+    k: usize,
+    nprobe: usize,
+    roaring_filter_bytes: &[u8],
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    let filter = decode_roaring_filter(roaring_filter_bytes)?;
+    search_with_reader_filter(reader, query, k, nprobe, Some(&filter))
+}
+
 fn scan_reader_list(entry: &PreReadList, ctx: &ReaderSearchContext<'_>, heap: 
&mut TopKHeap) {
     let d = ctx.d;
     let m = ctx.m;
@@ -1107,6 +1145,18 @@ pub fn search_batch_reader<R: SeekRead>(
     nq: usize,
     k: usize,
     nprobe: usize,
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    search_batch_reader_filter(reader, queries, nq, k, nprobe, None)
+}
+
+/// Big batch search with an optional row-id filter.
+pub fn search_batch_reader_filter<R: SeekRead>(
+    reader: &mut IVFPQIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    filter: Option<&dyn RowIdFilter>,
 ) -> io::Result<(Vec<i64>, Vec<f32>)> {
     reader.ensure_loaded()?;
     let d = reader.d;
@@ -1234,7 +1284,7 @@ pub fn search_batch_reader<R: SeekRead>(
                     &[]
                 },
                 use_precomputed,
-                filter: None,
+                filter,
                 d,
                 m,
                 ksub,
@@ -1265,6 +1315,19 @@ pub fn search_batch_reader<R: SeekRead>(
     Ok((result_ids, result_dists))
 }
 
+/// Big batch search with a cross-language serialized RoaringTreemap row-id 
filter.
+pub fn search_batch_reader_roaring_filter<R: SeekRead>(
+    reader: &mut IVFPQIndexReader<R>,
+    queries: &[f32],
+    nq: usize,
+    k: usize,
+    nprobe: usize,
+    roaring_filter_bytes: &[u8],
+) -> io::Result<(Vec<i64>, Vec<f32>)> {
+    let filter = decode_roaring_filter(roaring_filter_bytes)?;
+    search_batch_reader_filter(reader, queries, nq, k, nprobe, Some(&filter))
+}
+
 // --- Top-K Heap ---
 
 struct TopKHeap {
@@ -1987,6 +2050,72 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_reader_search_with_roaring_filter_bytes() {
+        use crate::io::{write_index, IVFPQIndexReader, PosWriter};
+        use roaring::RoaringTreemap;
+
+        let d = 16;
+        let nlist = 4;
+        let m = 4;
+        let n = 500;
+        let k = 5;
+
+        let data = generate_clustered_data(n, d, 4, 789);
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_index(&index, &mut writer).unwrap();
+
+        let mut allowed = RoaringTreemap::new();
+        for id in (0..n as u64).filter(|id| id % 5 == 0) {
+            allowed.insert(id);
+        }
+        let mut filter_bytes = Vec::new();
+        allowed.serialize_into(&mut filter_bytes).unwrap();
+
+        let mut reader = IVFPQIndexReader::open(Cursor::new(buf)).unwrap();
+        let (result_ids, _) =
+            search_with_reader_roaring_filter(&mut reader, &data[0..d], k, 4, 
&filter_bytes)
+                .unwrap();
+
+        for &id in &result_ids {
+            assert_eq!(id % 5, 0, "Roaring filter violated: got ID {}", id);
+        }
+    }
+
+    #[test]
+    fn test_reader_search_rejects_invalid_roaring_filter_bytes() {
+        use crate::io::{write_index, IVFPQIndexReader, PosWriter};
+
+        let d = 16;
+        let nlist = 4;
+        let m = 4;
+        let n = 500;
+
+        let data = generate_clustered_data(n, d, 4, 789);
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_index(&index, &mut writer).unwrap();
+
+        let mut reader = IVFPQIndexReader::open(Cursor::new(buf)).unwrap();
+        let err = search_with_reader_roaring_filter(&mut reader, &data[0..d], 
5, 4, b"not roaring")
+            .unwrap_err();
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+    }
+
     #[test]
     fn test_big_batch_search() {
         use crate::io::{write_index, IVFPQIndexReader, PosWriter};
@@ -2069,6 +2198,113 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_batch_reader_search_with_roaring_filter_bytes() {
+        use crate::io::{write_index, IVFPQIndexReader, PosWriter};
+        use roaring::RoaringTreemap;
+        use std::io::Cursor;
+
+        let d = 16;
+        let nlist = 8;
+        let m = 4;
+        let n = 1000;
+        let k = 5;
+        let nq = 12;
+        let nprobe = 3;
+
+        let data = generate_clustered_data(n, d, 8, 42);
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_index(&index, &mut writer).unwrap();
+
+        let mut allowed = RoaringTreemap::new();
+        for id in (0..n as u64).filter(|id| id % 7 == 0) {
+            allowed.insert(id);
+        }
+        let mut filter_bytes = Vec::new();
+        allowed.serialize_into(&mut filter_bytes).unwrap();
+
+        let queries = &data[..nq * d];
+        let mut batch_reader = 
IVFPQIndexReader::open(Cursor::new(buf.clone())).unwrap();
+        let (batch_ids, batch_dists) = search_batch_reader_roaring_filter(
+            &mut batch_reader,
+            queries,
+            nq,
+            k,
+            nprobe,
+            &filter_bytes,
+        )
+        .unwrap();
+
+        for qi in 0..nq {
+            let base = qi * k;
+            for &id in &batch_ids[base..base + k] {
+                if id >= 0 {
+                    assert_eq!(id % 7, 0, "Roaring filter violated: got ID 
{}", id);
+                }
+            }
+
+            let mut single_reader = 
IVFPQIndexReader::open(Cursor::new(buf.clone())).unwrap();
+            let query = &queries[qi * d..(qi + 1) * d];
+            let (single_ids, single_dists) = search_with_reader_roaring_filter(
+                &mut single_reader,
+                query,
+                k,
+                nprobe,
+                &filter_bytes,
+            )
+            .unwrap();
+
+            assert_eq!(&batch_ids[base..base + k], &single_ids[..]);
+            assert_eq!(&batch_dists[base..base + k], &single_dists[..]);
+        }
+    }
+
+    #[test]
+    fn test_batch_reader_empty_roaring_filter_returns_empty_results() {
+        use crate::io::{write_index, IVFPQIndexReader, PosWriter};
+        use roaring::RoaringTreemap;
+        use std::io::Cursor;
+
+        let d = 16;
+        let nlist = 4;
+        let m = 4;
+        let n = 500;
+        let k = 5;
+        let nq = 4;
+        let nprobe = 2;
+
+        let data = generate_clustered_data(n, d, 4, 42);
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_index(&index, &mut writer).unwrap();
+
+        let empty = RoaringTreemap::new();
+        let mut filter_bytes = Vec::new();
+        empty.serialize_into(&mut filter_bytes).unwrap();
+
+        let queries = &data[..nq * d];
+        let mut reader = IVFPQIndexReader::open(Cursor::new(buf)).unwrap();
+        let (batch_ids, batch_dists) =
+            search_batch_reader_roaring_filter(&mut reader, queries, nq, k, 
nprobe, &filter_bytes)
+                .unwrap();
+
+        assert!(batch_ids.iter().all(|&id| id == -1));
+        assert!(batch_dists.iter().all(|&dist| dist == f32::MAX));
+    }
+
     #[test]
     fn test_batch_reader_validates_inputs() {
         use crate::io::{write_index, IVFPQIndexReader, PosWriter};
diff --git a/jni/src/lib.rs b/jni/src/lib.rs
index 6dd5190..e7d3fd5 100644
--- a/jni/src/lib.rs
+++ b/jni/src/lib.rs
@@ -17,12 +17,14 @@
 
 mod stream;
 
-use jni::objects::{JClass, JFloatArray, JLongArray, JObject, JValue};
+use jni::objects::{JByteArray, JClass, JFloatArray, JLongArray, JObject, 
JValue};
 use jni::sys::{jboolean, jint, jlong, jobject};
 use jni::JNIEnv;
 use paimon_vindex_core::distance::MetricType;
 use paimon_vindex_core::io::{write_index, IVFPQIndexReader};
-use paimon_vindex_core::ivfpq::{search_batch_reader, IVFPQIndex};
+use paimon_vindex_core::ivfpq::{
+    search_batch_reader, search_batch_reader_roaring_filter, IVFPQIndex,
+};
 use stream::{JniOutputStream, JniSeekableStream};
 
 fn throw_and_return<T: Default>(env: &mut JNIEnv, msg: &str) -> T {
@@ -46,6 +48,86 @@ fn deref_reader(ptr: jlong) -> Option<&'static mut 
IVFPQIndexReader<JniSeekableS
     }
 }
 
+fn read_byte_array(env: &mut JNIEnv, array: JByteArray) -> Result<Vec<u8>, 
String> {
+    if array.as_raw().is_null() {
+        return Err("filter byte array is null".to_string());
+    }
+
+    env.convert_byte_array(array)
+        .map_err(|e| format!("convert_byte_array: {}", e))
+}
+
+fn build_result(env: &mut JNIEnv, ids: Vec<i64>, dists: Vec<f32>) -> jobject {
+    let id_array = match env.new_long_array(ids.len() as i32) {
+        Ok(a) => a,
+        Err(e) => return throw_and_return(env, &format!("new_long_array: {}", 
e)),
+    };
+    let _ = env.set_long_array_region(&id_array, 0, &ids);
+
+    let dist_array = match env.new_float_array(dists.len() as i32) {
+        Ok(a) => a,
+        Err(e) => return throw_and_return(env, &format!("new_float_array: {}", 
e)),
+    };
+    let _ = env.set_float_array_region(&dist_array, 0, &dists);
+
+    let result_class = match 
env.find_class("org/apache/paimon/index/ivfpq/IVFPQResult") {
+        Ok(c) => c,
+        Err(e) => return throw_and_return(env, &format!("find_class: {}", e)),
+    };
+
+    let result = match env.new_object(
+        result_class,
+        "([J[F)V",
+        &[JValue::Object(&id_array), JValue::Object(&dist_array)],
+    ) {
+        Ok(r) => r,
+        Err(e) => return throw_and_return(env, &format!("new_object: {}", e)),
+    };
+
+    result.into_raw()
+}
+
+fn build_batch_result(
+    env: &mut JNIEnv,
+    ids: Vec<i64>,
+    dists: Vec<f32>,
+    nq: usize,
+    k: usize,
+) -> jobject {
+    let id_array = match env.new_long_array((nq * k) as i32) {
+        Ok(a) => a,
+        Err(e) => return throw_and_return(env, &format!("new_long_array: {}", 
e)),
+    };
+    let _ = env.set_long_array_region(&id_array, 0, &ids);
+
+    let dist_array = match env.new_float_array((nq * k) as i32) {
+        Ok(a) => a,
+        Err(e) => return throw_and_return(env, &format!("new_float_array: {}", 
e)),
+    };
+    let _ = env.set_float_array_region(&dist_array, 0, &dists);
+
+    let result_class = match 
env.find_class("org/apache/paimon/index/ivfpq/IVFPQBatchResult") {
+        Ok(c) => c,
+        Err(e) => return throw_and_return(env, &format!("find_class: {}", e)),
+    };
+
+    let result = match env.new_object(
+        result_class,
+        "([J[FII)V",
+        &[
+            JValue::Object(&id_array),
+            JValue::Object(&dist_array),
+            JValue::Int(nq as jint),
+            JValue::Int(k as jint),
+        ],
+    ) {
+        Ok(r) => r,
+        Err(e) => return throw_and_return(env, &format!("new_object: {}", e)),
+    };
+
+    result.into_raw()
+}
+
 // --- Writer API ---
 
 #[no_mangle]
@@ -292,33 +374,64 @@ pub extern "system" fn 
Java_org_apache_paimon_index_ivfpq_IVFPQNative_search(
         Err(e) => return throw_and_return(&mut env, &format!("search: {}", e)),
     };
 
-    let id_array = match env.new_long_array(ids.len() as i32) {
-        Ok(a) => a,
-        Err(e) => return throw_and_return(&mut env, &format!("new_long_array: 
{}", e)),
+    build_result(&mut env, ids, dists)
+}
+
+#[no_mangle]
+pub extern "system" fn 
Java_org_apache_paimon_index_ivfpq_IVFPQNative_searchWithRoaringFilter(
+    mut env: JNIEnv,
+    _class: JClass,
+    ptr: jlong,
+    query: JFloatArray,
+    k: jint,
+    nprobe: jint,
+    roaring_filter: JByteArray,
+) -> jobject {
+    let reader = match deref_reader(ptr) {
+        Some(r) => r,
+        None => return throw_and_return(&mut env, "null native pointer (reader 
already freed?)"),
     };
-    let _ = env.set_long_array_region(&id_array, 0, &ids);
 
-    let dist_array = match env.new_float_array(dists.len() as i32) {
-        Ok(a) => a,
-        Err(e) => return throw_and_return(&mut env, &format!("new_float_array: 
{}", e)),
+    if k <= 0 || nprobe <= 0 {
+        return throw_and_return(
+            &mut env,
+            &format!("invalid parameters: k={}, nprobe={}", k, nprobe),
+        );
+    }
+
+    let d = reader.d;
+    let query_len = match env.get_array_length(&query) {
+        Ok(l) => l as usize,
+        Err(e) => return throw_and_return(&mut env, 
&format!("get_array_length: {}", e)),
     };
-    let _ = env.set_float_array_region(&dist_array, 0, &dists);
+    if query_len != d {
+        return throw_and_return(
+            &mut env,
+            &format!("query array length {} != d={}", query_len, d),
+        );
+    }
 
-    let result_class = match 
env.find_class("org/apache/paimon/index/ivfpq/IVFPQResult") {
-        Ok(c) => c,
-        Err(e) => return throw_and_return(&mut env, &format!("find_class: {}", 
e)),
+    let mut query_buf = vec![0.0f32; d];
+    if let Err(e) = env.get_float_array_region(&query, 0, &mut query_buf) {
+        return throw_and_return(&mut env, &format!("get_float_array_region: 
{}", e));
+    }
+
+    let filter_bytes = match read_byte_array(&mut env, roaring_filter) {
+        Ok(bytes) => bytes,
+        Err(e) => return throw_and_return(&mut env, &e),
     };
 
-    let result = match env.new_object(
-        result_class,
-        "([J[F)V",
-        &[JValue::Object(&id_array), JValue::Object(&dist_array)],
+    let (ids, dists) = match reader.search_with_roaring_filter(
+        &query_buf,
+        k as usize,
+        nprobe as usize,
+        &filter_bytes,
     ) {
         Ok(r) => r,
-        Err(e) => return throw_and_return(&mut env, &format!("new_object: {}", 
e)),
+        Err(e) => return throw_and_return(&mut env, 
&format!("search_with_filter: {}", e)),
     };
 
-    result.into_raw()
+    build_result(&mut env, ids, dists)
 }
 
 // --- Reader metadata ---
@@ -399,38 +512,70 @@ pub extern "system" fn 
Java_org_apache_paimon_index_ivfpq_IVFPQNative_searchBatc
         Err(e) => return throw_and_return(&mut env, &format!("search_batch: 
{}", e)),
     };
 
-    let id_array = match env.new_long_array((nq * k) as i32) {
-        Ok(a) => a,
-        Err(e) => return throw_and_return(&mut env, &format!("new_long_array: 
{}", e)),
+    build_batch_result(&mut env, all_ids, all_dists, nq, k)
+}
+
+#[no_mangle]
+pub extern "system" fn 
Java_org_apache_paimon_index_ivfpq_IVFPQNative_searchBatchWithRoaringFilter(
+    mut env: JNIEnv,
+    _class: JClass,
+    ptr: jlong,
+    queries: JFloatArray,
+    nq: jint,
+    k: jint,
+    nprobe: jint,
+    roaring_filter: JByteArray,
+) -> jobject {
+    let reader = match deref_reader(ptr) {
+        Some(r) => r,
+        None => return throw_and_return(&mut env, "null native pointer (reader 
already freed?)"),
     };
-    let _ = env.set_long_array_region(&id_array, 0, &all_ids);
 
-    let dist_array = match env.new_float_array((nq * k) as i32) {
-        Ok(a) => a,
-        Err(e) => return throw_and_return(&mut env, &format!("new_float_array: 
{}", e)),
+    if nq <= 0 || k <= 0 || nprobe <= 0 {
+        return throw_and_return(
+            &mut env,
+            &format!("invalid parameters: nq={}, k={}, nprobe={}", nq, k, 
nprobe),
+        );
+    }
+
+    let d = reader.d;
+    let nq = nq as usize;
+    let k = k as usize;
+
+    let query_len = match env.get_array_length(&queries) {
+        Ok(l) => l as usize,
+        Err(e) => return throw_and_return(&mut env, 
&format!("get_array_length: {}", e)),
     };
-    let _ = env.set_float_array_region(&dist_array, 0, &all_dists);
+    if query_len != nq * d {
+        return throw_and_return(
+            &mut env,
+            &format!("queries array length {} != nq*d={}", query_len, nq * d),
+        );
+    }
 
-    let result_class = match 
env.find_class("org/apache/paimon/index/ivfpq/IVFPQBatchResult") {
-        Ok(c) => c,
-        Err(e) => return throw_and_return(&mut env, &format!("find_class: {}", 
e)),
+    let mut query_buf = vec![0.0f32; nq * d];
+    if let Err(e) = env.get_float_array_region(&queries, 0, &mut query_buf) {
+        return throw_and_return(&mut env, &format!("get_float_array_region: 
{}", e));
+    }
+
+    let filter_bytes = match read_byte_array(&mut env, roaring_filter) {
+        Ok(bytes) => bytes,
+        Err(e) => return throw_and_return(&mut env, &e),
     };
 
-    let result = match env.new_object(
-        result_class,
-        "([J[FII)V",
-        &[
-            JValue::Object(&id_array),
-            JValue::Object(&dist_array),
-            JValue::Int(nq as jint),
-            JValue::Int(k as jint),
-        ],
+    let (all_ids, all_dists) = match search_batch_reader_roaring_filter(
+        reader,
+        &query_buf,
+        nq,
+        k,
+        nprobe as usize,
+        &filter_bytes,
     ) {
-        Ok(r) => r,
-        Err(e) => return throw_and_return(&mut env, &format!("new_object: {}", 
e)),
+        Ok(result) => result,
+        Err(e) => return throw_and_return(&mut env, 
&format!("search_batch_with_filter: {}", e)),
     };
 
-    result.into_raw()
+    build_batch_result(&mut env, all_ids, all_dists, nq, k)
 }
 
 #[no_mangle]
diff --git a/python/Cargo.lock b/python/Cargo.lock
index eef0727..de3821d 100644
--- a/python/Cargo.lock
+++ b/python/Cargo.lock
@@ -23,6 +23,12 @@ version = "1.25.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec"
 
+[[package]]
+name = "byteorder"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
+
 [[package]]
 name = "cfg-if"
 version = "1.0.4"
@@ -230,6 +236,7 @@ dependencies = [
  "nalgebra",
  "rand",
  "rayon",
+ "roaring",
 ]
 
 [[package]]
@@ -408,6 +415,16 @@ dependencies = [
  "crossbeam-utils",
 ]
 
+[[package]]
+name = "roaring"
+version = "0.11.4"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "1dedc5658c6ecb3bdb5ef5f3295bb9253f42dcf3fd1402c03f6b1f7659c3c4a9"
+dependencies = [
+ "bytemuck",
+ "byteorder",
+]
+
 [[package]]
 name = "rustc-hash"
 version = "1.1.0"
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 52204d9..fa71d64 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -19,9 +19,9 @@
 
 use numpy::{PyArray1, PyReadonlyArray1};
 use paimon_vindex_core::io::{IVFPQIndexReader, SeekRead};
-use pyo3::exceptions::PyIOError;
+use pyo3::exceptions::{PyIOError, PyValueError};
 use pyo3::prelude::*;
-use pyo3::types::PyBytes;
+use pyo3::types::{PyAny, PyBytes};
 use std::io;
 
 /// Python file object wrapper implementing SeekRead.
@@ -99,35 +99,43 @@ impl IVFPQReader {
     }
 
     #[allow(clippy::type_complexity)]
+    #[pyo3(signature = (query, top_k, nprobe, filter_bytes=None))]
     fn search<'py>(
         &mut self,
         py: Python<'py>,
         query: PyReadonlyArray1<f32>,
         top_k: usize,
         nprobe: usize,
+        filter_bytes: Option<&Bound<'_, PyAny>>,
     ) -> PyResult<(Bound<'py, PyArray1<i64>>, Bound<'py, PyArray1<f32>>)> {
         let query_slice = query.as_slice()?;
 
         if query_slice.len() != self.inner.d {
-            return Err(pyo3::exceptions::PyValueError::new_err(format!(
+            return Err(PyValueError::new_err(format!(
                 "query length {} != index dimension {}",
                 query_slice.len(),
                 self.inner.d
             )));
         }
         if top_k == 0 {
-            return Err(pyo3::exceptions::PyValueError::new_err("top_k must be 
> 0"));
+            return Err(PyValueError::new_err("top_k must be > 0"));
         }
         if nprobe == 0 {
-            return Err(pyo3::exceptions::PyValueError::new_err(
-                "nprobe must be > 0",
-            ));
+            return Err(PyValueError::new_err("nprobe must be > 0"));
         }
 
-        let (ids, dists) = self
-            .inner
-            .search(query_slice, top_k, nprobe)
-            .map_err(|e| PyIOError::new_err(format!("Search failed: {}", e)))?;
+        let (ids, dists) = if let Some(filter_obj) = filter_bytes {
+            let bytes: &Bound<PyBytes> = filter_obj
+                .downcast()
+                .map_err(|_| PyValueError::new_err("filter_bytes must be 
bytes"))?;
+            self.inner
+                .search_with_roaring_filter(query_slice, top_k, nprobe, 
bytes.as_bytes())
+                .map_err(|e| PyIOError::new_err(format!("Search failed: {}", 
e)))?
+        } else {
+            self.inner
+                .search(query_slice, top_k, nprobe)
+                .map_err(|e| PyIOError::new_err(format!("Search failed: {}", 
e)))?
+        };
 
         let id_array = PyArray1::from_vec_bound(py, ids);
         let dist_array = PyArray1::from_vec_bound(py, dists);

Reply via email to