This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new 81963db  Add Java and Python vector index APIs (#16)
81963db is described below

commit 81963db3d39aaa83efc2c1eb52f04d7e5738548d
Author: Jingsong Lee <[email protected]>
AuthorDate: Mon Jun 8 22:16:06 2026 +0800

    Add Java and Python vector index APIs (#16)
---
 .github/workflows/ci.yml                           |  13 +-
 README.md                                          |  16 +-
 .../paimon/index/ivfpq/IVFPQJavaApiTest.java       | 140 ++++++
 .../paimon/index/ivfpq/IVFPQBatchResult.java       |  97 ++++
 .../org/apache/paimon/index/ivfpq/IVFPQNative.java |  52 +++
 .../org/apache/paimon/index/ivfpq/IVFPQReader.java | 114 +++++
 .../org/apache/paimon/index/ivfpq/IVFPQResult.java |  59 +++
 .../org/apache/paimon/index/ivfpq/IVFPQWriter.java | 112 +++++
 jni/java/org/apache/paimon/index/ivfpq/Metric.java |  34 ++
 python/Cargo.lock                                  |   1 +
 python/Cargo.toml                                  |  10 +-
 python/src/lib.rs                                  | 488 ++++++++++++++++++++-
 12 files changed, 1118 insertions(+), 18 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 1a47796..81e421c 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -104,6 +104,12 @@ jobs:
           java-version: '8'
           distribution: 'temurin'
 
+      - name: Test Java API
+        run: |
+          mkdir -p target/java-api-test
+          javac -source 8 -target 8 -d target/java-api-test $(find jni/java 
jni/java-test -name '*.java')
+          java -cp target/java-api-test 
org.apache.paimon.index.ivfpq.IVFPQJavaApiTest
+
       - name: Build JNI library
         run: cargo build -p paimon-vindex-jni --release
 
@@ -128,8 +134,11 @@ jobs:
         with:
           python-version: '3.12'
 
-      - name: Install maturin
-        run: pip install maturin
+      - name: Install Python build/test dependencies
+        run: pip install maturin numpy
+
+      - name: Test Python API
+        run: cargo test --manifest-path python/Cargo.toml 
--no-default-features --features auto-initialize
 
       - name: Build Python module
         working-directory: python
diff --git a/README.md b/README.md
index 2826912..8de9f6b 100644
--- a/README.md
+++ b/README.md
@@ -31,11 +31,23 @@ The vector index accepts a serialized 64-bit Roaring bitmap 
of allowed row IDs d
 Bindings expose the same wire format:
 
 - Rust core: `search_with_reader_roaring_filter` and 
`search_batch_reader_roaring_filter`
-- JNI: `searchWithRoaringFilter` and `searchBatchWithRoaringFilter` with 
`byte[]`
-- Python: `IVFPQReader.search(..., filter_bytes=...)`
+- Java/JNI: `IVFPQReader.search(..., byte[])` and 
`IVFPQReader.searchBatch(..., byte[])`
+- Python: `IVFPQReader.search(..., filter_bytes=...)` and 
`IVFPQReader.search_batch(..., filter_bytes=...)`
 
 Row IDs must be non-negative to map directly into `RoaringTreemap`'s `u64` 
domain.
 
+## Language Bindings
+
+The Java binding provides small lifecycle-safe facades over the JNI symbols:
+`IVFPQWriter` builds and writes an index, `IVFPQReader` opens an index and runs
+single-query or batch search, and result containers expose defensive copies of
+IDs and distances.
+
+The Python binding mirrors that flow with `IVFPQWriter` and `IVFPQReader`.
+`search` returns one-dimensional NumPy arrays for a single query, while
+`search_batch` accepts a two-dimensional query array and returns 
two-dimensional
+NumPy arrays shaped as `(query_count, top_k)`.
+
 ## Contributing
 
 Apache Paimon Vector Index is an exciting project currently under active 
development. Whether you're looking to use it in your projects or contribute to 
its growth, there are several ways you can get involved:
diff --git a/jni/java-test/org/apache/paimon/index/ivfpq/IVFPQJavaApiTest.java 
b/jni/java-test/org/apache/paimon/index/ivfpq/IVFPQJavaApiTest.java
new file mode 100644
index 0000000..3b8998e
--- /dev/null
+++ b/jni/java-test/org/apache/paimon/index/ivfpq/IVFPQJavaApiTest.java
@@ -0,0 +1,140 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+import java.util.Arrays;
+
+public class IVFPQJavaApiTest {
+
+    public static void main(String[] args) {
+        testMetricCodes();
+        testSingleResultCopiesArrays();
+        testBatchResultCopiesArraysAndSlicesRows();
+        testReaderAndWriterApiCompile();
+    }
+
+    private static void testMetricCodes() {
+        assertEquals(0, Metric.L2.code());
+        assertEquals(1, Metric.INNER_PRODUCT.code());
+        assertEquals(2, Metric.COSINE.code());
+    }
+
+    private static void testSingleResultCopiesArrays() {
+        long[] ids = new long[] {11L, 7L};
+        float[] distances = new float[] {0.1f, 0.3f};
+
+        IVFPQResult result = new IVFPQResult(ids, distances);
+        ids[0] = 99L;
+        distances[0] = 9.0f;
+
+        assertArrayEquals(new long[] {11L, 7L}, result.ids());
+        assertArrayEquals(new float[] {0.1f, 0.3f}, result.distances());
+
+        long[] resultIds = result.ids();
+        resultIds[0] = 99L;
+        assertArrayEquals(new long[] {11L, 7L}, result.ids());
+    }
+
+    private static void testBatchResultCopiesArraysAndSlicesRows() {
+        long[] ids = new long[] {1L, 2L, 3L, 4L, 5L, 6L};
+        float[] distances = new float[] {0.1f, 0.2f, 0.3f, 1.1f, 1.2f, 1.3f};
+
+        IVFPQBatchResult result = new IVFPQBatchResult(ids, distances, 2, 3);
+        ids[0] = 99L;
+        distances[0] = 9.0f;
+
+        assertEquals(2, result.queryCount());
+        assertEquals(3, result.topK());
+        assertArrayEquals(new long[] {1L, 2L, 3L, 4L, 5L, 6L}, result.ids());
+        assertArrayEquals(new float[] {0.1f, 0.2f, 0.3f, 1.1f, 1.2f, 1.3f}, 
result.distances());
+        assertArrayEquals(new long[] {4L, 5L, 6L}, result.idsForQuery(1));
+        assertArrayEquals(new float[] {1.1f, 1.2f, 1.3f}, 
result.distancesForQuery(1));
+
+        assertThrows(IllegalArgumentException.class, new ThrowingRunnable() {
+            @Override
+            public void run() {
+                new IVFPQBatchResult(new long[] {1L}, new float[] {1.0f}, 2, 
3);
+            }
+        });
+        assertThrows(IndexOutOfBoundsException.class, new ThrowingRunnable() {
+            @Override
+            public void run() {
+                result.idsForQuery(2);
+            }
+        });
+    }
+
+    private static void testReaderAndWriterApiCompile() {
+        IVFPQReader closedReader = IVFPQReader.fromNativePointerForTesting(0L);
+        closedReader.close();
+        closedReader.close();
+
+        IVFPQWriter closedWriter = IVFPQWriter.fromNativePointerForTesting(0L, 
2);
+        closedWriter.close();
+        closedWriter.close();
+
+        if (System.currentTimeMillis() < 0) {
+            IVFPQReader reader = new IVFPQReader(new Object());
+            reader.dimension();
+            reader.totalVectors();
+            reader.search(new float[] {0.0f, 1.0f}, 10, 4);
+            reader.search(new float[] {0.0f, 1.0f}, 10, 4, new byte[] {1, 2});
+            reader.searchBatch(new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2, 10, 4);
+            reader.searchBatch(new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2, 10, 4, 
new byte[] {1, 2});
+
+            IVFPQWriter writer = new IVFPQWriter(2, 4, 1, Metric.L2, false);
+            writer.train(new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2);
+            writer.addVectors(new long[] {1L, 2L}, new float[] {0.0f, 1.0f, 
2.0f, 3.0f}, 2);
+            writer.writeIndex(new Object());
+        }
+    }
+
+    private static void assertEquals(int expected, int actual) {
+        if (expected != actual) {
+            throw new AssertionError("expected " + expected + " but got " + 
actual);
+        }
+    }
+
+    private static void assertArrayEquals(long[] expected, long[] actual) {
+        if (!Arrays.equals(expected, actual)) {
+            throw new AssertionError("expected " + Arrays.toString(expected) + 
" but got " + Arrays.toString(actual));
+        }
+    }
+
+    private static void assertArrayEquals(float[] expected, float[] actual) {
+        if (!Arrays.equals(expected, actual)) {
+            throw new AssertionError("expected " + Arrays.toString(expected) + 
" but got " + Arrays.toString(actual));
+        }
+    }
+
+    private static void assertThrows(Class<? extends Throwable> expected, 
ThrowingRunnable runnable) {
+        try {
+            runnable.run();
+        } catch (Throwable t) {
+            if (expected.isInstance(t)) {
+                return;
+            }
+            throw new AssertionError("expected " + expected.getName() + " but 
got " + t.getClass().getName(), t);
+        }
+        throw new AssertionError("expected " + expected.getName());
+    }
+
+    private interface ThrowingRunnable {
+        void run() throws Throwable;
+    }
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFPQBatchResult.java 
b/jni/java/org/apache/paimon/index/ivfpq/IVFPQBatchResult.java
new file mode 100644
index 0000000..1c9f985
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFPQBatchResult.java
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+import java.util.Arrays;
+
+public final class IVFPQBatchResult {
+
+    private final long[] ids;
+    private final float[] distances;
+    private final int queryCount;
+    private final int topK;
+
+    public IVFPQBatchResult(long[] ids, float[] distances, int queryCount, int 
topK) {
+        if (ids == null) {
+            throw new NullPointerException("ids");
+        }
+        if (distances == null) {
+            throw new NullPointerException("distances");
+        }
+        if (queryCount < 0) {
+            throw new IllegalArgumentException("queryCount must be >= 0");
+        }
+        if (topK < 0) {
+            throw new IllegalArgumentException("topK must be >= 0");
+        }
+        int expectedLength = checkedResultLength(queryCount, topK);
+        if (ids.length != expectedLength) {
+            throw new IllegalArgumentException(
+                    "ids length " + ids.length + " != queryCount * topK " + 
expectedLength);
+        }
+        if (distances.length != expectedLength) {
+            throw new IllegalArgumentException(
+                    "distances length " + distances.length + " != queryCount * 
topK " + expectedLength);
+        }
+        this.ids = ids.clone();
+        this.distances = distances.clone();
+        this.queryCount = queryCount;
+        this.topK = topK;
+    }
+
+    public int queryCount() {
+        return queryCount;
+    }
+
+    public int topK() {
+        return topK;
+    }
+
+    public long[] ids() {
+        return ids.clone();
+    }
+
+    public float[] distances() {
+        return distances.clone();
+    }
+
+    public long[] idsForQuery(int queryIndex) {
+        checkQueryIndex(queryIndex);
+        return Arrays.copyOfRange(ids, queryIndex * topK, (queryIndex + 1) * 
topK);
+    }
+
+    public float[] distancesForQuery(int queryIndex) {
+        checkQueryIndex(queryIndex);
+        return Arrays.copyOfRange(distances, queryIndex * topK, (queryIndex + 
1) * topK);
+    }
+
+    private void checkQueryIndex(int queryIndex) {
+        if (queryIndex < 0 || queryIndex >= queryCount) {
+            throw new IndexOutOfBoundsException(
+                    "queryIndex " + queryIndex + " out of range [0, " + 
queryCount + ')');
+        }
+    }
+
+    private static int checkedResultLength(int queryCount, int topK) {
+        long length = (long) queryCount * (long) topK;
+        if (length > Integer.MAX_VALUE) {
+            throw new IllegalArgumentException("queryCount * topK overflows 
int");
+        }
+        return (int) length;
+    }
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFPQNative.java 
b/jni/java/org/apache/paimon/index/ivfpq/IVFPQNative.java
new file mode 100644
index 0000000..83acca7
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFPQNative.java
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+final class IVFPQNative {
+
+    private IVFPQNative() {}
+
+    static native long createWriter(int d, int nlist, int m, int metric, 
boolean useOpq);
+
+    static native void train(long ptr, float[] data, int n);
+
+    static native void addVectors(long ptr, long[] ids, float[] data, int n);
+
+    static native void writeIndex(long ptr, Object streamOutput);
+
+    static native void freeWriter(long ptr);
+
+    static native long openReader(Object streamInput);
+
+    static native IVFPQResult search(long ptr, float[] query, int k, int 
nprobe);
+
+    static native IVFPQResult searchWithRoaringFilter(
+            long ptr, float[] query, int k, int nprobe, byte[] roaringFilter);
+
+    static native int getDimension(long ptr);
+
+    static native long getTotalVectors(long ptr);
+
+    static native IVFPQBatchResult searchBatch(
+            long ptr, float[] queries, int queryCount, int k, int nprobe);
+
+    static native IVFPQBatchResult searchBatchWithRoaringFilter(
+            long ptr, float[] queries, int queryCount, int k, int nprobe, 
byte[] roaringFilter);
+
+    static native void freeReader(long ptr);
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFPQReader.java 
b/jni/java/org/apache/paimon/index/ivfpq/IVFPQReader.java
new file mode 100644
index 0000000..9d53bd8
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFPQReader.java
@@ -0,0 +1,114 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+public final class IVFPQReader implements AutoCloseable {
+
+    private long nativePtr;
+
+    public IVFPQReader(Object input) {
+        if (input == null) {
+            throw new NullPointerException("input");
+        }
+        this.nativePtr = IVFPQNative.openReader(input);
+    }
+
+    private IVFPQReader(long nativePtr) {
+        this.nativePtr = nativePtr;
+    }
+
+    static IVFPQReader fromNativePointerForTesting(long nativePtr) {
+        return new IVFPQReader(nativePtr);
+    }
+
+    public int dimension() {
+        return IVFPQNative.getDimension(requireOpen());
+    }
+
+    public long totalVectors() {
+        return IVFPQNative.getTotalVectors(requireOpen());
+    }
+
+    public IVFPQResult search(float[] query, int topK, int nprobe) {
+        if (query == null) {
+            throw new NullPointerException("query");
+        }
+        validatePositive(topK, "topK");
+        validatePositive(nprobe, "nprobe");
+        return IVFPQNative.search(requireOpen(), query, topK, nprobe);
+    }
+
+    public IVFPQResult search(float[] query, int topK, int nprobe, byte[] 
roaringFilter) {
+        if (query == null) {
+            throw new NullPointerException("query");
+        }
+        if (roaringFilter == null) {
+            throw new NullPointerException("roaringFilter");
+        }
+        validatePositive(topK, "topK");
+        validatePositive(nprobe, "nprobe");
+        return IVFPQNative.searchWithRoaringFilter(requireOpen(), query, topK, 
nprobe, roaringFilter);
+    }
+
+    public IVFPQBatchResult searchBatch(float[] queries, int queryCount, int 
topK, int nprobe) {
+        if (queries == null) {
+            throw new NullPointerException("queries");
+        }
+        validatePositive(queryCount, "queryCount");
+        validatePositive(topK, "topK");
+        validatePositive(nprobe, "nprobe");
+        return IVFPQNative.searchBatch(requireOpen(), queries, queryCount, 
topK, nprobe);
+    }
+
+    public IVFPQBatchResult searchBatch(
+            float[] queries, int queryCount, int topK, int nprobe, byte[] 
roaringFilter) {
+        if (queries == null) {
+            throw new NullPointerException("queries");
+        }
+        if (roaringFilter == null) {
+            throw new NullPointerException("roaringFilter");
+        }
+        validatePositive(queryCount, "queryCount");
+        validatePositive(topK, "topK");
+        validatePositive(nprobe, "nprobe");
+        return IVFPQNative.searchBatchWithRoaringFilter(
+                requireOpen(), queries, queryCount, topK, nprobe, 
roaringFilter);
+    }
+
+    @Override
+    public void close() {
+        long ptr = nativePtr;
+        nativePtr = 0L;
+        if (ptr != 0L) {
+            IVFPQNative.freeReader(ptr);
+        }
+    }
+
+    private long requireOpen() {
+        if (nativePtr == 0L) {
+            throw new IllegalStateException("IVFPQReader is closed");
+        }
+        return nativePtr;
+    }
+
+    private static void validatePositive(int value, String name) {
+        if (value <= 0) {
+            throw new IllegalArgumentException(name + " must be > 0");
+        }
+    }
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFPQResult.java 
b/jni/java/org/apache/paimon/index/ivfpq/IVFPQResult.java
new file mode 100644
index 0000000..afee4da
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFPQResult.java
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+import java.util.Arrays;
+
+public final class IVFPQResult {
+
+    private final long[] ids;
+    private final float[] distances;
+
+    public IVFPQResult(long[] ids, float[] distances) {
+        if (ids == null) {
+            throw new NullPointerException("ids");
+        }
+        if (distances == null) {
+            throw new NullPointerException("distances");
+        }
+        if (ids.length != distances.length) {
+            throw new IllegalArgumentException(
+                    "ids length " + ids.length + " != distances length " + 
distances.length);
+        }
+        this.ids = ids.clone();
+        this.distances = distances.clone();
+    }
+
+    public int size() {
+        return ids.length;
+    }
+
+    public long[] ids() {
+        return ids.clone();
+    }
+
+    public float[] distances() {
+        return distances.clone();
+    }
+
+    @Override
+    public String toString() {
+        return "IVFPQResult{ids=" + Arrays.toString(ids)
+                + ", distances=" + Arrays.toString(distances) + '}';
+    }
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IVFPQWriter.java 
b/jni/java/org/apache/paimon/index/ivfpq/IVFPQWriter.java
new file mode 100644
index 0000000..f18a75a
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/IVFPQWriter.java
@@ -0,0 +1,112 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+public final class IVFPQWriter implements AutoCloseable {
+
+    private final int dimension;
+    private long nativePtr;
+
+    public IVFPQWriter(int dimension, int nlist, int m, Metric metric, boolean 
useOpq) {
+        if (metric == null) {
+            throw new NullPointerException("metric");
+        }
+        validatePositive(dimension, "dimension");
+        validatePositive(nlist, "nlist");
+        validatePositive(m, "m");
+        if (dimension % m != 0) {
+            throw new IllegalArgumentException("dimension must be divisible by 
m");
+        }
+        this.dimension = dimension;
+        this.nativePtr = IVFPQNative.createWriter(dimension, nlist, m, 
metric.code(), useOpq);
+    }
+
+    private IVFPQWriter(long nativePtr, int dimension) {
+        this.nativePtr = nativePtr;
+        this.dimension = dimension;
+    }
+
+    static IVFPQWriter fromNativePointerForTesting(long nativePtr, int 
dimension) {
+        return new IVFPQWriter(nativePtr, dimension);
+    }
+
+    public int dimension() {
+        return dimension;
+    }
+
+    public void train(float[] data, int vectorCount) {
+        validateVectors(data, vectorCount);
+        IVFPQNative.train(requireOpen(), data, vectorCount);
+    }
+
+    public void addVectors(long[] ids, float[] data, int vectorCount) {
+        if (ids == null) {
+            throw new NullPointerException("ids");
+        }
+        validateVectors(data, vectorCount);
+        if (ids.length < vectorCount) {
+            throw new IllegalArgumentException(
+                    "ids length " + ids.length + " < vectorCount " + 
vectorCount);
+        }
+        IVFPQNative.addVectors(requireOpen(), ids, data, vectorCount);
+    }
+
+    public void writeIndex(Object output) {
+        if (output == null) {
+            throw new NullPointerException("output");
+        }
+        IVFPQNative.writeIndex(requireOpen(), output);
+    }
+
+    @Override
+    public void close() {
+        long ptr = nativePtr;
+        nativePtr = 0L;
+        if (ptr != 0L) {
+            IVFPQNative.freeWriter(ptr);
+        }
+    }
+
+    private void validateVectors(float[] data, int vectorCount) {
+        if (data == null) {
+            throw new NullPointerException("data");
+        }
+        validatePositive(vectorCount, "vectorCount");
+        long expected = (long) vectorCount * (long) dimension;
+        if (expected > Integer.MAX_VALUE) {
+            throw new IllegalArgumentException("vectorCount * dimension 
overflows int");
+        }
+        if (data.length < expected) {
+            throw new IllegalArgumentException(
+                    "data length " + data.length + " < vectorCount * dimension 
" + expected);
+        }
+    }
+
+    private long requireOpen() {
+        if (nativePtr == 0L) {
+            throw new IllegalStateException("IVFPQWriter is closed");
+        }
+        return nativePtr;
+    }
+
+    private static void validatePositive(int value, String name) {
+        if (value <= 0) {
+            throw new IllegalArgumentException(name + " must be > 0");
+        }
+    }
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/Metric.java 
b/jni/java/org/apache/paimon/index/ivfpq/Metric.java
new file mode 100644
index 0000000..c31dbc2
--- /dev/null
+++ b/jni/java/org/apache/paimon/index/ivfpq/Metric.java
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+public enum Metric {
+    L2(0),
+    INNER_PRODUCT(1),
+    COSINE(2);
+
+    private final int code;
+
+    Metric(int code) {
+        this.code = code;
+    }
+
+    public int code() {
+        return code;
+    }
+}
diff --git a/python/Cargo.lock b/python/Cargo.lock
index de3821d..3075ad6 100644
--- a/python/Cargo.lock
+++ b/python/Cargo.lock
@@ -246,6 +246,7 @@ dependencies = [
  "numpy",
  "paimon-vindex-core",
  "pyo3",
+ "roaring",
 ]
 
 [[package]]
diff --git a/python/Cargo.toml b/python/Cargo.toml
index e18031d..9fee3f4 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -25,7 +25,15 @@ license = "Apache-2.0"
 name = "paimon_vindex"
 crate-type = ["cdylib"]
 
+[features]
+default = ["extension-module"]
+extension-module = ["pyo3/extension-module"]
+auto-initialize = ["pyo3/auto-initialize"]
+
 [dependencies]
 paimon-vindex-core = { path = "../core" }
-pyo3 = { version = "0.22", features = ["extension-module"] }
+pyo3 = "0.22"
 numpy = "0.22"
+
+[dev-dependencies]
+roaring = "0.11"
diff --git a/python/src/lib.rs b/python/src/lib.rs
index fa71d64..694dd0b 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -17,8 +17,14 @@
 
 #![allow(clippy::useless_conversion)]
 
-use numpy::{PyArray1, PyReadonlyArray1};
-use paimon_vindex_core::io::{IVFPQIndexReader, SeekRead};
+use numpy::{
+    PyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2, 
PyUntypedArrayMethods,
+};
+use paimon_vindex_core::distance::MetricType;
+use paimon_vindex_core::io::{write_index, IVFPQIndexReader, SeekRead};
+use paimon_vindex_core::ivfpq::{
+    search_batch_reader, search_batch_reader_roaring_filter, IVFPQIndex,
+};
 use pyo3::exceptions::{PyIOError, PyValueError};
 use pyo3::prelude::*;
 use pyo3::types::{PyAny, PyBytes};
@@ -63,11 +69,120 @@ impl SeekRead for PyFileStream {
     }
 }
 
+/// Python file object wrapper implementing SeekWrite.
+struct PyOutputStream {
+    file: PyObject,
+    pos: u64,
+}
+
+impl paimon_vindex_core::io::SeekWrite for PyOutputStream {
+    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
+        Python::with_gil(|py| {
+            let bytes = PyBytes::new_bound(py, buf);
+            let written = self
+                .file
+                .call_method1(py, "write", (bytes,))
+                .map_err(|e| io::Error::other(format!("write: {}", e)))?
+                .extract::<usize>(py)
+                .map_err(|e| io::Error::other(format!("write return value: 
{}", e)))?;
+            if written != buf.len() {
+                return Err(io::Error::new(
+                    io::ErrorKind::WriteZero,
+                    format!("write accepted {} of {} bytes", written, 
buf.len()),
+                ));
+            }
+            self.pos += buf.len() as u64;
+            Ok(())
+        })
+    }
+
+    fn pos(&self) -> u64 {
+        self.pos
+    }
+}
+
+fn parse_metric(metric: &str) -> PyResult<MetricType> {
+    match metric.to_ascii_lowercase().as_str() {
+        "l2" => Ok(MetricType::L2),
+        "inner_product" | "ip" => Ok(MetricType::InnerProduct),
+        "cosine" => Ok(MetricType::Cosine),
+        _ => Err(PyValueError::new_err(format!(
+            "unknown metric '{}'; expected 'l2', 'inner_product', or 'cosine'",
+            metric
+        ))),
+    }
+}
+
+fn validate_positive(value: usize, name: &str) -> PyResult<()> {
+    if value == 0 {
+        Err(PyValueError::new_err(format!("{} must be > 0", name)))
+    } else {
+        Ok(())
+    }
+}
+
+fn decode_filter_bytes<'a>(
+    filter_bytes: Option<&'a Bound<'_, PyAny>>,
+) -> PyResult<Option<&'a [u8]>> {
+    if let Some(filter_obj) = filter_bytes {
+        let bytes: &Bound<PyBytes> = filter_obj
+            .downcast()
+            .map_err(|_| PyValueError::new_err("filter_bytes must be bytes"))?;
+        Ok(Some(bytes.as_bytes()))
+    } else {
+        Ok(None)
+    }
+}
+
+fn pyarray2_from_flat<'py, T: numpy::Element + Clone>(
+    py: Python<'py>,
+    data: Vec<T>,
+    rows: usize,
+    cols: usize,
+) -> PyResult<Bound<'py, PyArray2<T>>> {
+    let matrix = data
+        .chunks(cols)
+        .map(|chunk| chunk.to_vec())
+        .collect::<Vec<_>>();
+    debug_assert_eq!(matrix.len(), rows);
+    PyArray::from_vec2_bound(py, &matrix)
+        .map_err(|e| PyValueError::new_err(format!("reshape batch result: {}", 
e)))
+}
+
+fn validate_matrix_shape(
+    shape: &[usize],
+    dimension: usize,
+    value_name: &str,
+    dimension_name: &str,
+) -> PyResult<usize> {
+    let row_count = shape[0];
+    let actual_dimension = shape[1];
+    if actual_dimension != dimension {
+        return Err(PyValueError::new_err(format!(
+            "{} dimension {} != {} {}",
+            value_name, actual_dimension, dimension_name, dimension
+        )));
+    }
+    if row_count == 0 {
+        return Err(PyValueError::new_err(format!(
+            "{} must contain at least one row",
+            value_name
+        )));
+    }
+    Ok(row_count)
+}
+
 #[pyclass]
 struct IVFPQReader {
     inner: IVFPQIndexReader<PyFileStream>,
 }
 
+#[pyclass]
+struct IVFPQWriter {
+    index: Option<IVFPQIndex>,
+    dimension: usize,
+}
+
 #[pymethods]
 impl IVFPQReader {
     #[new]
@@ -117,19 +232,12 @@ impl IVFPQReader {
                 self.inner.d
             )));
         }
-        if top_k == 0 {
-            return Err(PyValueError::new_err("top_k must be > 0"));
-        }
-        if nprobe == 0 {
-            return Err(PyValueError::new_err("nprobe must be > 0"));
-        }
+        validate_positive(top_k, "top_k")?;
+        validate_positive(nprobe, "nprobe")?;
 
-        let (ids, dists) = if let Some(filter_obj) = filter_bytes {
-            let bytes: &Bound<PyBytes> = filter_obj
-                .downcast()
-                .map_err(|_| PyValueError::new_err("filter_bytes must be 
bytes"))?;
+        let (ids, dists) = if let Some(bytes) = 
decode_filter_bytes(filter_bytes)? {
             self.inner
-                .search_with_roaring_filter(query_slice, top_k, nprobe, 
bytes.as_bytes())
+                .search_with_roaring_filter(query_slice, top_k, nprobe, bytes)
                 .map_err(|e| PyIOError::new_err(format!("Search failed: {}", 
e)))?
         } else {
             self.inner
@@ -143,7 +251,139 @@ impl IVFPQReader {
         Ok((id_array, dist_array))
     }
 
+    #[allow(clippy::type_complexity)]
+    #[pyo3(signature = (queries, top_k, nprobe, filter_bytes=None))]
+    fn search_batch<'py>(
+        &mut self,
+        py: Python<'py>,
+        queries: PyReadonlyArray2<f32>,
+        top_k: usize,
+        nprobe: usize,
+        filter_bytes: Option<&Bound<'_, PyAny>>,
+    ) -> PyResult<(Bound<'py, PyArray2<i64>>, Bound<'py, PyArray2<f32>>)> {
+        let shape = queries.shape();
+        let query_count = validate_matrix_shape(shape, self.inner.d, "query", 
"index dimension")?;
+        validate_positive(top_k, "top_k")?;
+        validate_positive(nprobe, "nprobe")?;
+
+        let query_slice = queries.as_slice().map_err(|_| {
+            PyValueError::new_err("queries must be a contiguous 
two-dimensional float32 array")
+        })?;
+
+        let (ids, dists) = if let Some(bytes) = 
decode_filter_bytes(filter_bytes)? {
+            search_batch_reader_roaring_filter(
+                &mut self.inner,
+                query_slice,
+                query_count,
+                top_k,
+                nprobe,
+                bytes,
+            )
+            .map_err(|e| PyIOError::new_err(format!("Batch search failed: {}", 
e)))?
+        } else {
+            search_batch_reader(&mut self.inner, query_slice, query_count, 
top_k, nprobe)
+                .map_err(|e| PyIOError::new_err(format!("Batch search failed: 
{}", e)))?
+        };
+
+        Ok((
+            pyarray2_from_flat(py, ids, query_count, top_k)?,
+            pyarray2_from_flat(py, dists, query_count, top_k)?,
+        ))
+    }
+
+    fn close(&mut self) -> PyResult<()> {
+        Ok(())
+    }
+
+    fn __enter__(slf: Py<Self>) -> Py<Self> {
+        slf
+    }
+
+    #[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
+    fn __exit__(
+        &mut self,
+        _exc_type: Option<&Bound<'_, pyo3::types::PyType>>,
+        _exc_val: Option<&Bound<'_, pyo3::types::PyAny>>,
+        _exc_tb: Option<&Bound<'_, pyo3::types::PyAny>>,
+    ) -> PyResult<bool> {
+        self.close()?;
+        Ok(false)
+    }
+}
+
+#[pymethods]
+impl IVFPQWriter {
+    #[new]
+    #[pyo3(signature = (dimension, nlist, m, metric="l2", use_opq=false))]
+    fn new(
+        dimension: usize,
+        nlist: usize,
+        m: usize,
+        metric: &str,
+        use_opq: bool,
+    ) -> PyResult<Self> {
+        validate_positive(dimension, "dimension")?;
+        validate_positive(nlist, "nlist")?;
+        validate_positive(m, "m")?;
+        if !dimension.is_multiple_of(m) {
+            return Err(PyValueError::new_err(format!(
+                "dimension {} must be divisible by m {}",
+                dimension, m
+            )));
+        }
+        let metric = parse_metric(metric)?;
+        Ok(IVFPQWriter {
+            index: Some(IVFPQIndex::new(dimension, nlist, m, metric, use_opq)),
+            dimension,
+        })
+    }
+
+    #[getter]
+    fn dimension(&self) -> usize {
+        self.dimension
+    }
+
+    fn train(&mut self, data: PyReadonlyArray2<f32>) -> PyResult<()> {
+        let shape = data.shape();
+        let row_count = validate_matrix_shape(shape, self.dimension, "data", 
"writer dimension")?;
+        let data_slice = data.as_slice().map_err(|_| {
+            PyValueError::new_err("data must be a contiguous two-dimensional 
float32 array")
+        })?;
+        self.index_mut()?.train(data_slice, row_count);
+        Ok(())
+    }
+
+    fn add_vectors(
+        &mut self,
+        ids: PyReadonlyArray1<i64>,
+        data: PyReadonlyArray2<f32>,
+    ) -> PyResult<()> {
+        let shape = data.shape();
+        let row_count = validate_matrix_shape(shape, self.dimension, "data", 
"writer dimension")?;
+        let id_slice = ids.as_slice()?;
+        if id_slice.len() != row_count {
+            return Err(PyValueError::new_err(format!(
+                "ids length {} != vector count {}",
+                id_slice.len(),
+                row_count
+            )));
+        }
+        let data_slice = data.as_slice().map_err(|_| {
+            PyValueError::new_err("data must be a contiguous two-dimensional 
float32 array")
+        })?;
+        self.index_mut()?.add(data_slice, id_slice, row_count);
+        Ok(())
+    }
+
+    fn write(&mut self, file: PyObject) -> PyResult<()> {
+        let mut stream = PyOutputStream { file, pos: 0 };
+        write_index(self.index_ref()?, &mut stream)
+            .map_err(|e| PyIOError::new_err(format!("Failed to write index: 
{}", e)))?;
+        Ok(())
+    }
+
     fn close(&mut self) -> PyResult<()> {
+        self.index = None;
         Ok(())
     }
 
@@ -163,8 +403,230 @@ impl IVFPQReader {
     }
 }
 
+impl IVFPQWriter {
+    fn index_ref(&self) -> PyResult<&IVFPQIndex> {
+        self.index
+            .as_ref()
+            .ok_or_else(|| PyValueError::new_err("IVFPQWriter is closed"))
+    }
+
+    fn index_mut(&mut self) -> PyResult<&mut IVFPQIndex> {
+        self.index
+            .as_mut()
+            .ok_or_else(|| PyValueError::new_err("IVFPQWriter is closed"))
+    }
+}
+
 #[pymodule]
 fn paimon_vindex(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
     m.add_class::<IVFPQReader>()?;
+    m.add_class::<IVFPQWriter>()?;
     Ok(())
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use numpy::{PyArray, PyArrayMethods};
+    use paimon_vindex_core::distance::MetricType;
+    use paimon_vindex_core::io::{write_index, PosWriter};
+    use paimon_vindex_core::ivfpq::IVFPQIndex;
+    use pyo3::types::PyBytes;
+    use roaring::RoaringTreemap;
+
+    fn generate_clustered_data(n: usize, d: usize, clusters: usize) -> 
Vec<f32> {
+        let mut data = vec![0.0; n * d];
+        for i in 0..n {
+            let cluster = i % clusters;
+            for j in 0..d {
+                data[i * d + j] = cluster as f32 * 10.0 + j as f32 * 0.01 + i 
as f32 * 0.0001;
+            }
+        }
+        data
+    }
+
+    fn build_test_index_bytes() -> Vec<u8> {
+        let d = 16;
+        let nlist = 4;
+        let m = 4;
+        let n = 500;
+        let data = generate_clustered_data(n, d, 4);
+        let ids: Vec<i64> = (0..n as i64).collect();
+
+        let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, false);
+        index.train(&data, n);
+        index.add(&data, &ids, n);
+
+        let mut buf = Vec::new();
+        let mut writer = PosWriter::new(&mut buf);
+        write_index(&index, &mut writer).unwrap();
+        buf
+    }
+
+    #[test]
+    fn python_batch_search_returns_2d_numpy_arrays() {
+        Python::with_gil(|py| {
+            let io = py.import_bound("io").unwrap();
+            let file = io
+                .getattr("BytesIO")
+                .unwrap()
+                .call1((PyBytes::new_bound(py, &build_test_index_bytes()),))
+                .unwrap();
+            let mut reader = IVFPQReader::new(file.unbind()).unwrap();
+            let queries = generate_clustered_data(3, reader.dimension(), 4);
+            let query_array = PyArray::from_vec2_bound(
+                py,
+                &queries
+                    .chunks(reader.dimension())
+                    .map(|chunk| chunk.to_vec())
+                    .collect::<Vec<_>>(),
+            )
+            .unwrap();
+
+            let (ids, dists) = reader
+                .search_batch(py, query_array.readonly(), 5, 2, None)
+                .unwrap();
+
+            assert_eq!(ids.shape(), &[3, 5]);
+            assert_eq!(dists.shape(), &[3, 5]);
+            assert_eq!(ids.readonly().as_slice().unwrap()[0], 0);
+        });
+    }
+
+    #[test]
+    fn python_batch_search_accepts_roaring_filter_bytes() {
+        Python::with_gil(|py| {
+            let io = py.import_bound("io").unwrap();
+            let file = io
+                .getattr("BytesIO")
+                .unwrap()
+                .call1((PyBytes::new_bound(py, &build_test_index_bytes()),))
+                .unwrap();
+            let mut reader = IVFPQReader::new(file.unbind()).unwrap();
+            let queries = generate_clustered_data(3, reader.dimension(), 4);
+            let query_array = PyArray::from_vec2_bound(
+                py,
+                &queries
+                    .chunks(reader.dimension())
+                    .map(|chunk| chunk.to_vec())
+                    .collect::<Vec<_>>(),
+            )
+            .unwrap();
+
+            let mut allowed = RoaringTreemap::new();
+            for id in (0..500u64).filter(|id| id % 7 == 0) {
+                allowed.insert(id);
+            }
+            let mut filter_bytes = Vec::new();
+            allowed.serialize_into(&mut filter_bytes).unwrap();
+            let filter = PyBytes::new_bound(py, &filter_bytes);
+
+            let (ids, _) = reader
+                .search_batch(py, query_array.readonly(), 5, 2, 
Some(filter.as_any()))
+                .unwrap();
+
+            assert_eq!(ids.shape(), &[3, 5]);
+            for &id in ids.readonly().as_slice().unwrap() {
+                if id >= 0 {
+                    assert_eq!(id % 7, 0);
+                }
+            }
+        });
+    }
+
+    #[test]
+    fn python_writer_can_build_an_index_for_reader() {
+        Python::with_gil(|py| {
+            let io = py.import_bound("io").unwrap();
+            let output = io.getattr("BytesIO").unwrap().call0().unwrap();
+            let mut writer = IVFPQWriter::new(16, 4, 4, "l2", false).unwrap();
+            let data = generate_clustered_data(500, 16, 4);
+            let ids: Vec<i64> = (0..500).collect();
+
+            let train = PyArray::from_vec2_bound(
+                py,
+                &data
+                    .chunks(16)
+                    .map(|chunk| chunk.to_vec())
+                    .collect::<Vec<_>>(),
+            )
+            .unwrap();
+            let id_array = PyArray1::from_vec_bound(py, ids);
+
+            writer.train(train.readonly()).unwrap();
+            writer
+                .add_vectors(id_array.readonly(), train.readonly())
+                .unwrap();
+            writer.write(output.as_any().clone().unbind()).unwrap();
+
+            output.call_method1("seek", (0,)).unwrap();
+            let mut reader = IVFPQReader::new(output.unbind()).unwrap();
+            let query = PyArray1::from_vec_bound(py, data[0..16].to_vec());
+
+            let (result_ids, _) = reader.search(py, query.readonly(), 5, 2, 
None).unwrap();
+
+            assert_eq!(result_ids.len(), 5);
+            assert_eq!(result_ids.readonly().as_slice().unwrap()[0], 0);
+        });
+    }
+
+    #[test]
+    fn python_batch_search_validates_query_shape() {
+        Python::with_gil(|py| {
+            let io = py.import_bound("io").unwrap();
+            let file = io
+                .getattr("BytesIO")
+                .unwrap()
+                .call1((PyBytes::new_bound(py, &build_test_index_bytes()),))
+                .unwrap();
+            let mut reader = IVFPQReader::new(file.unbind()).unwrap();
+            let wrong_dim = PyArray::from_vec2_bound(py, &[vec![0.0f32; 
15]]).unwrap();
+
+            let err = reader
+                .search_batch(py, wrong_dim.readonly(), 5, 2, None)
+                .unwrap_err();
+
+            assert!(err
+                .to_string()
+                .contains("query dimension 15 != index dimension 16"));
+        });
+    }
+
+    #[test]
+    fn python_writer_rejects_short_writes() {
+        Python::with_gil(|py| {
+            let locals = pyo3::types::PyDict::new_bound(py);
+            py.run_bound(
+                r#"
+class ShortWriter:
+    def write(self, data):
+        return max(0, len(data) - 1)
+"#,
+                None,
+                Some(&locals),
+            )
+            .unwrap();
+            let output = locals
+                .get_item("ShortWriter")
+                .unwrap()
+                .unwrap()
+                .call0()
+                .unwrap();
+            let mut writer = IVFPQWriter::new(16, 4, 4, "l2", false).unwrap();
+            let data = generate_clustered_data(500, 16, 4);
+            let train = PyArray::from_vec2_bound(
+                py,
+                &data
+                    .chunks(16)
+                    .map(|chunk| chunk.to_vec())
+                    .collect::<Vec<_>>(),
+            )
+            .unwrap();
+
+            writer.train(train.readonly()).unwrap();
+            let err = writer.write(output.unbind()).unwrap_err();
+
+            assert!(err.to_string().contains("write accepted"));
+        });
+    }
+}


Reply via email to