This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new ee8c7a4 Add JNI bindings, Python bindings, and PQ4 benchmark (#7)
ee8c7a4 is described below
commit ee8c7a4ad7ec994dbb9a4f160caf31e5fcfaa5b9
Author: Jingsong Lee <[email protected]>
AuthorDate: Mon Jun 8 14:40:37 2026 +0800
Add JNI bindings, Python bindings, and PQ4 benchmark (#7)
JNI module (jni/):
- Writer API: createWriter, train, add, finishWrite — builds IVF-PQ
index and serializes to Java OutputStream via JniOutputStream
- Reader API: openReader, search, batchSearch — lazy-loading reader
backed by Java SeekableInputStream via JniSeekableStream
- JniSeekableStream: implements SeekRead with lock-protected seek+read,
optional pread via VectoredReadable for concurrent positional reads
- JniOutputStream: implements SeekWrite for streaming index writes
Python module (python/):
- IVFPQReader: open from Python file object, search returns numpy arrays
(ids, distances), context manager support (__enter__/__exit__)
- PyFileStream: implements SeekRead by delegating to Python file.seek/read
Benchmark (core/benches/pq4_bench.rs):
- Comparative benchmark: 8-bit vs 4-bit PQ vs 4-bit FastScan
- Measures query latency, build time, and recall@10 vs brute-force
- Run with: cargo bench --bench pq4_bench
---
Cargo.toml | 3 +-
core/Cargo.toml | 7 +
core/benches/pq4_bench.rs | 189 ++++++++++++++++++
{core => jni}/Cargo.toml | 12 +-
jni/src/lib.rs | 460 ++++++++++++++++++++++++++++++++++++++++++++
jni/src/stream.rs | 275 ++++++++++++++++++++++++++
{core => python}/Cargo.toml | 13 +-
python/src/lib.rs | 162 ++++++++++++++++
8 files changed, 1110 insertions(+), 11 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index ff31345..cd6cc8e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,7 +16,8 @@
# under the License.
[workspace]
-members = ["core"]
+members = ["core", "jni"]
+exclude = ["python"]
resolver = "2"
[profile.release]
diff --git a/core/Cargo.toml b/core/Cargo.toml
index 22ecb07..120be93 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -26,3 +26,10 @@ nalgebra = "0.33"
rand = "0.8"
rayon = "1.10"
matrixmultiply = "0.3"
+
+[dev-dependencies]
+criterion = "0.5"
+
+[[bench]]
+name = "pq4_bench"
+harness = false
diff --git a/core/benches/pq4_bench.rs b/core/benches/pq4_bench.rs
new file mode 100644
index 0000000..3590231
--- /dev/null
+++ b/core/benches/pq4_bench.rs
@@ -0,0 +1,189 @@
+use paimon_vindex_core::distance::MetricType;
+use paimon_vindex_core::fastscan::{fastscan_4bit, pack_codes_block_layout};
+use paimon_vindex_core::ivfpq::IVFPQIndex;
+use std::collections::HashSet;
+use std::time::Instant;
+
+fn main() {
+ println!("=== Paimon IVF-PQ Full Benchmark ===\n");
+
+ let d = 128;
+ let nlist = 256;
+ let n = 100_000;
+ let nprobe = 8;
+ let k = 10;
+ let nq = 100;
+
+ println!(
+ "Dataset: {}K vectors, d={}, nlist={}, nprobe={}, k={}",
+ n / 1000,
+ d,
+ nlist,
+ nprobe,
+ k
+ );
+ println!();
+
+ // Generate data
+ let mut rng_state: u64 = 42;
+ let mut next = || -> f32 {
+ rng_state =
rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
+ ((rng_state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
+ };
+ let num_clusters = 16;
+ let mut centers = vec![0.0f32; num_clusters * d];
+ for v in centers.iter_mut() {
+ *v = next() * 50.0;
+ }
+ let mut data = vec![0.0f32; n * d];
+ for i in 0..n {
+ let c = i % num_clusters;
+ for j in 0..d {
+ data[i * d + j] = centers[c * d + j] + next();
+ }
+ }
+ let ids: Vec<i64> = (0..n as i64).collect();
+ let queries = &data[..nq * d];
+
+ // === 8-bit M=16 ===
+ let m8 = 16;
+ print!("Building 8-bit (M={})...", m8);
+ let start = Instant::now();
+ let mut idx8 = IVFPQIndex::new(d, nlist, m8, MetricType::L2, false);
+ idx8.train(&data, n);
+ idx8.add(&data, &ids, n);
+ idx8.build_search_structures();
+ idx8.build_precomputed_table();
+ let build8 = start.elapsed();
+ println!(" {:.2}s", build8.as_secs_f64());
+
+ // === 4-bit M=32 (same storage) ===
+ let m4 = 32;
+ print!("Building 4-bit (M={})...", m4);
+ let start = Instant::now();
+ let mut idx4 = IVFPQIndex::with_nbits(d, nlist, m4, 4, MetricType::L2,
false);
+ idx4.train(&data, n);
+ idx4.add(&data, &ids, n);
+ idx4.build_search_structures();
+ idx4.build_precomputed_table();
+ let build4 = start.elapsed();
+ println!(" {:.2}s", build4.as_secs_f64());
+
+ // === Query: 8-bit ===
+ let mut d8 = vec![0.0f32; nq * k];
+ let mut l8 = vec![0i64; nq * k];
+ let start = Instant::now();
+ for _ in 0..5 {
+ idx8.search(queries, nq, k, nprobe, &mut d8, &mut l8);
+ }
+ let q8 = start.elapsed().as_secs_f64() / 5.0;
+
+ // === Query: 4-bit (standard scan) ===
+ let mut d4 = vec![0.0f32; nq * k];
+ let mut l4 = vec![0i64; nq * k];
+ let start = Instant::now();
+ for _ in 0..5 {
+ idx4.search(queries, nq, k, nprobe, &mut d4, &mut l4);
+ }
+ let q4 = start.elapsed().as_secs_f64() / 5.0;
+
+ // === Query: 4-bit FastScan (block layout) ===
+ // Simulate: pack one list's codes and scan with fastscan
+ let biggest_list = idx4
+ .codes
+ .iter()
+ .enumerate()
+ .max_by_key(|(_, c)| c.len())
+ .map(|(i, _)| i)
+ .unwrap();
+ let list_n = idx4.ids[biggest_list].len();
+ let cs4 = idx4.pq.code_size();
+ let packed = pack_codes_block_layout(&idx4.codes[biggest_list], list_n,
cs4);
+
+ // Build distance table for benchmark
+ let mut sim_table = vec![0.0f32; m4 * 16];
+ let query0 = &data[0..d];
+ // compute residual
+ let centroid = &idx4.quantizer_centroids[biggest_list * d..(biggest_list +
1) * d];
+ let residual: Vec<f32> = (0..d).map(|j| query0[j] - centroid[j]).collect();
+ idx4.pq
+ .compute_distance_table(&residual, MetricType::L2, &mut sim_table);
+
+ let mut fs_dists = vec![0.0f32; list_n];
+ let start = Instant::now();
+ for _ in 0..100 {
+ fastscan_4bit(&sim_table, &packed, list_n, m4, &mut fs_dists);
+ }
+ let fs_us = start.elapsed().as_micros() as f64 / 100.0;
+
+ // === Recall ===
+ let nq_r = nq.min(20);
+ let mut recall_4 = 0usize;
+ let mut recall_8 = 0usize;
+ for qi in 0..nq_r {
+ let query = &data[qi * d..(qi + 1) * d];
+ let mut bf: Vec<(f32, i64)> = (0..n)
+ .map(|i| {
+ let mut dist = 0.0f32;
+ for j in 0..d {
+ let diff = query[j] - data[i * d + j];
+ dist += diff * diff;
+ }
+ (dist, i as i64)
+ })
+ .collect();
+ bf.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ let gt: HashSet<i64> = bf[..k].iter().map(|&(_, id)| id).collect();
+ let base = qi * k;
+ recall_4 += l4[base..base + k]
+ .iter()
+ .filter(|id| gt.contains(id))
+ .count();
+ recall_8 += l8[base..base + k]
+ .iter()
+ .filter(|id| gt.contains(id))
+ .count();
+ }
+
+ let codes_4: usize = idx4.codes.iter().map(|c| c.len()).sum();
+ let codes_8: usize = idx8.codes.iter().map(|c| c.len()).sum();
+
+ // === Print results ===
+
println!("\n╔══════════════════════════════════════════════════════════════════╗");
+ println!("║ Paimon IVF-PQ Performance Summary
║");
+
println!("╠══════════════════════════════════════════════════════════════════╣");
+ println!(
+ "║ 8-bit (M={}) 4-bit (M={}) ║",
+ m8, m4
+ );
+
println!("╠══════════════════════════════════════════════════════════════════╣");
+ println!(
+ "║ Storage/vec: {} bytes {} bytes ║",
+ codes_8 / n,
+ codes_4 / n
+ );
+ println!(
+ "║ Build time: {:.2}s {:.2}s ║",
+ build8.as_secs_f64(),
+ build4.as_secs_f64()
+ );
+ println!(
+ "║ Query (nq={}): {:.1}ms ({:.0}μs/q) {:.1}ms ({:.0}μs/q) ║",
+ nq,
+ q8 * 1000.0,
+ q8 * 1e6 / nq as f64,
+ q4 * 1000.0,
+ q4 * 1e6 / nq as f64
+ );
+ println!(
+ "║ Recall@{}: {:.1}% {:.1}% ║",
+ k,
+ recall_8 as f64 / (nq_r * k) as f64 * 100.0,
+ recall_4 as f64 / (nq_r * k) as f64 * 100.0
+ );
+ println!(
+ "║ FastScan (1 list): - {:.0}μs ({} vecs) ║",
+ fs_us, list_n
+ );
+
println!("╚══════════════════════════════════════════════════════════════════╝");
+}
diff --git a/core/Cargo.toml b/jni/Cargo.toml
similarity index 86%
copy from core/Cargo.toml
copy to jni/Cargo.toml
index 22ecb07..e785f52 100644
--- a/core/Cargo.toml
+++ b/jni/Cargo.toml
@@ -16,13 +16,15 @@
# under the License.
[package]
-name = "paimon-vindex-core"
+name = "paimon-vindex-jni"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
+[lib]
+name = "paimon_vindex_jni"
+crate-type = ["cdylib"]
+
[dependencies]
-nalgebra = "0.33"
-rand = "0.8"
-rayon = "1.10"
-matrixmultiply = "0.3"
+paimon-vindex-core = { path = "../core" }
+jni = "0.21"
diff --git a/jni/src/lib.rs b/jni/src/lib.rs
new file mode 100644
index 0000000..b55ad31
--- /dev/null
+++ b/jni/src/lib.rs
@@ -0,0 +1,460 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod stream;
+
+use jni::objects::{JClass, JFloatArray, JLongArray, JObject, JValue};
+use jni::sys::{jboolean, jint, jlong, jobject};
+use jni::JNIEnv;
+use paimon_vindex_core::distance::MetricType;
+use paimon_vindex_core::io::{write_index, IVFPQIndexReader};
+use paimon_vindex_core::ivfpq::IVFPQIndex;
+use stream::{JniOutputStream, JniSeekableStream};
+
+fn throw_and_return<T: Default>(env: &mut JNIEnv, msg: &str) -> T {
+ let _ = env.throw_new("java/lang/RuntimeException", msg);
+ T::default()
+}
+
+fn deref_writer(ptr: jlong) -> Option<&'static mut IVFPQIndex> {
+ if ptr == 0 {
+ None
+ } else {
+ Some(unsafe { &mut *(ptr as *mut IVFPQIndex) })
+ }
+}
+
+fn deref_reader(ptr: jlong) -> Option<&'static mut
IVFPQIndexReader<JniSeekableStream>> {
+ if ptr == 0 {
+ None
+ } else {
+ Some(unsafe { &mut *(ptr as *mut IVFPQIndexReader<JniSeekableStream>)
})
+ }
+}
+
+// --- Writer API ---
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_createWriter(
+ mut env: JNIEnv,
+ _class: JClass,
+ d: jint,
+ nlist: jint,
+ m: jint,
+ metric: jint,
+ use_opq: jboolean,
+) -> jlong {
+ if d <= 0 || nlist <= 0 || m <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: d={}, nlist={}, m={}", d, nlist, m),
+ );
+ }
+ if d % m != 0 {
+ return throw_and_return(&mut env, &format!("d={} must be divisible by
m={}", d, m));
+ }
+
+ let metric_type = match MetricType::from_code(metric as u32) {
+ Some(m) => m,
+ None => return throw_and_return(&mut env, &format!("Unknown metric:
{}", metric)),
+ };
+
+ let index = Box::new(IVFPQIndex::new(
+ d as usize,
+ nlist as usize,
+ m as usize,
+ metric_type,
+ use_opq != 0,
+ ));
+ Box::into_raw(index) as jlong
+}
+
+#[no_mangle]
+pub extern "system" fn Java_org_apache_paimon_index_ivfpq_IVFPQNative_train(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ data: JFloatArray,
+ n: jint,
+) {
+ let index = match deref_writer(ptr) {
+ Some(i) => i,
+ None => return throw_and_return(&mut env, "null native pointer (writer
already freed?)"),
+ };
+
+ if n <= 0 {
+ return throw_and_return(&mut env, &format!("invalid n: {}", n));
+ }
+ let n = n as usize;
+
+ let len = match env.get_array_length(&data) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+
+ if len < n * index.d {
+ return throw_and_return(
+ &mut env,
+ &format!(
+ "data array too short: {} < n*d={}*{}={}",
+ len,
+ n,
+ index.d,
+ n * index.d
+ ),
+ );
+ }
+
+ let mut buf = vec![0.0f32; len];
+ if let Err(e) = env.get_float_array_region(&data, 0, &mut buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+
+ index.train(&buf, n);
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_addVectors(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ ids: JLongArray,
+ data: JFloatArray,
+ n: jint,
+) {
+ let index = match deref_writer(ptr) {
+ Some(i) => i,
+ None => return throw_and_return(&mut env, "null native pointer (writer
already freed?)"),
+ };
+
+ if n <= 0 {
+ return throw_and_return(&mut env, &format!("invalid n: {}", n));
+ }
+ let n = n as usize;
+
+ let id_len = match env.get_array_length(&ids) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if id_len < n {
+ return throw_and_return(
+ &mut env,
+ &format!("ids array too short: {} < n={}", id_len, n),
+ );
+ }
+
+ let mut id_buf = vec![0i64; n];
+ if let Err(e) = env.get_long_array_region(&ids, 0, &mut id_buf) {
+ return throw_and_return(&mut env, &format!("get_long_array_region:
{}", e));
+ }
+
+ let data_len = match env.get_array_length(&data) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if data_len < n * index.d {
+ return throw_and_return(
+ &mut env,
+ &format!("data array too short: {} < n*d={}", data_len, n *
index.d),
+ );
+ }
+
+ let mut data_buf = vec![0.0f32; data_len];
+ if let Err(e) = env.get_float_array_region(&data, 0, &mut data_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+
+ index.add(&data_buf, &id_buf, n);
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_writeIndex(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ stream_output: JObject,
+) {
+ let index = match deref_writer(ptr) {
+ Some(i) => i,
+ None => return throw_and_return(&mut env, "null native pointer (writer
already freed?)"),
+ };
+
+ let jvm = match env.get_java_vm() {
+ Ok(vm) => vm,
+ Err(e) => return throw_and_return(&mut env, &format!("get_java_vm:
{}", e)),
+ };
+
+ let global_ref = match env.new_global_ref(stream_output) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("new_global_ref:
{}", e)),
+ };
+
+ let mut writer = JniOutputStream::new(jvm, global_ref);
+ if let Err(e) = write_index(index, &mut writer) {
+ throw_and_return::<()>(&mut env, &format!("write_index: {}", e));
+ }
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_freeWriter(
+ _env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) {
+ if ptr != 0 {
+ unsafe {
+ drop(Box::from_raw(ptr as *mut IVFPQIndex));
+ }
+ }
+}
+
+// --- Reader API ---
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_openReader(
+ mut env: JNIEnv,
+ _class: JClass,
+ stream_input: JObject,
+) -> jlong {
+ let jvm = match env.get_java_vm() {
+ Ok(vm) => vm,
+ Err(e) => return throw_and_return(&mut env, &format!("get_java_vm:
{}", e)),
+ };
+
+ let global_ref = match env.new_global_ref(stream_input) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("new_global_ref:
{}", e)),
+ };
+
+ let stream = JniSeekableStream::new(jvm, global_ref);
+ let reader = match IVFPQIndexReader::open(stream) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("open: {}", e)),
+ };
+
+ Box::into_raw(Box::new(reader)) as jlong
+}
+
+#[no_mangle]
+pub extern "system" fn Java_org_apache_paimon_index_ivfpq_IVFPQNative_search(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ query: JFloatArray,
+ k: jint,
+ nprobe: jint,
+) -> jobject {
+ let reader = match deref_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+
+ if k <= 0 || nprobe <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: k={}, nprobe={}", k, nprobe),
+ );
+ }
+
+ let d = reader.d;
+ let query_len = match env.get_array_length(&query) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if query_len != d {
+ return throw_and_return(
+ &mut env,
+ &format!("query array length {} != d={}", query_len, d),
+ );
+ }
+
+ let mut query_buf = vec![0.0f32; d];
+ if let Err(e) = env.get_float_array_region(&query, 0, &mut query_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+
+ let (ids, dists) = match reader.search(&query_buf, k as usize, nprobe as
usize) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("search: {}", e)),
+ };
+
+ let id_array = match env.new_long_array(ids.len() as i32) {
+ Ok(a) => a,
+ Err(e) => return throw_and_return(&mut env, &format!("new_long_array:
{}", e)),
+ };
+ let _ = env.set_long_array_region(&id_array, 0, &ids);
+
+ let dist_array = match env.new_float_array(dists.len() as i32) {
+ Ok(a) => a,
+ Err(e) => return throw_and_return(&mut env, &format!("new_float_array:
{}", e)),
+ };
+ let _ = env.set_float_array_region(&dist_array, 0, &dists);
+
+ let result_class = match
env.find_class("org/apache/paimon/index/ivfpq/IVFPQResult") {
+ Ok(c) => c,
+ Err(e) => return throw_and_return(&mut env, &format!("find_class: {}",
e)),
+ };
+
+ let result = match env.new_object(
+ result_class,
+ "([J[F)V",
+ &[JValue::Object(&id_array), JValue::Object(&dist_array)],
+ ) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("new_object: {}",
e)),
+ };
+
+ result.into_raw()
+}
+
+// --- Reader metadata ---
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_getDimension(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) -> jint {
+ let reader = match deref_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ reader.d as jint
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_getTotalVectors(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) -> jlong {
+ let reader = match deref_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+ reader.total_vectors
+}
+
+// --- Batch search ---
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_searchBatch(
+ mut env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+ queries: JFloatArray,
+ nq: jint,
+ k: jint,
+ nprobe: jint,
+) -> jobject {
+ let reader = match deref_reader(ptr) {
+ Some(r) => r,
+ None => return throw_and_return(&mut env, "null native pointer (reader
already freed?)"),
+ };
+
+ if nq <= 0 || k <= 0 || nprobe <= 0 {
+ return throw_and_return(
+ &mut env,
+ &format!("invalid parameters: nq={}, k={}, nprobe={}", nq, k,
nprobe),
+ );
+ }
+
+ let d = reader.d;
+ let nq = nq as usize;
+ let k = k as usize;
+
+ let query_len = match env.get_array_length(&queries) {
+ Ok(l) => l as usize,
+ Err(e) => return throw_and_return(&mut env,
&format!("get_array_length: {}", e)),
+ };
+ if query_len != nq * d {
+ return throw_and_return(
+ &mut env,
+ &format!("queries array length {} != nq*d={}", query_len, nq * d),
+ );
+ }
+
+ let mut query_buf = vec![0.0f32; nq * d];
+ if let Err(e) = env.get_float_array_region(&queries, 0, &mut query_buf) {
+ return throw_and_return(&mut env, &format!("get_float_array_region:
{}", e));
+ }
+
+ let mut all_ids = vec![-1i64; nq * k];
+ let mut all_dists = vec![f32::MAX; nq * k];
+
+ for qi in 0..nq {
+ let query = &query_buf[qi * d..(qi + 1) * d];
+ match reader.search(query, k, nprobe as usize) {
+ Ok((ids, dists)) => {
+ let base = qi * k;
+ for (i, (&id, &dist)) in
ids.iter().zip(dists.iter()).enumerate() {
+ all_ids[base + i] = id;
+ all_dists[base + i] = dist;
+ }
+ }
+ Err(e) => return throw_and_return(&mut env, &format!("search: {}",
e)),
+ }
+ }
+
+ let id_array = match env.new_long_array((nq * k) as i32) {
+ Ok(a) => a,
+ Err(e) => return throw_and_return(&mut env, &format!("new_long_array:
{}", e)),
+ };
+ let _ = env.set_long_array_region(&id_array, 0, &all_ids);
+
+ let dist_array = match env.new_float_array((nq * k) as i32) {
+ Ok(a) => a,
+ Err(e) => return throw_and_return(&mut env, &format!("new_float_array:
{}", e)),
+ };
+ let _ = env.set_float_array_region(&dist_array, 0, &all_dists);
+
+ let result_class = match
env.find_class("org/apache/paimon/index/ivfpq/IVFPQBatchResult") {
+ Ok(c) => c,
+ Err(e) => return throw_and_return(&mut env, &format!("find_class: {}",
e)),
+ };
+
+ let result = match env.new_object(
+ result_class,
+ "([J[FII)V",
+ &[
+ JValue::Object(&id_array),
+ JValue::Object(&dist_array),
+ JValue::Int(nq as jint),
+ JValue::Int(k as jint),
+ ],
+ ) {
+ Ok(r) => r,
+ Err(e) => return throw_and_return(&mut env, &format!("new_object: {}",
e)),
+ };
+
+ result.into_raw()
+}
+
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_IVFPQNative_freeReader(
+ _env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) {
+ if ptr != 0 {
+ unsafe {
+ drop(Box::from_raw(
+ ptr as *mut IVFPQIndexReader<JniSeekableStream>,
+ ));
+ }
+ }
+}
diff --git a/jni/src/stream.rs b/jni/src/stream.rs
new file mode 100644
index 0000000..1e566a3
--- /dev/null
+++ b/jni/src/stream.rs
@@ -0,0 +1,275 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use jni::objects::GlobalRef;
+use jni::JavaVM;
+use paimon_vindex_core::io::SeekRead;
+use std::io;
+use std::sync::{Arc, Mutex};
+
+/// JNI-backed seekable stream that delegates to Java's SeekableInputStream.
+///
+/// If the Java stream also implements VectoredReadable, pread() is used for
+/// thread-safe positional reads without changing the stream cursor.
+pub struct JniSeekableStream {
+ jvm: Arc<JavaVM>,
+ stream_ref: Arc<GlobalRef>,
+ stream_lock: Arc<Mutex<()>>,
+ /// Whether the Java stream supports pread (implements VectoredReadable)
+ has_pread: bool,
+}
+
+impl JniSeekableStream {
+ pub fn new(jvm: JavaVM, stream_ref: GlobalRef) -> Self {
+ let jvm = Arc::new(jvm);
+ let has_pread = check_has_pread(&jvm, &stream_ref);
+ JniSeekableStream {
+ jvm,
+ stream_ref: Arc::new(stream_ref),
+ stream_lock: Arc::new(Mutex::new(())),
+ has_pread,
+ }
+ }
+}
+
+/// Check if the Java object implements VectoredReadable (has pread method).
+fn check_has_pread(jvm: &JavaVM, stream_ref: &GlobalRef) -> bool {
+ let mut env = match jvm.attach_current_thread() {
+ Ok(e) => e,
+ Err(_) => return false,
+ };
+ // Try to find the pread method — if it exists, the stream supports
positional reads
+ let class = match env.get_object_class(stream_ref.as_obj()) {
+ Ok(c) => c,
+ Err(_) => return false,
+ };
+ env.get_method_id(&class, "pread", "(J[BII)I").is_ok()
+}
+
+impl SeekRead for JniSeekableStream {
+ fn seek(&mut self, pos: u64) -> io::Result<()> {
+ let _guard = self
+ .stream_lock
+ .lock()
+ .map_err(|e| io::Error::other(format!("Lock poisoned: {}", e)))?;
+
+ let mut env = self
+ .jvm
+ .attach_current_thread()
+ .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
+
+ env.call_method(
+ self.stream_ref.as_obj(),
+ "seek",
+ "(J)V",
+ &[jni::objects::JValue::Long(pos as i64)],
+ )
+ .map_err(|e| io::Error::other(format!("JNI seek: {}", e)))?;
+
+ Ok(())
+ }
+
+ fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
+ let _guard = self
+ .stream_lock
+ .lock()
+ .map_err(|e| io::Error::other(format!("Lock poisoned: {}", e)))?;
+
+ read_bytes_from_stream(&self.jvm, &self.stream_ref, buf)
+ }
+
+ /// Positional read via Java's VectoredReadable.pread(position, buffer,
offset, length).
+ /// Thread-safe: does not change the stream cursor position.
+ fn pread(&mut self, pos: u64, buf: &mut [u8]) -> io::Result<()> {
+ if !self.has_pread {
+ // Fallback: seek + read with lock
+ let _guard = self
+ .stream_lock
+ .lock()
+ .map_err(|e| io::Error::other(format!("Lock poisoned: {}",
e)))?;
+
+ let mut env = self
+ .jvm
+ .attach_current_thread()
+ .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
+
+ env.call_method(
+ self.stream_ref.as_obj(),
+ "seek",
+ "(J)V",
+ &[jni::objects::JValue::Long(pos as i64)],
+ )
+ .map_err(|e| io::Error::other(format!("JNI seek: {}", e)))?;
+
+ drop(env);
+ return read_bytes_from_stream(&self.jvm, &self.stream_ref, buf);
+ }
+
+ // Use pread — no lock needed, thread-safe positional read
+ let mut env = self
+ .jvm
+ .attach_current_thread()
+ .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
+
+ let jbuf = env
+ .new_byte_array(buf.len() as i32)
+ .map_err(|e| io::Error::other(format!("JNI alloc: {}", e)))?;
+
+ let mut total_read = 0i32;
+ let length = buf.len() as i32;
+
+ while total_read < length {
+ let remaining = length - total_read;
+ let n = env
+ .call_method(
+ self.stream_ref.as_obj(),
+ "pread",
+ "(J[BII)I",
+ &[
+ jni::objects::JValue::Long(pos as i64 + total_read as
i64),
+ jni::objects::JValue::Object(&jbuf),
+ jni::objects::JValue::Int(total_read),
+ jni::objects::JValue::Int(remaining),
+ ],
+ )
+ .map_err(|e| io::Error::other(format!("JNI pread: {}", e)))?
+ .i()
+ .map_err(|e| io::Error::other(format!("JNI pread return: {}",
e)))?;
+
+ if n <= 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("pread EOF: read {} of {} bytes", total_read,
length),
+ ));
+ }
+ total_read += n;
+ }
+
+ let mut signed_buf = vec![0i8; buf.len()];
+ env.get_byte_array_region(&jbuf, 0, &mut signed_buf)
+ .map_err(|e| io::Error::other(format!("JNI get_region: {}", e)))?;
+
+ for (i, &b) in signed_buf.iter().enumerate() {
+ buf[i] = b as u8;
+ }
+
+ Ok(())
+ }
+
+ fn supports_concurrent_pread(&self) -> bool {
+ self.has_pread
+ }
+}
+
+/// Helper: read bytes from the Java stream (after seek, under lock).
+fn read_bytes_from_stream(jvm: &JavaVM, stream_ref: &GlobalRef, buf: &mut
[u8]) -> io::Result<()> {
+ let mut env = jvm
+ .attach_current_thread()
+ .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
+
+ let jbuf = env
+ .new_byte_array(buf.len() as i32)
+ .map_err(|e| io::Error::other(format!("JNI alloc: {}", e)))?;
+
+ let mut total_read = 0i32;
+ let length = buf.len() as i32;
+
+ while total_read < length {
+ let remaining = length - total_read;
+ let n = env
+ .call_method(
+ stream_ref.as_obj(),
+ "read",
+ "([BII)I",
+ &[
+ jni::objects::JValue::Object(&jbuf),
+ jni::objects::JValue::Int(total_read),
+ jni::objects::JValue::Int(remaining),
+ ],
+ )
+ .map_err(|e| io::Error::other(format!("JNI read: {}", e)))?
+ .i()
+ .map_err(|e| io::Error::other(format!("JNI read return: {}", e)))?;
+
+ if n <= 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("EOF: read {} of {} bytes", total_read, length),
+ ));
+ }
+ total_read += n;
+ }
+
+ let mut signed_buf = vec![0i8; buf.len()];
+ env.get_byte_array_region(&jbuf, 0, &mut signed_buf)
+ .map_err(|e| io::Error::other(format!("JNI get_region: {}", e)))?;
+
+ for (i, &b) in signed_buf.iter().enumerate() {
+ buf[i] = b as u8;
+ }
+
+ Ok(())
+}
+
+/// JNI-backed output stream that delegates to Java's PositionOutputStream.
+pub struct JniOutputStream {
+ jvm: Arc<JavaVM>,
+ stream_ref: Arc<GlobalRef>,
+ pos: u64,
+}
+
+impl JniOutputStream {
+ pub fn new(jvm: JavaVM, stream_ref: GlobalRef) -> Self {
+ JniOutputStream {
+ jvm: Arc::new(jvm),
+ stream_ref: Arc::new(stream_ref),
+ pos: 0,
+ }
+ }
+}
+
+impl paimon_vindex_core::io::SeekWrite for JniOutputStream {
+ fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
+ let mut env = self
+ .jvm
+ .attach_current_thread()
+ .map_err(|e| io::Error::other(format!("JNI attach: {}", e)))?;
+
+ let jbuf = env
+ .new_byte_array(buf.len() as i32)
+ .map_err(|e| io::Error::other(format!("JNI alloc: {}", e)))?;
+
+ let signed: Vec<i8> = buf.iter().map(|&b| b as i8).collect();
+ env.set_byte_array_region(&jbuf, 0, &signed)
+ .map_err(|e| io::Error::other(format!("JNI set_region: {}", e)))?;
+
+ env.call_method(
+ self.stream_ref.as_obj(),
+ "write",
+ "([B)V",
+ &[jni::objects::JValue::Object(&jbuf)],
+ )
+ .map_err(|e| io::Error::other(format!("JNI write: {}", e)))?;
+
+ self.pos += buf.len() as u64;
+ Ok(())
+ }
+
+ fn pos(&self) -> u64 {
+ self.pos
+ }
+}
diff --git a/core/Cargo.toml b/python/Cargo.toml
similarity index 81%
copy from core/Cargo.toml
copy to python/Cargo.toml
index 22ecb07..e18031d 100644
--- a/core/Cargo.toml
+++ b/python/Cargo.toml
@@ -16,13 +16,16 @@
# under the License.
[package]
-name = "paimon-vindex-core"
+name = "paimon-vindex-python"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
+[lib]
+name = "paimon_vindex"
+crate-type = ["cdylib"]
+
[dependencies]
-nalgebra = "0.33"
-rand = "0.8"
-rayon = "1.10"
-matrixmultiply = "0.3"
+paimon-vindex-core = { path = "../core" }
+pyo3 = { version = "0.22", features = ["extension-module"] }
+numpy = "0.22"
diff --git a/python/src/lib.rs b/python/src/lib.rs
new file mode 100644
index 0000000..52204d9
--- /dev/null
+++ b/python/src/lib.rs
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![allow(clippy::useless_conversion)]
+
+use numpy::{PyArray1, PyReadonlyArray1};
+use paimon_vindex_core::io::{IVFPQIndexReader, SeekRead};
+use pyo3::exceptions::PyIOError;
+use pyo3::prelude::*;
+use pyo3::types::PyBytes;
+use std::io;
+
+/// Python file object wrapper implementing SeekRead.
+struct PyFileStream {
+ file: PyObject,
+}
+
+impl SeekRead for PyFileStream {
+ fn seek(&mut self, pos: u64) -> io::Result<()> {
+ Python::with_gil(|py| {
+ self.file
+ .call_method1(py, "seek", (pos as i64,))
+ .map_err(|e| io::Error::other(format!("seek: {}", e)))?;
+ Ok(())
+ })
+ }
+
+ fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
+ Python::with_gil(|py| {
+ let result = self
+ .file
+ .call_method1(py, "read", (buf.len(),))
+ .map_err(|e| io::Error::other(format!("read: {}", e)))?;
+
+ let bytes: &Bound<PyBytes> = result
+ .downcast_bound(py)
+ .map_err(|e| io::Error::other(format!("downcast: {}", e)))?;
+
+ let data = bytes.as_bytes();
+ if data.len() != buf.len() {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!("read {} of {} bytes", data.len(), buf.len()),
+ ));
+ }
+ buf.copy_from_slice(data);
+ Ok(())
+ })
+ }
+}
+
+#[pyclass]
+struct IVFPQReader {
+ inner: IVFPQIndexReader<PyFileStream>,
+}
+
+#[pymethods]
+impl IVFPQReader {
+ #[new]
+ fn new(file: PyObject) -> PyResult<Self> {
+ let stream = PyFileStream { file };
+ let reader = IVFPQIndexReader::open(stream)
+ .map_err(|e| PyIOError::new_err(format!("Failed to open index:
{}", e)))?;
+ Ok(IVFPQReader { inner: reader })
+ }
+
+ #[getter]
+ fn dimension(&self) -> usize {
+ self.inner.d
+ }
+
+ #[getter]
+ fn nlist(&self) -> usize {
+ self.inner.nlist
+ }
+
+ #[getter]
+ fn m(&self) -> usize {
+ self.inner.m
+ }
+
+ #[getter]
+ fn total_vectors(&self) -> i64 {
+ self.inner.total_vectors
+ }
+
+ #[allow(clippy::type_complexity)]
+ fn search<'py>(
+ &mut self,
+ py: Python<'py>,
+ query: PyReadonlyArray1<f32>,
+ top_k: usize,
+ nprobe: usize,
+ ) -> PyResult<(Bound<'py, PyArray1<i64>>, Bound<'py, PyArray1<f32>>)> {
+ let query_slice = query.as_slice()?;
+
+ if query_slice.len() != self.inner.d {
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
+ "query length {} != index dimension {}",
+ query_slice.len(),
+ self.inner.d
+ )));
+ }
+ if top_k == 0 {
+ return Err(pyo3::exceptions::PyValueError::new_err("top_k must be
> 0"));
+ }
+ if nprobe == 0 {
+ return Err(pyo3::exceptions::PyValueError::new_err(
+ "nprobe must be > 0",
+ ));
+ }
+
+ let (ids, dists) = self
+ .inner
+ .search(query_slice, top_k, nprobe)
+ .map_err(|e| PyIOError::new_err(format!("Search failed: {}", e)))?;
+
+ let id_array = PyArray1::from_vec_bound(py, ids);
+ let dist_array = PyArray1::from_vec_bound(py, dists);
+
+ Ok((id_array, dist_array))
+ }
+
+ fn close(&mut self) -> PyResult<()> {
+ Ok(())
+ }
+
+ fn __enter__(slf: Py<Self>) -> Py<Self> {
+ slf
+ }
+
+ #[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
+ fn __exit__(
+ &mut self,
+ _exc_type: Option<&Bound<'_, pyo3::types::PyType>>,
+ _exc_val: Option<&Bound<'_, pyo3::types::PyAny>>,
+ _exc_tb: Option<&Bound<'_, pyo3::types::PyAny>>,
+ ) -> PyResult<bool> {
+ self.close()?;
+ Ok(false)
+ }
+}
+
+#[pymodule]
+fn paimon_vindex(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
+ m.add_class::<IVFPQReader>()?;
+ Ok(())
+}