This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 516fe27 Simplify vector index options API (#31)
516fe27 is described below
commit 516fe27e08be8ab9e0fb6c86bc66311c1062a076
Author: Jingsong Lee <[email protected]>
AuthorDate: Wed Jun 10 22:59:10 2026 +0800
Simplify vector index options API (#31)
---
.github/workflows/ci.yml | 13 +-
README.md | 53 +-
core/src/index.rs | 291 +++++++++-
java/pom.xml | 70 +++
.../org/apache/paimon/index/ivfpq/IndexType.java | 0
.../java/org/apache/paimon/index/ivfpq/Metric.java | 0
.../paimon/index/ivfpq/VectorIndexInput.java | 0
.../paimon/index/ivfpq/VectorIndexMetadata.java | 20 +-
.../paimon/index/ivfpq/VectorIndexNative.java | 13 +-
.../paimon/index/ivfpq/VectorIndexReader.java | 39 +-
.../paimon/index/ivfpq/VectorIndexWriter.java | 66 +--
.../index/ivfpq/VectorSearchBatchResult.java | 0
.../paimon/index/ivfpq/VectorSearchResult.java | 0
.../paimon/index/ivfpq/VectorIndexJavaApiTest.java | 69 +--
.../ivfpq/VectorIndexNativeHandleSafetyTest.java | 14 +-
.../ivfpq/VectorIndexNativePanicBoundaryTest.java | 19 +-
.../ivfpq/VectorIndexNativeValidationTest.java | 185 +++++++
.../org/apache/paimon/index/ivfpq/HnswConfig.java | 54 --
.../apache/paimon/index/ivfpq/IvfFlatConfig.java | 25 -
.../paimon/index/ivfpq/IvfHnswFlatConfig.java | 35 --
.../apache/paimon/index/ivfpq/IvfHnswSqConfig.java | 35 --
.../org/apache/paimon/index/ivfpq/IvfPqConfig.java | 47 --
.../paimon/index/ivfpq/VectorIndexConfig.java | 94 ----
jni/src/lib.rs | 269 ++++-----
python/src/lib.rs | 605 +++++----------------
25 files changed, 918 insertions(+), 1098 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 2293a62..56ffb84 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -105,19 +105,18 @@ jobs:
distribution: 'temurin'
- name: Test Java API
- run: |
- mkdir -p target/java-api-test
- javac -source 8 -target 8 -d target/java-api-test $(find jni/java
jni/java-test -name '*.java')
- java -cp target/java-api-test
org.apache.paimon.index.ivfpq.VectorIndexJavaApiTest
+ run: mvn -f java/pom.xml test
- name: Build JNI library
run: cargo build -p paimon-vindex-jni --release
- - name: Test JNI panic boundary
+ - name: Test JNI native behavior
run: |
- java -cp target/java-api-test
org.apache.paimon.index.ivfpq.VectorIndexNativePanicBoundaryTest \
+ java -cp java/target/test-classes:java/target/classes
org.apache.paimon.index.ivfpq.VectorIndexNativeValidationTest \
+ "$(pwd)/target/release/libpaimon_vindex_jni.so"
+ java -cp java/target/test-classes:java/target/classes
org.apache.paimon.index.ivfpq.VectorIndexNativePanicBoundaryTest \
"$(pwd)/target/release/libpaimon_vindex_jni.so"
- java -cp target/java-api-test
org.apache.paimon.index.ivfpq.VectorIndexNativeHandleSafetyTest \
+ java -cp java/target/test-classes:java/target/classes
org.apache.paimon.index.ivfpq.VectorIndexNativeHandleSafetyTest \
"$(pwd)/target/release/libpaimon_vindex_jni.so"
python-build:
diff --git a/README.md b/README.md
index 74efba9..56c588f 100644
--- a/README.md
+++ b/README.md
@@ -128,19 +128,25 @@ VectorIndexConfig::IvfHnswFlat {
### Java/JNI
```java
-import org.apache.paimon.index.ivfpq.HnswConfig;
-import org.apache.paimon.index.ivfpq.Metric;
-import org.apache.paimon.index.ivfpq.VectorIndexConfig;
+import java.util.HashMap;
+import java.util.Map;
+
import org.apache.paimon.index.ivfpq.VectorIndexInput;
import org.apache.paimon.index.ivfpq.VectorIndexMetadata;
import org.apache.paimon.index.ivfpq.VectorIndexReader;
import org.apache.paimon.index.ivfpq.VectorSearchResult;
import org.apache.paimon.index.ivfpq.VectorIndexWriter;
-VectorIndexConfig config =
- VectorIndexConfig.ivfHnswSq(128, 1024, Metric.L2, HnswConfig.DEFAULT);
+Map<String, String> options = new HashMap<>();
+options.put("index.type", "ivf_hnsw_sq");
+options.put("dimension", "128");
+options.put("nlist", "1024");
+options.put("metric", "l2");
+options.put("hnsw.m", "20");
+options.put("hnsw.ef-construction", "150");
+options.put("hnsw.max-level", "7");
-try (VectorIndexWriter writer = new VectorIndexWriter(config)) {
+try (VectorIndexWriter writer = new VectorIndexWriter(options)) {
writer.train(trainingVectors, trainingCount);
writer.addVectors(rowIds, vectors, vectorCount);
writer.writeIndex(vectorIndexOutput);
@@ -153,18 +159,13 @@ try (VectorIndexReader reader = new
VectorIndexReader(vectorIndexInput)) {
```
The Java package currently remains `org.apache.paimon.index.ivfpq`, but the API
-surface is unified and supports `ivfFlat`, `ivfPq`, `ivfHnswFlat`, and
-`ivfHnswSq` configs.
+surface uses string options so it maps directly to Paimon table/index
+properties. Rust parses and validates the options when the writer is created.
### Python
```python
-from paimon_vindex import (
- HnswConfig,
- IvfHnswSqConfig,
- VectorIndexReader,
- VectorIndexWriter,
-)
+from paimon_vindex import VectorIndexReader, VectorIndexWriter
class VectorIndexInput:
@@ -175,13 +176,16 @@ class VectorIndexInput:
return [self.data[pos : pos + length] for pos, length in ranges]
-config = IvfHnswSqConfig(
- 128,
- 1024,
- metric="l2",
- hnsw=HnswConfig(m=20, ef_construction=150, max_level=7),
-)
-writer = VectorIndexWriter(config)
+options = {
+ "index.type": "ivf_hnsw_sq",
+ "dimension": "128",
+ "nlist": "1024",
+ "metric": "l2",
+ "hnsw.m": "20",
+ "hnsw.ef-construction": "150",
+ "hnsw.max-level": "7",
+}
+writer = VectorIndexWriter(options)
writer.train(training_vectors)
writer.add_vectors(row_ids, vectors)
writer.write(output)
@@ -190,7 +194,6 @@ reader = VectorIndexReader(VectorIndexInput(index_bytes))
ids, distances = reader.search(query, top_k=10, nprobe=16, ef_search=80)
```
-Python also exposes `IvfFlatConfig`, `IvfPqConfig`, and `IvfHnswFlatConfig`.
`search` returns one-dimensional NumPy arrays for a single query, while
`search_batch` accepts a two-dimensional query array and returns arrays shaped
as `(query_count, top_k)`.
@@ -245,6 +248,12 @@ cargo test --workspace
cargo clippy --workspace --all-targets
```
+Java API tests are run from the JNI Java module:
+
+```bash
+mvn -f java/pom.xml test
+```
+
Python extension tests are run from the `python` package:
```bash
diff --git a/core/src/index.rs b/core/src/index.rs
index 4b9aae6..6fcaa00 100644
--- a/core/src/index.rs
+++ b/core/src/index.rs
@@ -37,6 +37,7 @@ use crate::ivfpq::{
search_batch_reader, search_batch_reader_roaring_filter,
search_with_reader,
search_with_reader_roaring_filter, IVFPQIndex,
};
+use std::collections::{HashMap, HashSet};
use std::io;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -98,6 +99,51 @@ pub enum VectorIndexConfig {
}
impl VectorIndexConfig {
+ pub fn from_options(options: &HashMap<String, String>) -> io::Result<Self>
{
+ let mut options = ConfigOptions::new(options)?;
+ let index_type =
parse_index_type_option(&options.required("index.type")?)?;
+ let dimension = parse_usize_option("dimension",
&options.required("dimension")?);
+ let nlist = parse_usize_option("nlist", &options.required("nlist")?);
+ let metric = match options.optional("metric") {
+ Some(metric) => parse_metric_option(&metric)?,
+ None => MetricType::L2,
+ };
+
+ let config = match index_type {
+ IndexType::IvfFlat => Self::IvfFlat {
+ dimension: dimension?,
+ nlist: nlist?,
+ metric,
+ },
+ IndexType::IvfPq => Self::IvfPq {
+ dimension: dimension?,
+ nlist: nlist?,
+ m: parse_usize_option("pq.m", &options.required("pq.m")?)?,
+ metric,
+ use_opq: match options.optional("use-opq") {
+ Some(use_opq) => parse_bool_option("use-opq", &use_opq)?,
+ None => false,
+ },
+ },
+ IndexType::IvfHnswFlat => Self::IvfHnswFlat {
+ dimension: dimension?,
+ nlist: nlist?,
+ metric,
+ hnsw: parse_hnsw_options(&mut options)?,
+ },
+ IndexType::IvfHnswSq => Self::IvfHnswSq {
+ dimension: dimension?,
+ nlist: nlist?,
+ metric,
+ hnsw: parse_hnsw_options(&mut options)?,
+ },
+ };
+
+ options.reject_unknown()?;
+ validate_config(&config)?;
+ Ok(config)
+ }
+
pub fn index_type(&self) -> IndexType {
match self {
Self::IvfFlat { .. } => IndexType::IvfFlat,
@@ -126,6 +172,123 @@ impl VectorIndexConfig {
}
}
+struct ConfigOptions {
+ values: HashMap<String, String>,
+ used: HashSet<String>,
+}
+
+impl ConfigOptions {
+ fn new(options: &HashMap<String, String>) -> io::Result<Self> {
+ let mut values = HashMap::new();
+ for (key, value) in options {
+ let key = key.trim().to_string();
+ if key.is_empty() {
+ return Err(invalid_input("option key must not be empty"));
+ }
+ if values.insert(key.clone(), value.clone()).is_some() {
+ return Err(invalid_input(format!("duplicate option key '{}'",
key)));
+ }
+ }
+ Ok(Self {
+ values,
+ used: HashSet::new(),
+ })
+ }
+
+ fn required(&mut self, key: &str) -> io::Result<String> {
+ self.optional(key)
+ .ok_or_else(|| invalid_input(format!("missing required option
'{}'", key)))
+ }
+
+ fn optional(&mut self, key: &str) -> Option<String> {
+ if let Some(value) = self.values.get(key) {
+ self.used.insert(key.to_string());
+ Some(value.clone())
+ } else {
+ None
+ }
+ }
+
+ fn reject_unknown(&self) -> io::Result<()> {
+ let mut unknown = self
+ .values
+ .keys()
+ .filter(|key| !self.used.contains(*key))
+ .cloned()
+ .collect::<Vec<_>>();
+ if unknown.is_empty() {
+ Ok(())
+ } else {
+ unknown.sort();
+ Err(invalid_input(format!(
+ "unknown vector index option(s): {}",
+ unknown.join(", ")
+ )))
+ }
+ }
+}
+
+fn parse_hnsw_options(options: &mut ConfigOptions) ->
io::Result<HnswBuildParams> {
+ let defaults = HnswBuildParams::default();
+ Ok(HnswBuildParams {
+ m: match options.optional("hnsw.m") {
+ Some(value) => parse_usize_option("hnsw.m", &value)?,
+ None => defaults.m,
+ },
+ ef_construction: match options.optional("hnsw.ef-construction") {
+ Some(value) => parse_usize_option("hnsw.ef-construction", &value)?,
+ None => defaults.ef_construction,
+ },
+ max_level: match options.optional("hnsw.max-level") {
+ Some(value) => parse_usize_option("hnsw.max-level", &value)?,
+ None => defaults.max_level,
+ },
+ })
+}
+
+fn parse_index_type_option(value: &str) -> io::Result<IndexType> {
+ match value.trim() {
+ "ivf_flat" => Ok(IndexType::IvfFlat),
+ "ivf_pq" => Ok(IndexType::IvfPq),
+ "ivf_hnsw_flat" => Ok(IndexType::IvfHnswFlat),
+ "ivf_hnsw_sq" => Ok(IndexType::IvfHnswSq),
+ _ => Err(invalid_input(format!(
+ "unknown index.type '{}'; expected ivf_flat, ivf_pq,
ivf_hnsw_flat, or ivf_hnsw_sq",
+ value
+ ))),
+ }
+}
+
+fn parse_metric_option(value: &str) -> io::Result<MetricType> {
+ match value.trim() {
+ "l2" => Ok(MetricType::L2),
+ "inner_product" => Ok(MetricType::InnerProduct),
+ "cosine" => Ok(MetricType::Cosine),
+ _ => Err(invalid_input(format!(
+ "unknown metric '{}'; expected l2, inner_product, or cosine",
+ value
+ ))),
+ }
+}
+
+fn parse_usize_option(name: &str, value: &str) -> io::Result<usize> {
+ value
+ .trim()
+ .parse::<usize>()
+ .map_err(|_| invalid_input(format!("option '{}' must be a positive
integer", name)))
+}
+
+fn parse_bool_option(name: &str, value: &str) -> io::Result<bool> {
+ match value.trim() {
+ "true" => Ok(true),
+ "false" => Ok(false),
+ _ => Err(invalid_input(format!(
+ "option '{}' must be true or false",
+ name
+ ))),
+ }
+}
+
#[derive(Debug, Clone, Copy)]
pub struct VectorSearchParams {
pub top_k: usize,
@@ -249,14 +412,10 @@ impl VectorIndexWriter {
pub fn add_vectors(&mut self, ids: &[i64], data: &[f32], n: usize) ->
io::Result<()> {
validate_vectors(data, n, self.dimension(), "vector data")?;
- if ids.len() < n {
+ if ids.len() != n {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
- format!(
- "ids length {} is shorter than vector count {}",
- ids.len(),
- n
- ),
+ format!("ids length {} does not match vector count {}",
ids.len(), n),
));
}
match self {
@@ -532,15 +691,16 @@ fn validate_hnsw_params(params: HnswBuildParams) ->
io::Result<()> {
fn validate_positive(value: usize, name: &str) -> io::Result<()> {
if value == 0 {
- Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- format!("{} must be greater than 0", name),
- ))
+ Err(invalid_input(format!("{} must be greater than 0", name)))
} else {
Ok(())
}
}
+fn invalid_input(message: impl Into<String>) -> io::Error {
+ io::Error::new(io::ErrorKind::InvalidInput, message.into())
+}
+
fn validate_vectors(data: &[f32], n: usize, dimension: usize, value_name:
&str) -> io::Result<()> {
validate_positive(n, "vector count")?;
let expected_len = n.checked_mul(dimension).ok_or_else(|| {
@@ -549,11 +709,11 @@ fn validate_vectors(data: &[f32], n: usize, dimension:
usize, value_name: &str)
"vector count * dimension overflows usize",
)
})?;
- if data.len() < expected_len {
+ if data.len() != expected_len {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
- "{} length {} is shorter than vector count * dimension {}",
+ "{} length {} does not match vector count * dimension {}",
value_name,
data.len(),
expected_len
@@ -660,4 +820,111 @@ mod tests {
};
assert!(err.to_string().contains("must be divisible"));
}
+
+ fn options(values: &[(&str, &str)]) -> HashMap<String, String> {
+ values
+ .iter()
+ .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
+ .collect()
+ }
+
+ #[test]
+ fn config_from_options_parses_all_index_types() {
+ assert_eq!(
+ VectorIndexConfig::from_options(&options(&[
+ ("index.type", "ivf_flat"),
+ ("dimension", "8"),
+ ("nlist", "4"),
+ ("metric", "l2"),
+ ]))
+ .unwrap()
+ .index_type(),
+ IndexType::IvfFlat
+ );
+
+ match VectorIndexConfig::from_options(&options(&[
+ ("index.type", "ivf_pq"),
+ ("dimension", "16"),
+ ("nlist", "4"),
+ ("pq.m", "4"),
+ ("use-opq", "true"),
+ ]))
+ .unwrap()
+ {
+ VectorIndexConfig::IvfPq { m, use_opq, .. } => {
+ assert_eq!(m, 4);
+ assert!(use_opq);
+ }
+ _ => panic!("expected IVF PQ config"),
+ }
+
+ match VectorIndexConfig::from_options(&options(&[
+ ("index.type", "ivf_hnsw_sq"),
+ ("dimension", "8"),
+ ("nlist", "4"),
+ ("hnsw.m", "12"),
+ ("hnsw.ef-construction", "64"),
+ ("hnsw.max-level", "5"),
+ ]))
+ .unwrap()
+ {
+ VectorIndexConfig::IvfHnswSq { hnsw, .. } => {
+ assert_eq!(hnsw.m, 12);
+ assert_eq!(hnsw.ef_construction, 64);
+ assert_eq!(hnsw.max_level, 5);
+ }
+ _ => panic!("expected IVF HNSW SQ config"),
+ }
+ }
+
+ #[test]
+ fn config_from_options_rejects_unknown_options() {
+ let err = VectorIndexConfig::from_options(&options(&[
+ ("index.type", "ivf_flat"),
+ ("dimension", "8"),
+ ("nlist", "4"),
+ ("unused", "value"),
+ ]))
+ .unwrap_err();
+
+ assert!(err.to_string().contains("unknown vector index option"));
+ }
+
+ #[test]
+ fn config_from_options_rejects_alias_keys_and_values() {
+ let err = VectorIndexConfig::from_options(&options(&[
+ ("type", "ivf_flat"),
+ ("dimension", "8"),
+ ("nlist", "4"),
+ ]))
+ .unwrap_err();
+ assert!(err
+ .to_string()
+ .contains("missing required option 'index.type'"));
+
+ let err = VectorIndexConfig::from_options(&options(&[
+ ("index.type", "ivf-flat"),
+ ("dimension", "8"),
+ ("nlist", "4"),
+ ]))
+ .unwrap_err();
+ assert!(err.to_string().contains("unknown index.type"));
+
+ let err = VectorIndexConfig::from_options(&options(&[
+ ("index.type", "IVF_FLAT"),
+ ("dimension", "8"),
+ ("nlist", "4"),
+ ]))
+ .unwrap_err();
+ assert!(err.to_string().contains("unknown index.type"));
+
+ let err = VectorIndexConfig::from_options(&options(&[
+ ("index.type", "ivf_flat"),
+ ("dimension", "8"),
+ ("nlist", "4"),
+ ("metric", "ip"),
+ ]))
+ .unwrap_err();
+ assert!(err.to_string().contains("unknown metric"));
+ }
}
diff --git a/java/pom.xml b/java/pom.xml
new file mode 100644
index 0000000..f433ef2
--- /dev/null
+++ b/java/pom.xml
@@ -0,0 +1,70 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
https://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <groupId>org.apache.paimon</groupId>
+ <artifactId>paimon-vector-index-java</artifactId>
+ <version>0.1.0-SNAPSHOT</version>
+ <packaging>jar</packaging>
+
+ <name>Apache Paimon Vector Index Java</name>
+
+ <properties>
+ <maven.compiler.source>8</maven.compiler.source>
+ <maven.compiler.target>8</maven.compiler.target>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ </properties>
+
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>3.13.0</version>
+ <configuration>
+ <source>${maven.compiler.source}</source>
+ <target>${maven.compiler.target}</target>
+ <encoding>${project.build.sourceEncoding}</encoding>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <version>3.5.0</version>
+ <executions>
+ <execution>
+ <id>java-api-test</id>
+ <phase>test</phase>
+ <goals>
+ <goal>java</goal>
+ </goals>
+ <configuration>
+
<mainClass>org.apache.paimon.index.ivfpq.VectorIndexJavaApiTest</mainClass>
+ <classpathScope>test</classpathScope>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IndexType.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/IndexType.java
similarity index 100%
rename from jni/java/org/apache/paimon/index/ivfpq/IndexType.java
rename to java/src/main/java/org/apache/paimon/index/ivfpq/IndexType.java
diff --git a/jni/java/org/apache/paimon/index/ivfpq/Metric.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/Metric.java
similarity index 100%
rename from jni/java/org/apache/paimon/index/ivfpq/Metric.java
rename to java/src/main/java/org/apache/paimon/index/ivfpq/Metric.java
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexInput.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexInput.java
similarity index 100%
rename from jni/java/org/apache/paimon/index/ivfpq/VectorIndexInput.java
rename to java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexInput.java
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexMetadata.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexMetadata.java
similarity index 84%
rename from jni/java/org/apache/paimon/index/ivfpq/VectorIndexMetadata.java
rename to
java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexMetadata.java
index 931a5e4..ec52f11 100644
--- a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexMetadata.java
+++ b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexMetadata.java
@@ -25,7 +25,9 @@ public final class VectorIndexMetadata {
private final Metric metric;
private final long totalVectors;
private final int pqM;
- private final HnswConfig hnsw;
+ private final int hnswM;
+ private final int hnswEfConstruction;
+ private final int hnswMaxLevel;
public VectorIndexMetadata(
int indexType,
@@ -43,7 +45,9 @@ public final class VectorIndexMetadata {
this.metric = metricFromCode(metric);
this.totalVectors = totalVectors;
this.pqM = pqM;
- this.hnsw = hnswM > 0 ? new HnswConfig(hnswM, efConstruction,
maxLevel) : null;
+ this.hnswM = hnswM;
+ this.hnswEfConstruction = efConstruction;
+ this.hnswMaxLevel = maxLevel;
}
public IndexType indexType() {
@@ -70,8 +74,16 @@ public final class VectorIndexMetadata {
return pqM;
}
- public HnswConfig hnsw() {
- return hnsw;
+ public int hnswM() {
+ return hnswM;
+ }
+
+ public int hnswEfConstruction() {
+ return hnswEfConstruction;
+ }
+
+ public int hnswMaxLevel() {
+ return hnswMaxLevel;
}
private static Metric metricFromCode(int code) {
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexNative.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexNative.java
similarity index 88%
rename from jni/java/org/apache/paimon/index/ivfpq/VectorIndexNative.java
rename to
java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexNative.java
index b9c771d..194e61b 100644
--- a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexNative.java
+++ b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexNative.java
@@ -21,16 +21,9 @@ final class VectorIndexNative {
private VectorIndexNative() {}
- static native long createWriter(
- int indexType,
- int dimension,
- int nlist,
- int pqM,
- int metric,
- boolean useOpq,
- int hnswM,
- int efConstruction,
- int maxLevel);
+ static native long createWriter(String[] optionKeys, String[]
optionValues);
+
+ static native int writerDimension(long ptr);
static native void train(long ptr, float[] data, int n);
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
similarity index 82%
rename from jni/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
rename to
java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
index f6a5214..b10fb52 100644
--- a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
+++ b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexReader.java
@@ -72,7 +72,6 @@ public final class VectorIndexReader implements AutoCloseable
{
public VectorSearchResult search(float[] query, int topK, int nprobe, int
efSearch) {
validateQuery(query);
- validateSearchParams(topK, nprobe, efSearch);
synchronized (nativeHandleLock) {
enterNativeHandle();
try {
@@ -93,7 +92,6 @@ public final class VectorIndexReader implements AutoCloseable
{
if (roaringFilter == null) {
throw new NullPointerException("roaringFilter");
}
- validateSearchParams(topK, nprobe, efSearch);
synchronized (nativeHandleLock) {
enterNativeHandle();
try {
@@ -112,8 +110,9 @@ public final class VectorIndexReader implements
AutoCloseable {
public VectorSearchBatchResult searchBatch(
float[] queries, int queryCount, int topK, int nprobe, int
efSearch) {
- validateQueries(queries, queryCount);
- validateSearchParams(topK, nprobe, efSearch);
+ if (queries == null) {
+ throw new NullPointerException("queries");
+ }
synchronized (nativeHandleLock) {
enterNativeHandle();
try {
@@ -137,11 +136,12 @@ public final class VectorIndexReader implements
AutoCloseable {
int nprobe,
int efSearch,
byte[] roaringFilter) {
- validateQueries(queries, queryCount);
+ if (queries == null) {
+ throw new NullPointerException("queries");
+ }
if (roaringFilter == null) {
throw new NullPointerException("roaringFilter");
}
- validateSearchParams(topK, nprobe, efSearch);
synchronized (nativeHandleLock) {
enterNativeHandle();
try {
@@ -173,33 +173,6 @@ public final class VectorIndexReader implements
AutoCloseable {
if (query == null) {
throw new NullPointerException("query");
}
- if (query.length != dimension()) {
- throw new IllegalArgumentException(
- "query length " + query.length + " != index dimension " +
dimension());
- }
- }
-
- private void validateQueries(float[] queries, int queryCount) {
- if (queries == null) {
- throw new NullPointerException("queries");
- }
- VectorIndexConfig.validatePositive(queryCount, "queryCount");
- long expected = (long) queryCount * (long) dimension();
- if (expected > Integer.MAX_VALUE) {
- throw new IllegalArgumentException("queryCount * dimension
overflows int");
- }
- if (queries.length != expected) {
- throw new IllegalArgumentException(
- "queries length " + queries.length + " != queryCount *
dimension " + expected);
- }
- }
-
- private static void validateSearchParams(int topK, int nprobe, int
efSearch) {
- VectorIndexConfig.validatePositive(topK, "topK");
- VectorIndexConfig.validatePositive(nprobe, "nprobe");
- if (efSearch < 0) {
- throw new IllegalArgumentException("efSearch must be >= 0");
- }
}
private long requireOpen() {
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexWriter.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexWriter.java
similarity index 64%
rename from jni/java/org/apache/paimon/index/ivfpq/VectorIndexWriter.java
rename to
java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexWriter.java
index 5bebe50..a82950d 100644
--- a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexWriter.java
+++ b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorIndexWriter.java
@@ -17,51 +17,42 @@
package org.apache.paimon.index.ivfpq;
+import java.util.Map;
+
public final class VectorIndexWriter implements AutoCloseable {
- private final VectorIndexConfig config;
private final Object nativeHandleLock = new Object();
private long nativePtr;
private Thread nativeHandleOwner;
- public VectorIndexWriter(VectorIndexConfig config) {
- if (config == null) {
- throw new NullPointerException("config");
+ public VectorIndexWriter(Map<String, String> options) {
+ String[] keys = new String[options.size()];
+ String[] values = new String[options.size()];
+ int index = 0;
+ for (Map.Entry<String, String> entry : options.entrySet()) {
+ keys[index] = entry.getKey();
+ values[index] = entry.getValue();
+ index++;
}
- this.config = config;
- HnswConfig hnsw = config.hnsw();
- this.nativePtr =
- VectorIndexNative.createWriter(
- config.indexType().code(),
- config.dimension(),
- config.nlist(),
- config.pqM(),
- config.metric().code(),
- config.useOpq(),
- hnsw.m(),
- hnsw.efConstruction(),
- hnsw.maxLevel());
+ this.nativePtr = VectorIndexNative.createWriter(keys, values);
}
- private VectorIndexWriter(long nativePtr, VectorIndexConfig config) {
+ private VectorIndexWriter(long nativePtr) {
this.nativePtr = nativePtr;
- this.config = config;
- }
-
- static VectorIndexWriter fromNativePointerForTesting(long nativePtr,
VectorIndexConfig config) {
- return new VectorIndexWriter(nativePtr, config);
}
- public VectorIndexConfig config() {
- return config;
+ static VectorIndexWriter fromNativePointerForTesting(long nativePtr) {
+ return new VectorIndexWriter(nativePtr);
}
public int dimension() {
- return config.dimension();
+ return VectorIndexNative.writerDimension(requireOpen());
}
public void train(float[] data, int vectorCount) {
- validateVectors(data, vectorCount);
+ if (data == null) {
+ throw new NullPointerException("data");
+ }
synchronized (nativeHandleLock) {
enterNativeHandle();
try {
@@ -76,10 +67,8 @@ public final class VectorIndexWriter implements
AutoCloseable {
if (ids == null) {
throw new NullPointerException("ids");
}
- validateVectors(data, vectorCount);
- if (ids.length < vectorCount) {
- throw new IllegalArgumentException(
- "ids length " + ids.length + " < vectorCount " +
vectorCount);
+ if (data == null) {
+ throw new NullPointerException("data");
}
synchronized (nativeHandleLock) {
enterNativeHandle();
@@ -121,21 +110,6 @@ public final class VectorIndexWriter implements
AutoCloseable {
}
}
- private void validateVectors(float[] data, int vectorCount) {
- if (data == null) {
- throw new NullPointerException("data");
- }
- VectorIndexConfig.validatePositive(vectorCount, "vectorCount");
- long expected = (long) vectorCount * (long) config.dimension();
- if (expected > Integer.MAX_VALUE) {
- throw new IllegalArgumentException("vectorCount * dimension
overflows int");
- }
- if (data.length < expected) {
- throw new IllegalArgumentException(
- "data length " + data.length + " < vectorCount * dimension
" + expected);
- }
- }
-
private long requireOpen() {
if (nativePtr == 0L) {
throw new IllegalStateException("VectorIndexWriter is closed");
diff --git
a/jni/java/org/apache/paimon/index/ivfpq/VectorSearchBatchResult.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorSearchBatchResult.java
similarity index 100%
rename from jni/java/org/apache/paimon/index/ivfpq/VectorSearchBatchResult.java
rename to
java/src/main/java/org/apache/paimon/index/ivfpq/VectorSearchBatchResult.java
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorSearchResult.java
b/java/src/main/java/org/apache/paimon/index/ivfpq/VectorSearchResult.java
similarity index 100%
rename from jni/java/org/apache/paimon/index/ivfpq/VectorSearchResult.java
rename to
java/src/main/java/org/apache/paimon/index/ivfpq/VectorSearchResult.java
diff --git
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
similarity index 85%
rename from
jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
rename to
java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
index d24badc..85fe918 100644
--- a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
+++
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexJavaApiTest.java
@@ -18,6 +18,8 @@
package org.apache.paimon.index.ivfpq;
import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
public class VectorIndexJavaApiTest {
@@ -26,7 +28,6 @@ public class VectorIndexJavaApiTest {
testIndexTypeCodes();
testSingleResultCopiesArrays();
testBatchResultCopiesArraysAndSlicesRows();
- testConfigValidation();
testMetadata();
testClosedReaderRejectsOperations();
testClosedWriterRejectsOperations();
@@ -91,36 +92,6 @@ public class VectorIndexJavaApiTest {
});
}
- private static void testConfigValidation() {
- VectorIndexConfig flat = VectorIndexConfig.ivfFlat(16, 4, Metric.L2);
- assertEquals(IndexType.IVF_FLAT, flat.indexType());
- assertEquals(16, flat.dimension());
- assertEquals(4, flat.nlist());
-
- IvfPqConfig pq = new IvfPqConfig(16, 4, 4, Metric.L2, true);
- assertEquals(4, pq.m());
- assertTrue(pq.useOpq());
-
- HnswConfig hnsw = new HnswConfig(12, 64, 5);
- assertEquals(12, hnsw.m());
- assertEquals(64, hnsw.efConstruction());
- assertEquals(5, hnsw.maxLevel());
-
- VectorIndexConfig hnswFlat = VectorIndexConfig.ivfHnswFlat(16, 4,
Metric.L2, hnsw);
- assertEquals(IndexType.IVF_HNSW_FLAT, hnswFlat.indexType());
- assertEquals(12, hnswFlat.hnsw().m());
-
- VectorIndexConfig hnswSq = VectorIndexConfig.ivfHnswSq(16, 4,
Metric.L2, hnsw);
- assertEquals(IndexType.IVF_HNSW_SQ, hnswSq.indexType());
-
- assertThrows(IllegalArgumentException.class, new ThrowingRunnable() {
- @Override
- public void run() {
- new IvfPqConfig(10, 4, 3, Metric.L2, false);
- }
- });
- }
-
private static void testMetadata() {
VectorIndexMetadata metadata =
new VectorIndexMetadata(
@@ -138,7 +109,9 @@ public class VectorIndexJavaApiTest {
assertEquals(4, metadata.nlist());
assertEquals(Metric.COSINE, metadata.metric());
assertEquals(123L, metadata.totalVectors());
- assertEquals(20, metadata.hnsw().m());
+ assertEquals(20, metadata.hnswM());
+ assertEquals(150, metadata.hnswEfConstruction());
+ assertEquals(7, metadata.hnswMaxLevel());
}
private static void testClosedReaderRejectsOperations() {
@@ -185,8 +158,7 @@ public class VectorIndexJavaApiTest {
}
private static void testClosedWriterRejectsOperations() {
- VectorIndexConfig config = VectorIndexConfig.ivfPq(2, 4, 1, Metric.L2,
false);
- final VectorIndexWriter writer =
VectorIndexWriter.fromNativePointerForTesting(0L, config);
+ final VectorIndexWriter writer =
VectorIndexWriter.fromNativePointerForTesting(0L);
writer.close();
writer.close();
@@ -211,12 +183,12 @@ public class VectorIndexJavaApiTest {
}
private static void testReaderAndWriterApiCompile() {
- VectorIndexConfig config = VectorIndexConfig.ivfPq(2, 4, 1, Metric.L2,
false);
+ Map<String, String> options = ivfPqOptions(2, 4, 1);
VectorIndexReader closedReader =
VectorIndexReader.fromNativePointerForTesting(0L);
closedReader.close();
closedReader.close();
- VectorIndexWriter closedWriter =
VectorIndexWriter.fromNativePointerForTesting(0L, config);
+ VectorIndexWriter closedWriter =
VectorIndexWriter.fromNativePointerForTesting(0L);
closedWriter.close();
closedWriter.close();
@@ -236,13 +208,30 @@ public class VectorIndexJavaApiTest {
reader.searchBatch(
new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2, 10, 4, 32, new
byte[] {1, 2});
- VectorIndexWriter writer = new VectorIndexWriter(config);
+ VectorIndexWriter writer = new VectorIndexWriter(options);
writer.train(new float[] {0.0f, 1.0f, 2.0f, 3.0f}, 2);
writer.addVectors(new long[] {1L, 2L}, new float[] {0.0f, 1.0f,
2.0f, 3.0f}, 2);
writer.writeIndex(new Object());
}
}
+ private static Map<String, String> ivfFlatOptions(int dimension, int
nlist) {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("index.type", "ivf_flat");
+ options.put("dimension", Integer.toString(dimension));
+ options.put("nlist", Integer.toString(nlist));
+ options.put("metric", "l2");
+ return options;
+ }
+
+ private static Map<String, String> ivfPqOptions(int dimension, int nlist,
int m) {
+ Map<String, String> options = ivfFlatOptions(dimension, nlist);
+ options.put("index.type", "ivf_pq");
+ options.put("pq.m", Integer.toString(m));
+ options.put("use-opq", "false");
+ return options;
+ }
+
private static void assertEquals(int expected, int actual) {
if (expected != actual) {
throw new AssertionError("expected " + expected + " but got " +
actual);
@@ -261,12 +250,6 @@ public class VectorIndexJavaApiTest {
}
}
- private static void assertTrue(boolean value) {
- if (!value) {
- throw new AssertionError("expected true");
- }
- }
-
private static void assertArrayEquals(long[] expected, long[] actual) {
if (!Arrays.equals(expected, actual)) {
throw new AssertionError(
diff --git
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativeHandleSafetyTest.java
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativeHandleSafetyTest.java
similarity index 93%
rename from
jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativeHandleSafetyTest.java
rename to
java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativeHandleSafetyTest.java
index 7dc93f4..8d6dd2a 100644
---
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativeHandleSafetyTest.java
+++
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativeHandleSafetyTest.java
@@ -18,6 +18,8 @@
package org.apache.paimon.index.ivfpq;
import java.io.ByteArrayOutputStream;
+import java.util.HashMap;
+import java.util.Map;
public class VectorIndexNativeHandleSafetyTest {
@@ -73,13 +75,21 @@ public class VectorIndexNativeHandleSafetyTest {
}
private static VectorIndexWriter newPopulatedWriter() {
- VectorIndexWriter writer =
- new VectorIndexWriter(VectorIndexConfig.ivfFlat(1, 1,
Metric.L2));
+ VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions());
writer.train(new float[] {0.0f, 1.0f}, 2);
writer.addVectors(new long[] {1L, 2L}, new float[] {0.0f, 1.0f}, 2);
return writer;
}
+ private static Map<String, String> ivfFlatOptions() {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("index.type", "ivf_flat");
+ options.put("dimension", "1");
+ options.put("nlist", "1");
+ options.put("metric", "l2");
+ return options;
+ }
+
private static void assertEquals(int expected, int actual) {
if (expected != actual) {
throw new AssertionError("expected " + expected + " but got " +
actual);
diff --git
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
similarity index 89%
rename from
jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
rename to
java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
index 045eec3..f260aeb 100644
---
a/jni/java-test/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
+++
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativePanicBoundaryTest.java
@@ -18,6 +18,8 @@
package org.apache.paimon.index.ivfpq;
import java.io.ByteArrayOutputStream;
+import java.util.HashMap;
+import java.util.Map;
public class VectorIndexNativePanicBoundaryTest {
@@ -31,13 +33,12 @@ public class VectorIndexNativePanicBoundaryTest {
testVoidEntrypointPanicBecomesRuntimeException();
testObjectEntrypointPanicBecomesRuntimeException();
- VectorIndexWriter survivor = new
VectorIndexWriter(VectorIndexConfig.ivfFlat(1, 1, Metric.L2));
+ VectorIndexWriter survivor = new VectorIndexWriter(ivfFlatOptions());
survivor.close();
}
private static void testVoidEntrypointPanicBecomesRuntimeException() {
- final VectorIndexWriter writer =
- new VectorIndexWriter(VectorIndexConfig.ivfFlat(1, 1,
Metric.L2));
+ final VectorIndexWriter writer = new
VectorIndexWriter(ivfFlatOptions());
try {
assertThrows(RuntimeException.class, new ThrowingRunnable() {
@Override
@@ -52,8 +53,7 @@ public class VectorIndexNativePanicBoundaryTest {
private static void testObjectEntrypointPanicBecomesRuntimeException() {
ByteArrayPositionOutputStream output = new
ByteArrayPositionOutputStream();
- VectorIndexWriter writer =
- new VectorIndexWriter(VectorIndexConfig.ivfFlat(1, 1,
Metric.L2));
+ VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions());
try {
writer.train(new float[] {0.0f, 1.0f}, 2);
writer.addVectors(new long[] {1L, 2L}, new float[] {Float.NaN,
1.0f}, 2);
@@ -78,6 +78,15 @@ public class VectorIndexNativePanicBoundaryTest {
}
}
+ private static Map<String, String> ivfFlatOptions() {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("index.type", "ivf_flat");
+ options.put("dimension", "1");
+ options.put("nlist", "1");
+ options.put("metric", "l2");
+ return options;
+ }
+
private static void assertEquals(int expected, int actual) {
if (expected != actual) {
throw new AssertionError("expected " + expected + " but got " +
actual);
diff --git
a/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativeValidationTest.java
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativeValidationTest.java
new file mode 100644
index 0000000..55b8b2c
--- /dev/null
+++
b/java/src/test/java/org/apache/paimon/index/ivfpq/VectorIndexNativeValidationTest.java
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.paimon.index.ivfpq;
+
+import java.io.ByteArrayOutputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+public class VectorIndexNativeValidationTest {
+
+ public static void main(String[] args) {
+ if (args.length != 1) {
+ throw new IllegalArgumentException("native library path is
required");
+ }
+
+ System.load(args[0]);
+
+ testWriterValidationComesFromCore();
+ testReaderValidationComesFromCore();
+ }
+
+ private static void testWriterValidationComesFromCore() {
+ final VectorIndexWriter writer = new
VectorIndexWriter(ivfFlatOptions());
+ try {
+ assertThrowsMessage(
+ RuntimeException.class,
+ "training data length 2 does not match vector count *
dimension 1",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ writer.train(new float[] {0.0f, 1.0f}, 1);
+ }
+ });
+ assertThrowsMessage(
+ RuntimeException.class,
+ "ids length 2 does not match vector count 1",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ writer.addVectors(new long[] {1L, 2L}, new float[]
{0.0f}, 1);
+ }
+ });
+ assertThrowsMessage(
+ RuntimeException.class,
+ "vector count must be greater than 0",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ writer.train(new float[0], 0);
+ }
+ });
+ } finally {
+ writer.close();
+ }
+ }
+
+ private static void testReaderValidationComesFromCore() {
+ VectorIndexReader reader = new VectorIndexReader(new
ByteArraySeekableInputStream(buildIndexBytes()));
+ try {
+ assertThrowsMessage(
+ RuntimeException.class,
+ "query length 2 does not match index dimension 1",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ reader.search(new float[] {0.0f, 1.0f}, 1, 1);
+ }
+ });
+ assertThrowsMessage(
+ RuntimeException.class,
+ "k must be greater than 0",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ reader.search(new float[] {0.0f}, 0, 1);
+ }
+ });
+ assertThrowsMessage(
+ RuntimeException.class,
+ "queries length 2 does not match nq * dimension 1",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ reader.searchBatch(new float[] {0.0f, 1.0f}, 1, 1,
1);
+ }
+ });
+ } finally {
+ reader.close();
+ }
+ }
+
+ private static byte[] buildIndexBytes() {
+ VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions());
+ ByteArrayPositionOutputStream output = new
ByteArrayPositionOutputStream();
+ try {
+ writer.train(new float[] {0.0f, 1.0f}, 2);
+ writer.addVectors(new long[] {1L, 2L}, new float[] {0.0f, 1.0f},
2);
+ writer.writeIndex(output);
+ return output.toByteArray();
+ } finally {
+ writer.close();
+ }
+ }
+
+ private static Map<String, String> ivfFlatOptions() {
+ Map<String, String> options = new HashMap<String, String>();
+ options.put("index.type", "ivf_flat");
+ options.put("dimension", "1");
+ options.put("nlist", "1");
+ options.put("metric", "l2");
+ return options;
+ }
+
+ private static void assertThrowsMessage(
+ Class<? extends Throwable> expected, String expectedMessage,
ThrowingRunnable runnable) {
+ try {
+ runnable.run();
+ } catch (Throwable t) {
+ if (!expected.isInstance(t)) {
+ throw new AssertionError(
+ "expected " + expected.getName() + " but got " +
t.getClass().getName(), t);
+ }
+ String message = t.getMessage();
+ if (message == null || !message.contains(expectedMessage)) {
+ throw new AssertionError("unexpected exception message: " +
message, t);
+ }
+ return;
+ }
+ throw new AssertionError("expected " + expected.getName());
+ }
+
+ private interface ThrowingRunnable {
+ void run() throws Throwable;
+ }
+
+ public static final class ByteArrayPositionOutputStream {
+ private final ByteArrayOutputStream out = new ByteArrayOutputStream();
+
+ public void write(byte[] bytes) {
+ out.write(bytes, 0, bytes.length);
+ }
+
+ public byte[] toByteArray() {
+ return out.toByteArray();
+ }
+ }
+
+ public static final class ByteArraySeekableInputStream implements
VectorIndexInput {
+ private final byte[] data;
+
+ ByteArraySeekableInputStream(byte[] data) {
+ this.data = data.clone();
+ }
+
+ @Override
+ public void pread(long[] positions, byte[][] buffers) {
+ if (positions.length != buffers.length) {
+ throw new IllegalArgumentException("positions and buffers
length mismatch");
+ }
+ for (int i = 0; i < positions.length; i++) {
+ long readPosition = positions[i];
+ byte[] buffer = buffers[i];
+ if (readPosition < 0 || readPosition + buffer.length >
data.length) {
+ throw new IllegalArgumentException("read out of range: " +
readPosition);
+ }
+ System.arraycopy(data, (int) readPosition, buffer, 0,
buffer.length);
+ }
+ }
+ }
+}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/HnswConfig.java
b/jni/java/org/apache/paimon/index/ivfpq/HnswConfig.java
deleted file mode 100644
index eb0bb13..0000000
--- a/jni/java/org/apache/paimon/index/ivfpq/HnswConfig.java
+++ /dev/null
@@ -1,54 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.paimon.index.ivfpq;
-
-public final class HnswConfig {
-
- public static final HnswConfig DEFAULT = new HnswConfig(20, 150, 7);
-
- private final int m;
- private final int efConstruction;
- private final int maxLevel;
-
- public HnswConfig(int m, int efConstruction, int maxLevel) {
- validatePositive(m, "m");
- validatePositive(efConstruction, "efConstruction");
- validatePositive(maxLevel, "maxLevel");
- this.m = m;
- this.efConstruction = efConstruction;
- this.maxLevel = maxLevel;
- }
-
- public int m() {
- return m;
- }
-
- public int efConstruction() {
- return efConstruction;
- }
-
- public int maxLevel() {
- return maxLevel;
- }
-
- private static void validatePositive(int value, String name) {
- if (value <= 0) {
- throw new IllegalArgumentException(name + " must be > 0");
- }
- }
-}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IvfFlatConfig.java
b/jni/java/org/apache/paimon/index/ivfpq/IvfFlatConfig.java
deleted file mode 100644
index 285afc1..0000000
--- a/jni/java/org/apache/paimon/index/ivfpq/IvfFlatConfig.java
+++ /dev/null
@@ -1,25 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.paimon.index.ivfpq;
-
-public final class IvfFlatConfig extends VectorIndexConfig {
-
- public IvfFlatConfig(int dimension, int nlist, Metric metric) {
- super(IndexType.IVF_FLAT, dimension, nlist, metric);
- }
-}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IvfHnswFlatConfig.java
b/jni/java/org/apache/paimon/index/ivfpq/IvfHnswFlatConfig.java
deleted file mode 100644
index 3b17586..0000000
--- a/jni/java/org/apache/paimon/index/ivfpq/IvfHnswFlatConfig.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.paimon.index.ivfpq;
-
-public final class IvfHnswFlatConfig extends VectorIndexConfig {
-
- private final HnswConfig hnsw;
-
- public IvfHnswFlatConfig(int dimension, int nlist, Metric metric,
HnswConfig hnsw) {
- super(IndexType.IVF_HNSW_FLAT, dimension, nlist, metric);
- if (hnsw == null) {
- throw new NullPointerException("hnsw");
- }
- this.hnsw = hnsw;
- }
-
- public HnswConfig hnsw() {
- return hnsw;
- }
-}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IvfHnswSqConfig.java
b/jni/java/org/apache/paimon/index/ivfpq/IvfHnswSqConfig.java
deleted file mode 100644
index 80fe09a..0000000
--- a/jni/java/org/apache/paimon/index/ivfpq/IvfHnswSqConfig.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.paimon.index.ivfpq;
-
-public final class IvfHnswSqConfig extends VectorIndexConfig {
-
- private final HnswConfig hnsw;
-
- public IvfHnswSqConfig(int dimension, int nlist, Metric metric, HnswConfig
hnsw) {
- super(IndexType.IVF_HNSW_SQ, dimension, nlist, metric);
- if (hnsw == null) {
- throw new NullPointerException("hnsw");
- }
- this.hnsw = hnsw;
- }
-
- public HnswConfig hnsw() {
- return hnsw;
- }
-}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/IvfPqConfig.java
b/jni/java/org/apache/paimon/index/ivfpq/IvfPqConfig.java
deleted file mode 100644
index 751f560..0000000
--- a/jni/java/org/apache/paimon/index/ivfpq/IvfPqConfig.java
+++ /dev/null
@@ -1,47 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.paimon.index.ivfpq;
-
-public final class IvfPqConfig extends VectorIndexConfig {
-
- private final int m;
- private final boolean useOpq;
-
- public IvfPqConfig(int dimension, int nlist, int m, Metric metric, boolean
useOpq) {
- super(IndexType.IVF_PQ, dimension, nlist, metric);
- validatePositive(m, "m");
- if (dimension % m != 0) {
- throw new IllegalArgumentException("dimension must be divisible by
m");
- }
- this.m = m;
- this.useOpq = useOpq;
- }
-
- public int m() {
- return m;
- }
-
- public boolean useOpq() {
- return useOpq;
- }
-
- @Override
- int pqM() {
- return m;
- }
-}
diff --git a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexConfig.java
b/jni/java/org/apache/paimon/index/ivfpq/VectorIndexConfig.java
deleted file mode 100644
index 0152492..0000000
--- a/jni/java/org/apache/paimon/index/ivfpq/VectorIndexConfig.java
+++ /dev/null
@@ -1,94 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.paimon.index.ivfpq;
-
-public abstract class VectorIndexConfig {
-
- private final IndexType indexType;
- private final int dimension;
- private final int nlist;
- private final Metric metric;
-
- VectorIndexConfig(IndexType indexType, int dimension, int nlist, Metric
metric) {
- if (indexType == null) {
- throw new NullPointerException("indexType");
- }
- if (metric == null) {
- throw new NullPointerException("metric");
- }
- validatePositive(dimension, "dimension");
- validatePositive(nlist, "nlist");
- this.indexType = indexType;
- this.dimension = dimension;
- this.nlist = nlist;
- this.metric = metric;
- }
-
- public static VectorIndexConfig ivfFlat(int dimension, int nlist, Metric
metric) {
- return new IvfFlatConfig(dimension, nlist, metric);
- }
-
- public static VectorIndexConfig ivfPq(
- int dimension, int nlist, int m, Metric metric, boolean useOpq) {
- return new IvfPqConfig(dimension, nlist, m, metric, useOpq);
- }
-
- public static VectorIndexConfig ivfHnswFlat(
- int dimension, int nlist, Metric metric, HnswConfig hnsw) {
- return new IvfHnswFlatConfig(dimension, nlist, metric, hnsw);
- }
-
- public static VectorIndexConfig ivfHnswSq(
- int dimension, int nlist, Metric metric, HnswConfig hnsw) {
- return new IvfHnswSqConfig(dimension, nlist, metric, hnsw);
- }
-
- public IndexType indexType() {
- return indexType;
- }
-
- public int dimension() {
- return dimension;
- }
-
- public int nlist() {
- return nlist;
- }
-
- public Metric metric() {
- return metric;
- }
-
- int pqM() {
- return 0;
- }
-
- boolean useOpq() {
- return false;
- }
-
- HnswConfig hnsw() {
- return HnswConfig.DEFAULT;
- }
-
- static void validatePositive(int value, String name) {
- if (value <= 0) {
- throw new IllegalArgumentException(name + " must be > 0");
- }
- }
-}
diff --git a/jni/src/lib.rs b/jni/src/lib.rs
index 096d11c..7a07757 100644
--- a/jni/src/lib.rs
+++ b/jni/src/lib.rs
@@ -18,15 +18,14 @@
mod stream;
use jni::objects::{JByteArray, JClass, JFloatArray, JLongArray, JObject,
JValue};
-use jni::sys::{jboolean, jint, jlong, jobject};
+use jni::sys::{jint, jlong, jobject, jobjectArray};
use jni::JNIEnv;
-use paimon_vindex_core::distance::MetricType;
-use paimon_vindex_core::hnsw::HnswBuildParams;
use paimon_vindex_core::index::{
- IndexType, VectorIndexConfig, VectorIndexMetadata, VectorIndexReader,
VectorIndexWriter,
+ VectorIndexConfig, VectorIndexMetadata, VectorIndexReader,
VectorIndexWriter,
VectorSearchParams,
};
use std::any::Any;
+use std::collections::HashMap;
use std::panic::{catch_unwind, AssertUnwindSafe};
use stream::{JniOutputStream, JniSeekableStream};
@@ -80,109 +79,85 @@ fn deref_reader(ptr: jlong) -> Option<&'static mut
VectorIndexReader<JniSeekable
}
}
-fn parse_metric(env: &mut JNIEnv, metric: jint) -> Option<MetricType> {
- match MetricType::from_code(metric as u32) {
- Some(metric) => Some(metric),
- None => {
- throw_and_return::<()>(env, &format!("Unknown metric: {}",
metric));
- None
- }
- }
-}
-
-fn parse_index_type(env: &mut JNIEnv, index_type: jint) -> Option<IndexType> {
- match IndexType::from_code(index_type as u32) {
- Some(index_type) => Some(index_type),
- None => {
- throw_and_return::<()>(env, &format!("Unknown index type: {}",
index_type));
- None
- }
- }
-}
-
-#[allow(clippy::too_many_arguments)]
-fn build_config(
+fn build_config_from_options(
env: &mut JNIEnv,
- index_type: jint,
- dimension: jint,
- nlist: jint,
- pq_m: jint,
- metric: jint,
- use_opq: jboolean,
- hnsw_m: jint,
- ef_construction: jint,
- max_level: jint,
+ keys: jobjectArray,
+ values: jobjectArray,
) -> Option<VectorIndexConfig> {
- if dimension <= 0 || nlist <= 0 {
+ let keys = unsafe { jni::objects::JObjectArray::from_raw(keys) };
+ let values = unsafe { jni::objects::JObjectArray::from_raw(values) };
+ let key_len = match env.get_array_length(&keys) {
+ Ok(len) => len,
+ Err(e) => {
+ throw_and_return::<()>(env, &format!("get_array_length(keys): {}",
e));
+ return None;
+ }
+ };
+ let value_len = match env.get_array_length(&values) {
+ Ok(len) => len,
+ Err(e) => {
+ throw_and_return::<()>(env, &format!("get_array_length(values):
{}", e));
+ return None;
+ }
+ };
+ if key_len != value_len {
throw_and_return::<()>(
env,
&format!(
- "invalid parameters: dimension={}, nlist={}",
- dimension, nlist
+ "options key/value array length mismatch: {} != {}",
+ key_len, value_len
),
);
return None;
}
- let index_type = parse_index_type(env, index_type)?;
- let metric = parse_metric(env, metric)?;
- let dimension = dimension as usize;
- let nlist = nlist as usize;
-
- Some(match index_type {
- IndexType::IvfFlat => VectorIndexConfig::IvfFlat {
- dimension,
- nlist,
- metric,
- },
- IndexType::IvfPq => {
- if pq_m <= 0 {
- throw_and_return::<()>(env, &format!("invalid pq m: {}",
pq_m));
+
+ let mut options = HashMap::with_capacity(key_len as usize);
+ for idx in 0..key_len {
+ let key = match env.get_object_array_element(&keys, idx) {
+ Ok(key) => key,
+ Err(e) => {
+ throw_and_return::<()>(env, &format!("get options key {}: {}",
idx, e));
return None;
}
- VectorIndexConfig::IvfPq {
- dimension,
- nlist,
- m: pq_m as usize,
- metric,
- use_opq: use_opq != 0,
+ };
+ let value = match env.get_object_array_element(&values, idx) {
+ Ok(value) => value,
+ Err(e) => {
+ throw_and_return::<()>(env, &format!("get options value {}:
{}", idx, e));
+ return None;
+ }
+ };
+ let key = match java_string(env, key) {
+ Ok(key) => key,
+ Err(e) => {
+ throw_and_return::<()>(env, &format!("read options key {}:
{}", idx, e));
+ return None;
+ }
+ };
+ let value = match java_string(env, value) {
+ Ok(value) => value,
+ Err(e) => {
+ throw_and_return::<()>(env, &format!("read options value {}:
{}", idx, e));
+ return None;
}
+ };
+ options.insert(key, value);
+ }
+
+ match VectorIndexConfig::from_options(&options) {
+ Ok(config) => Some(config),
+ Err(e) => {
+ throw_and_return::<()>(env, &format!("invalid vector index
options: {}", e));
+ None
}
- IndexType::IvfHnswFlat => VectorIndexConfig::IvfHnswFlat {
- dimension,
- nlist,
- metric,
- hnsw: build_hnsw_params(env, hnsw_m, ef_construction, max_level)?,
- },
- IndexType::IvfHnswSq => VectorIndexConfig::IvfHnswSq {
- dimension,
- nlist,
- metric,
- hnsw: build_hnsw_params(env, hnsw_m, ef_construction, max_level)?,
- },
- })
+ }
}
-fn build_hnsw_params(
- env: &mut JNIEnv,
- hnsw_m: jint,
- ef_construction: jint,
- max_level: jint,
-) -> Option<HnswBuildParams> {
- if hnsw_m <= 0 || ef_construction <= 0 || max_level <= 0 {
- throw_and_return::<()>(
- env,
- &format!(
- "invalid HNSW parameters: m={}, efConstruction={},
maxLevel={}",
- hnsw_m, ef_construction, max_level
- ),
- );
- return None;
- }
- Some(HnswBuildParams {
- m: hnsw_m as usize,
- ef_construction: ef_construction as usize,
- max_level: max_level as usize,
- })
+fn java_string(env: &mut JNIEnv, object: JObject) -> Result<String, String> {
+ let string = jni::objects::JString::from(object);
+ env.get_string(&string)
+ .map(|value| value.into())
+ .map_err(|e| format!("get_string: {}", e))
}
fn read_byte_array(env: &mut JNIEnv, array: JByteArray) -> Result<Vec<u8>,
String> {
@@ -194,43 +169,21 @@ fn read_byte_array(env: &mut JNIEnv, array: JByteArray)
-> Result<Vec<u8>, Strin
.map_err(|e| format!("convert_byte_array: {}", e))
}
-fn read_float_array(
- env: &mut JNIEnv,
- array: &JFloatArray,
- expected_len: usize,
- name: &str,
-) -> Result<Vec<f32>, String> {
+fn read_float_array(env: &mut JNIEnv, array: &JFloatArray, name: &str) ->
Result<Vec<f32>, String> {
let len = env
.get_array_length(array)
.map_err(|e| format!("get_array_length({}): {}", name, e))? as usize;
- if len < expected_len {
- return Err(format!(
- "{} array too short: {} < {}",
- name, len, expected_len
- ));
- }
- let mut buf = vec![0.0f32; expected_len];
+ let mut buf = vec![0.0f32; len];
env.get_float_array_region(array, 0, &mut buf)
.map_err(|e| format!("get_float_array_region({}): {}", name, e))?;
Ok(buf)
}
-fn read_long_array(
- env: &mut JNIEnv,
- array: &JLongArray,
- expected_len: usize,
- name: &str,
-) -> Result<Vec<i64>, String> {
+fn read_long_array(env: &mut JNIEnv, array: &JLongArray, name: &str) ->
Result<Vec<i64>, String> {
let len = env
.get_array_length(array)
.map_err(|e| format!("get_array_length({}): {}", name, e))? as usize;
- if len < expected_len {
- return Err(format!(
- "{} array too short: {} < {}",
- name, len, expected_len
- ));
- }
- let mut buf = vec![0i64; expected_len];
+ let mut buf = vec![0i64; len];
env.get_long_array_region(array, 0, &mut buf)
.map_err(|e| format!("get_long_array_region({}): {}", name, e))?;
Ok(buf)
@@ -339,7 +292,7 @@ fn build_metadata(env: &mut JNIEnv, metadata:
VectorIndexMetadata) -> jobject {
}
fn search_params(k: jint, nprobe: jint, ef_search: jint) ->
Option<VectorSearchParams> {
- if k <= 0 || nprobe <= 0 || ef_search < 0 {
+ if k < 0 || nprobe < 0 || ef_search < 0 {
None
} else {
Some(VectorSearchParams::with_ef_search(
@@ -353,33 +306,14 @@ fn search_params(k: jint, nprobe: jint, ef_search: jint)
-> Option<VectorSearchP
// --- Unified Writer API ---
#[no_mangle]
-#[allow(clippy::too_many_arguments)]
pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_createWriter(
env: JNIEnv,
_class: JClass,
- index_type: jint,
- dimension: jint,
- nlist: jint,
- pq_m: jint,
- metric: jint,
- use_opq: jboolean,
- hnsw_m: jint,
- ef_construction: jint,
- max_level: jint,
+ keys: jobjectArray,
+ values: jobjectArray,
) -> jlong {
jni_call(env, |env| {
- let config = match build_config(
- env,
- index_type,
- dimension,
- nlist,
- pq_m,
- metric,
- use_opq,
- hnsw_m,
- ef_construction,
- max_level,
- ) {
+ let config = match build_config_from_options(env, keys, values) {
Some(config) => config,
None => return 0,
};
@@ -405,15 +339,11 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_trai
Some(writer) => writer,
None => return throw_and_return(env, "null native pointer (writer
already freed?)"),
};
- if n <= 0 {
+ if n < 0 {
return throw_and_return(env, &format!("invalid vector count: {}",
n));
}
let n = n as usize;
- let expected_len = match n.checked_mul(writer.dimension()) {
- Some(len) => len,
- None => return throw_and_return(env, "vector count * dimension
overflow"),
- };
- let data_buf = match read_float_array(env, &data, expected_len,
"data") {
+ let data_buf = match read_float_array(env, &data, "data") {
Ok(buf) => buf,
Err(e) => return throw_and_return(env, &e),
};
@@ -423,6 +353,21 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_trai
})
}
+#[no_mangle]
+pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_writerDimension(
+ env: JNIEnv,
+ _class: JClass,
+ ptr: jlong,
+) -> jint {
+ jni_call(env, |env| {
+ let writer = match deref_writer(ptr) {
+ Some(writer) => writer,
+ None => return throw_and_return(env, "null native pointer (writer
already freed?)"),
+ };
+ writer.dimension() as jint
+ })
+}
+
#[no_mangle]
pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_addVectors(
env: JNIEnv,
@@ -437,19 +382,15 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_addV
Some(writer) => writer,
None => return throw_and_return(env, "null native pointer (writer
already freed?)"),
};
- if n <= 0 {
+ if n < 0 {
return throw_and_return(env, &format!("invalid vector count: {}",
n));
}
let n = n as usize;
- let expected_len = match n.checked_mul(writer.dimension()) {
- Some(len) => len,
- None => return throw_and_return(env, "vector count * dimension
overflow"),
- };
- let id_buf = match read_long_array(env, &ids, n, "ids") {
+ let id_buf = match read_long_array(env, &ids, "ids") {
Ok(buf) => buf,
Err(e) => return throw_and_return(env, &e),
};
- let data_buf = match read_float_array(env, &data, expected_len,
"data") {
+ let data_buf = match read_float_array(env, &data, "data") {
Ok(buf) => buf,
Err(e) => return throw_and_return(env, &e),
};
@@ -572,7 +513,7 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_sear
)
}
};
- let query_buf = match read_float_array(env, &query,
reader.dimension(), "query") {
+ let query_buf = match read_float_array(env, &query, "query") {
Ok(buf) => buf,
Err(e) => return throw_and_return(env, &e),
};
@@ -612,7 +553,7 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_sear
)
}
};
- let query_buf = match read_float_array(env, &query,
reader.dimension(), "query") {
+ let query_buf = match read_float_array(env, &query, "query") {
Ok(buf) => buf,
Err(e) => return throw_and_return(env, &e),
};
@@ -645,7 +586,7 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_sear
Some(reader) => reader,
None => return throw_and_return(env, "null native pointer (reader
already freed?)"),
};
- if query_count <= 0 {
+ if query_count < 0 {
return throw_and_return(env, &format!("invalid query count: {}",
query_count));
}
let params = match search_params(k, nprobe, ef_search) {
@@ -661,11 +602,7 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_sear
}
};
let nq = query_count as usize;
- let expected_len = match nq.checked_mul(reader.dimension()) {
- Some(len) => len,
- None => return throw_and_return(env, "query count * dimension
overflow"),
- };
- let query_buf = match read_float_array(env, &queries, expected_len,
"queries") {
+ let query_buf = match read_float_array(env, &queries, "queries") {
Ok(buf) => buf,
Err(e) => return throw_and_return(env, &e),
};
@@ -694,7 +631,7 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_sear
Some(reader) => reader,
None => return throw_and_return(env, "null native pointer (reader
already freed?)"),
};
- if query_count <= 0 {
+ if query_count < 0 {
return throw_and_return(env, &format!("invalid query count: {}",
query_count));
}
let params = match search_params(k, nprobe, ef_search) {
@@ -710,11 +647,7 @@ pub extern "system" fn
Java_org_apache_paimon_index_ivfpq_VectorIndexNative_sear
}
};
let nq = query_count as usize;
- let expected_len = match nq.checked_mul(reader.dimension()) {
- Some(len) => len,
- None => return throw_and_return(env, "query count * dimension
overflow"),
- };
- let query_buf = match read_float_array(env, &queries, expected_len,
"queries") {
+ let query_buf = match read_float_array(env, &queries, "queries") {
Ok(buf) => buf,
Err(e) => return throw_and_return(env, &e),
};
diff --git a/python/src/lib.rs b/python/src/lib.rs
index aaeb90b..901bcc5 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -21,7 +21,6 @@ use numpy::{
PyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2,
PyUntypedArrayMethods,
};
use paimon_vindex_core::distance::MetricType;
-use paimon_vindex_core::hnsw::HnswBuildParams;
use paimon_vindex_core::index::{
IndexType, VectorIndexConfig, VectorIndexReader as CoreVectorIndexReader,
VectorIndexWriter as CoreVectorIndexWriter, VectorSearchParams,
@@ -29,7 +28,8 @@ use paimon_vindex_core::index::{
use paimon_vindex_core::io::{ReadRequest, SeekRead, SeekWrite};
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
-use pyo3::types::{PyAny, PyBytes, PyList};
+use pyo3::types::{PyAny, PyBytes, PyDict, PyList};
+use std::collections::HashMap;
use std::io;
struct PyVectorIndexInput {
@@ -131,18 +131,6 @@ impl SeekWrite for PyOutputStream {
}
}
-fn parse_metric(metric: &str) -> PyResult<MetricType> {
- match metric.to_ascii_lowercase().as_str() {
- "l2" => Ok(MetricType::L2),
- "inner_product" | "ip" => Ok(MetricType::InnerProduct),
- "cosine" => Ok(MetricType::Cosine),
- _ => Err(PyValueError::new_err(format!(
- "unknown metric '{}'; expected 'l2', 'inner_product', or 'cosine'",
- metric
- ))),
- }
-}
-
fn metric_name(metric: MetricType) -> &'static str {
match metric {
MetricType::L2 => "l2",
@@ -155,14 +143,6 @@ fn index_type_name(index_type: IndexType) -> &'static str {
index_type.as_str()
}
-fn validate_positive(value: usize, name: &str) -> PyResult<()> {
- if value == 0 {
- Err(PyValueError::new_err(format!("{} must be > 0", name)))
- } else {
- Ok(())
- }
-}
-
fn decode_filter_bytes<'a>(
filter_bytes: Option<&'a Bound<'_, PyAny>>,
) -> PyResult<Option<&'a [u8]>> {
@@ -191,230 +171,6 @@ fn pyarray2_from_flat<'py, T: numpy::Element + Clone>(
.map_err(|e| PyValueError::new_err(format!("reshape batch result: {}",
e)))
}
-fn validate_matrix_shape(
- shape: &[usize],
- dimension: usize,
- value_name: &str,
- dimension_name: &str,
-) -> PyResult<usize> {
- let row_count = shape[0];
- let actual_dimension = shape[1];
- if actual_dimension != dimension {
- return Err(PyValueError::new_err(format!(
- "{} dimension {} != {} {}",
- value_name, actual_dimension, dimension_name, dimension
- )));
- }
- if row_count == 0 {
- return Err(PyValueError::new_err(format!(
- "{} must contain at least one row",
- value_name
- )));
- }
- Ok(row_count)
-}
-
-fn hnsw_params(hnsw: Option<&HnswConfig>) -> HnswBuildParams {
- hnsw.map(|h| h.to_core())
- .unwrap_or_else(HnswBuildParams::default)
-}
-
-#[pyclass]
-#[derive(Clone)]
-struct HnswConfig {
- #[pyo3(get)]
- m: usize,
- #[pyo3(get)]
- ef_construction: usize,
- #[pyo3(get)]
- max_level: usize,
-}
-
-#[pymethods]
-impl HnswConfig {
- #[new]
- #[pyo3(signature = (m=20, ef_construction=150, max_level=7))]
- fn new(m: usize, ef_construction: usize, max_level: usize) ->
PyResult<Self> {
- validate_positive(m, "m")?;
- validate_positive(ef_construction, "ef_construction")?;
- validate_positive(max_level, "max_level")?;
- Ok(Self {
- m,
- ef_construction,
- max_level,
- })
- }
-}
-
-impl HnswConfig {
- fn to_core(&self) -> HnswBuildParams {
- HnswBuildParams {
- m: self.m,
- ef_construction: self.ef_construction,
- max_level: self.max_level,
- }
- }
-}
-
-#[pyclass]
-#[derive(Clone)]
-struct IvfFlatConfig {
- #[pyo3(get)]
- dimension: usize,
- #[pyo3(get)]
- nlist: usize,
- #[pyo3(get)]
- metric: String,
-}
-
-#[pymethods]
-impl IvfFlatConfig {
- #[new]
- #[pyo3(signature = (dimension, nlist, metric="l2"))]
- fn new(dimension: usize, nlist: usize, metric: &str) -> PyResult<Self> {
- validate_positive(dimension, "dimension")?;
- validate_positive(nlist, "nlist")?;
- parse_metric(metric)?;
- Ok(Self {
- dimension,
- nlist,
- metric: metric.to_string(),
- })
- }
-}
-
-#[pyclass]
-#[derive(Clone)]
-struct IvfPqConfig {
- #[pyo3(get)]
- dimension: usize,
- #[pyo3(get)]
- nlist: usize,
- #[pyo3(get)]
- m: usize,
- #[pyo3(get)]
- metric: String,
- #[pyo3(get)]
- use_opq: bool,
-}
-
-#[pymethods]
-impl IvfPqConfig {
- #[new]
- #[pyo3(signature = (dimension, nlist, m, metric="l2", use_opq=false))]
- fn new(
- dimension: usize,
- nlist: usize,
- m: usize,
- metric: &str,
- use_opq: bool,
- ) -> PyResult<Self> {
- validate_positive(dimension, "dimension")?;
- validate_positive(nlist, "nlist")?;
- validate_positive(m, "m")?;
- if !dimension.is_multiple_of(m) {
- return Err(PyValueError::new_err(format!(
- "dimension {} must be divisible by m {}",
- dimension, m
- )));
- }
- parse_metric(metric)?;
- Ok(Self {
- dimension,
- nlist,
- m,
- metric: metric.to_string(),
- use_opq,
- })
- }
-}
-
-#[pyclass]
-#[derive(Clone)]
-struct IvfHnswFlatConfig {
- #[pyo3(get)]
- dimension: usize,
- #[pyo3(get)]
- nlist: usize,
- #[pyo3(get)]
- metric: String,
- hnsw: HnswConfig,
-}
-
-#[pymethods]
-impl IvfHnswFlatConfig {
- #[new]
- #[pyo3(signature = (dimension, nlist, metric="l2", hnsw=None))]
- fn new(
- dimension: usize,
- nlist: usize,
- metric: &str,
- hnsw: Option<&HnswConfig>,
- ) -> PyResult<Self> {
- validate_positive(dimension, "dimension")?;
- validate_positive(nlist, "nlist")?;
- parse_metric(metric)?;
- Ok(Self {
- dimension,
- nlist,
- metric: metric.to_string(),
- hnsw: hnsw.cloned().unwrap_or_else(|| HnswConfig {
- m: 20,
- ef_construction: 150,
- max_level: 7,
- }),
- })
- }
-
- #[getter]
- fn hnsw(&self) -> HnswConfig {
- self.hnsw.clone()
- }
-}
-
-#[pyclass]
-#[derive(Clone)]
-struct IvfHnswSqConfig {
- #[pyo3(get)]
- dimension: usize,
- #[pyo3(get)]
- nlist: usize,
- #[pyo3(get)]
- metric: String,
- hnsw: HnswConfig,
-}
-
-#[pymethods]
-impl IvfHnswSqConfig {
- #[new]
- #[pyo3(signature = (dimension, nlist, metric="l2", hnsw=None))]
- fn new(
- dimension: usize,
- nlist: usize,
- metric: &str,
- hnsw: Option<&HnswConfig>,
- ) -> PyResult<Self> {
- validate_positive(dimension, "dimension")?;
- validate_positive(nlist, "nlist")?;
- parse_metric(metric)?;
- Ok(Self {
- dimension,
- nlist,
- metric: metric.to_string(),
- hnsw: hnsw.cloned().unwrap_or_else(|| HnswConfig {
- m: 20,
- ef_construction: 150,
- max_level: 7,
- }),
- })
- }
-
- #[getter]
- fn hnsw(&self) -> HnswConfig {
- self.hnsw.clone()
- }
-}
-
#[pyclass]
struct VectorIndexMetadata {
#[pyo3(get)]
@@ -429,53 +185,34 @@ struct VectorIndexMetadata {
total_vectors: i64,
#[pyo3(get)]
pq_m: Option<usize>,
- hnsw: Option<HnswConfig>,
+ #[pyo3(get)]
+ hnsw_m: Option<usize>,
+ #[pyo3(get)]
+ hnsw_ef_construction: Option<usize>,
+ #[pyo3(get)]
+ hnsw_max_level: Option<usize>,
}
-#[pymethods]
-impl VectorIndexMetadata {
- #[getter]
- fn hnsw(&self) -> Option<HnswConfig> {
- self.hnsw.clone()
+fn options_from_py(options: &Bound<'_, PyAny>) -> PyResult<HashMap<String,
String>> {
+ let dict: &Bound<PyDict> = options
+ .downcast()
+ .map_err(|_| PyValueError::new_err("options must be a dict[str,
str]"))?;
+ let mut result = HashMap::with_capacity(dict.len());
+ for (key, value) in dict.iter() {
+ let key = key
+ .extract::<String>()
+ .map_err(|_| PyValueError::new_err("option keys must be
strings"))?;
+ let value = value
+ .extract::<String>()
+ .map_err(|_| PyValueError::new_err("option values must be
strings"))?;
+ result.insert(key, value);
}
+ Ok(result)
}
-fn config_from_py(config: &Bound<'_, PyAny>) -> PyResult<VectorIndexConfig> {
- if let Ok(config) = config.extract::<PyRef<'_, IvfFlatConfig>>() {
- return Ok(VectorIndexConfig::IvfFlat {
- dimension: config.dimension,
- nlist: config.nlist,
- metric: parse_metric(&config.metric)?,
- });
- }
- if let Ok(config) = config.extract::<PyRef<'_, IvfPqConfig>>() {
- return Ok(VectorIndexConfig::IvfPq {
- dimension: config.dimension,
- nlist: config.nlist,
- m: config.m,
- metric: parse_metric(&config.metric)?,
- use_opq: config.use_opq,
- });
- }
- if let Ok(config) = config.extract::<PyRef<'_, IvfHnswFlatConfig>>() {
- return Ok(VectorIndexConfig::IvfHnswFlat {
- dimension: config.dimension,
- nlist: config.nlist,
- metric: parse_metric(&config.metric)?,
- hnsw: hnsw_params(Some(&config.hnsw)),
- });
- }
- if let Ok(config) = config.extract::<PyRef<'_, IvfHnswSqConfig>>() {
- return Ok(VectorIndexConfig::IvfHnswSq {
- dimension: config.dimension,
- nlist: config.nlist,
- metric: parse_metric(&config.metric)?,
- hnsw: hnsw_params(Some(&config.hnsw)),
- });
- }
- Err(PyValueError::new_err(
- "config must be IvfFlatConfig, IvfPqConfig, IvfHnswFlatConfig, or
IvfHnswSqConfig",
- ))
+fn config_from_options(options: &Bound<'_, PyAny>) ->
PyResult<VectorIndexConfig> {
+ VectorIndexConfig::from_options(&options_from_py(options)?)
+ .map_err(|e| PyValueError::new_err(format!("invalid vector index
options: {}", e)))
}
#[pyclass]
@@ -487,8 +224,8 @@ struct VectorIndexWriter {
#[pymethods]
impl VectorIndexWriter {
#[new]
- fn new(config: &Bound<'_, PyAny>) -> PyResult<Self> {
- let config = config_from_py(config)?;
+ fn new(options: &Bound<'_, PyAny>) -> PyResult<Self> {
+ let config = config_from_options(options)?;
let dimension = config.dimension();
let index = CoreVectorIndexWriter::new(config)
.map_err(|e| PyValueError::new_err(format!("failed to create
writer: {}", e)))?;
@@ -498,99 +235,13 @@ impl VectorIndexWriter {
})
}
- #[staticmethod]
- #[pyo3(signature = (dimension, nlist, metric="l2"))]
- fn ivf_flat(dimension: usize, nlist: usize, metric: &str) ->
PyResult<Self> {
- let config = VectorIndexConfig::IvfFlat {
- dimension,
- nlist,
- metric: parse_metric(metric)?,
- };
- let index = CoreVectorIndexWriter::new(config)
- .map_err(|e| PyValueError::new_err(format!("failed to create
writer: {}", e)))?;
- Ok(Self {
- index: Some(index),
- dimension,
- })
- }
-
- #[staticmethod]
- #[pyo3(signature = (dimension, nlist, m, metric="l2", use_opq=false))]
- fn ivf_pq(
- dimension: usize,
- nlist: usize,
- m: usize,
- metric: &str,
- use_opq: bool,
- ) -> PyResult<Self> {
- let config = IvfPqConfig::new(dimension, nlist, m, metric, use_opq)?;
- let core = VectorIndexConfig::IvfPq {
- dimension: config.dimension,
- nlist: config.nlist,
- m: config.m,
- metric: parse_metric(&config.metric)?,
- use_opq: config.use_opq,
- };
- let index = CoreVectorIndexWriter::new(core)
- .map_err(|e| PyValueError::new_err(format!("failed to create
writer: {}", e)))?;
- Ok(Self {
- index: Some(index),
- dimension,
- })
- }
-
- #[staticmethod]
- #[pyo3(signature = (dimension, nlist, metric="l2", hnsw=None))]
- fn ivf_hnsw_flat(
- dimension: usize,
- nlist: usize,
- metric: &str,
- hnsw: Option<&HnswConfig>,
- ) -> PyResult<Self> {
- let config = VectorIndexConfig::IvfHnswFlat {
- dimension,
- nlist,
- metric: parse_metric(metric)?,
- hnsw: hnsw_params(hnsw),
- };
- let index = CoreVectorIndexWriter::new(config)
- .map_err(|e| PyValueError::new_err(format!("failed to create
writer: {}", e)))?;
- Ok(Self {
- index: Some(index),
- dimension,
- })
- }
-
- #[staticmethod]
- #[pyo3(signature = (dimension, nlist, metric="l2", hnsw=None))]
- fn ivf_hnsw_sq(
- dimension: usize,
- nlist: usize,
- metric: &str,
- hnsw: Option<&HnswConfig>,
- ) -> PyResult<Self> {
- let config = VectorIndexConfig::IvfHnswSq {
- dimension,
- nlist,
- metric: parse_metric(metric)?,
- hnsw: hnsw_params(hnsw),
- };
- let index = CoreVectorIndexWriter::new(config)
- .map_err(|e| PyValueError::new_err(format!("failed to create
writer: {}", e)))?;
- Ok(Self {
- index: Some(index),
- dimension,
- })
- }
-
#[getter]
fn dimension(&self) -> usize {
self.dimension
}
fn train(&mut self, data: PyReadonlyArray2<f32>) -> PyResult<()> {
- let shape = data.shape();
- let row_count = validate_matrix_shape(shape, self.dimension, "data",
"writer dimension")?;
+ let row_count = data.shape()[0];
let data_slice = data.as_slice().map_err(|_| {
PyValueError::new_err("data must be a contiguous two-dimensional
float32 array")
})?;
@@ -605,16 +256,8 @@ impl VectorIndexWriter {
ids: PyReadonlyArray1<i64>,
data: PyReadonlyArray2<f32>,
) -> PyResult<()> {
- let shape = data.shape();
- let row_count = validate_matrix_shape(shape, self.dimension, "data",
"writer dimension")?;
+ let row_count = data.shape()[0];
let id_slice = ids.as_slice()?;
- if id_slice.len() != row_count {
- return Err(PyValueError::new_err(format!(
- "ids length {} != vector count {}",
- id_slice.len(),
- row_count
- )));
- }
let data_slice = data.as_slice().map_err(|_| {
PyValueError::new_err("data must be a contiguous two-dimensional
float32 array")
})?;
@@ -705,11 +348,9 @@ impl VectorIndexReader {
metric: metric_name(metadata.metric).to_string(),
total_vectors: metadata.total_vectors,
pq_m: metadata.pq_m,
- hnsw: metadata.hnsw.map(|h| HnswConfig {
- m: h.m,
- ef_construction: h.ef_construction,
- max_level: h.max_level,
- }),
+ hnsw_m: metadata.hnsw.map(|h| h.m),
+ hnsw_ef_construction: metadata.hnsw.map(|h| h.ef_construction),
+ hnsw_max_level: metadata.hnsw.map(|h| h.max_level),
}
}
@@ -725,16 +366,6 @@ impl VectorIndexReader {
filter_bytes: Option<&Bound<'_, PyAny>>,
) -> PyResult<(Bound<'py, PyArray1<i64>>, Bound<'py, PyArray1<f32>>)> {
let query_slice = query.as_slice()?;
- let dimension = self.inner.metadata().dimension;
- if query_slice.len() != dimension {
- return Err(PyValueError::new_err(format!(
- "query length {} != index dimension {}",
- query_slice.len(),
- dimension
- )));
- }
- validate_positive(top_k, "top_k")?;
- validate_positive(nprobe, "nprobe")?;
let params = VectorSearchParams::with_ef_search(top_k, nprobe,
ef_search);
let (ids, dists) = if let Some(bytes) =
decode_filter_bytes(filter_bytes)? {
@@ -764,11 +395,7 @@ impl VectorIndexReader {
ef_search: usize,
filter_bytes: Option<&Bound<'_, PyAny>>,
) -> PyResult<(Bound<'py, PyArray2<i64>>, Bound<'py, PyArray2<f32>>)> {
- let dimension = self.inner.metadata().dimension;
- let shape = queries.shape();
- let query_count = validate_matrix_shape(shape, dimension, "query",
"index dimension")?;
- validate_positive(top_k, "top_k")?;
- validate_positive(nprobe, "nprobe")?;
+ let query_count = queries.shape()[0];
let query_slice = queries.as_slice().map_err(|_| {
PyValueError::new_err("queries must be a contiguous
two-dimensional float32 array")
})?;
@@ -812,11 +439,6 @@ impl VectorIndexReader {
#[pymodule]
fn paimon_vindex(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
- m.add_class::<HnswConfig>()?;
- m.add_class::<IvfFlatConfig>()?;
- m.add_class::<IvfPqConfig>()?;
- m.add_class::<IvfHnswFlatConfig>()?;
- m.add_class::<IvfHnswSqConfig>()?;
m.add_class::<VectorIndexMetadata>()?;
m.add_class::<VectorIndexReader>()?;
m.add_class::<VectorIndexWriter>()?;
@@ -869,10 +491,7 @@ mod tests {
output
}
- fn vector_index_input<'py>(
- py: Python<'py>,
- output: &Bound<'py, PyAny>,
- ) -> Bound<'py, PyAny> {
+ fn vector_index_input<'py>(py: Python<'py>, output: &Bound<'py, PyAny>) ->
Bound<'py, PyAny> {
let data = output
.call_method0("getvalue")
.unwrap()
@@ -889,6 +508,14 @@ mod tests {
.into_any()
}
+ fn options<'py>(py: Python<'py>, values: &[(&str, &str)]) -> Bound<'py,
PyAny> {
+ let dict = PyDict::new_bound(py);
+ for (key, value) in values {
+ dict.set_item(*key, *value).unwrap();
+ }
+ dict.into_any()
+ }
+
#[pyclass]
struct PyBytesVectorIndexInput {
data: Vec<u8>,
@@ -904,9 +531,9 @@ mod tests {
let result = PyList::empty_bound(py);
for item in ranges.iter() {
let (pos, len): (usize, usize) = item.extract()?;
- let end = pos.checked_add(len).ok_or_else(|| {
- PyIOError::new_err("pread_many range position overflow")
- })?;
+ let end = pos
+ .checked_add(len)
+ .ok_or_else(|| PyIOError::new_err("pread_many range
position overflow"))?;
if end > self.data.len() {
return Err(PyIOError::new_err(format!(
"pread_many range {}..{} out of bounds {}",
@@ -926,34 +553,57 @@ mod tests {
Python::with_gil(|py| {
let configs: Vec<(Bound<'_, PyAny>, usize, &str)> = vec![
(
- Py::new(py, IvfFlatConfig::new(16, 4, "l2").unwrap())
- .unwrap()
- .into_bound(py)
- .into_any(),
+ options(
+ py,
+ &[
+ ("index.type", "ivf_flat"),
+ ("dimension", "16"),
+ ("nlist", "4"),
+ ("metric", "l2"),
+ ],
+ ),
16,
"ivf_flat",
),
(
- Py::new(py, IvfPqConfig::new(16, 4, 4, "l2",
false).unwrap())
- .unwrap()
- .into_bound(py)
- .into_any(),
+ options(
+ py,
+ &[
+ ("index.type", "ivf_pq"),
+ ("dimension", "16"),
+ ("nlist", "4"),
+ ("pq.m", "4"),
+ ("metric", "l2"),
+ ("use-opq", "false"),
+ ],
+ ),
16,
"ivf_pq",
),
(
- Py::new(py, IvfHnswFlatConfig::new(16, 4, "l2",
None).unwrap())
- .unwrap()
- .into_bound(py)
- .into_any(),
+ options(
+ py,
+ &[
+ ("index.type", "ivf_hnsw_flat"),
+ ("dimension", "16"),
+ ("nlist", "4"),
+ ("metric", "l2"),
+ ],
+ ),
16,
"ivf_hnsw_flat",
),
(
- Py::new(py, IvfHnswSqConfig::new(16, 4, "l2",
None).unwrap())
- .unwrap()
- .into_bound(py)
- .into_any(),
+ options(
+ py,
+ &[
+ ("index.type", "ivf_hnsw_sq"),
+ ("dimension", "16"),
+ ("nlist", "4"),
+ ("metric", "l2"),
+ ("hnsw.m", "12"),
+ ],
+ ),
16,
"ivf_hnsw_sq",
),
@@ -969,9 +619,7 @@ mod tests {
let data = generate_clustered_data(1, d, 1);
let query = PyArray1::from_vec_bound(py, data[0..d].to_vec());
- let (result_ids, _) = reader
- .search(py, query.readonly(), 5, 4, 32, None)
- .unwrap();
+ let (result_ids, _) = reader.search(py, query.readonly(), 5,
4, 32, None).unwrap();
assert_eq!(result_ids.len(), 5);
assert_eq!(result_ids.readonly().as_slice().unwrap()[0], 0);
}
@@ -981,10 +629,15 @@ mod tests {
#[test]
fn python_batch_search_accepts_roaring_filter_bytes() {
Python::with_gil(|py| {
- let config = Py::new(py, IvfFlatConfig::new(2, 1, "l2").unwrap())
- .unwrap()
- .into_bound(py)
- .into_any();
+ let config = options(
+ py,
+ &[
+ ("index.type", "ivf_flat"),
+ ("dimension", "2"),
+ ("nlist", "1"),
+ ("metric", "l2"),
+ ],
+ );
let io = py.import_bound("io").unwrap();
let output = io.getattr("BytesIO").unwrap().call0().unwrap();
let mut writer = VectorIndexWriter::new(&config).unwrap();
@@ -1016,33 +669,67 @@ mod tests {
.unwrap();
assert_eq!(result_ids.shape(), &[2, 2]);
- assert_eq!(
- result_ids.readonly().as_slice().unwrap(),
- &[12, -1, 12, -1]
- );
+ assert_eq!(result_ids.readonly().as_slice().unwrap(), &[12, -1,
12, -1]);
assert_eq!(result_dists.readonly().as_slice().unwrap()[1],
f32::MAX);
});
}
#[test]
- fn python_batch_search_validates_query_shape() {
+ fn python_delegates_validation_to_core() {
Python::with_gil(|py| {
- let config = Py::new(py, IvfPqConfig::new(16, 4, 4, "l2",
false).unwrap())
- .unwrap()
- .into_bound(py)
- .into_any();
+ let config = options(
+ py,
+ &[
+ ("index.type", "ivf_pq"),
+ ("dimension", "16"),
+ ("nlist", "4"),
+ ("pq.m", "4"),
+ ("metric", "l2"),
+ ],
+ );
+ let mut writer = VectorIndexWriter::new(&config).unwrap();
+ let wrong_train = PyArray::from_vec2_bound(py, &[vec![0.0f32;
17]]).unwrap();
+
+ let err = writer.train(wrong_train.readonly()).unwrap_err();
+ assert!(err
+ .to_string()
+ .contains("training data length 17 does not match vector count
* dimension 16"));
+
+ let data = PyArray::from_vec2_bound(py, &[vec![0.0f32;
16]]).unwrap();
+ let ids = PyArray1::from_vec_bound(py, vec![1i64, 2]);
+
+ let err = writer
+ .add_vectors(ids.readonly(), data.readonly())
+ .unwrap_err();
+ assert!(err
+ .to_string()
+ .contains("ids length 2 does not match vector count 1"));
+
let output = write_index_bytes(py, &config, 16);
let input = vector_index_input(py, &output);
let mut reader = VectorIndexReader::new(input.unbind()).unwrap();
+ let wrong_query = PyArray1::from_vec_bound(py, vec![0.0f32; 15]);
let wrong_dim = PyArray::from_vec2_bound(py, &[vec![0.0f32;
15]]).unwrap();
let err = reader
- .search_batch(py, wrong_dim.readonly(), 5, 2, 0, None)
+ .search(py, wrong_query.readonly(), 5, 2, 0, None)
+ .unwrap_err();
+ assert!(err
+ .to_string()
+ .contains("query length 15 does not match index dimension
16"));
+
+ let query = PyArray1::from_vec_bound(py, vec![0.0f32; 16]);
+ let err = reader
+ .search(py, query.readonly(), 0, 2, 0, None)
.unwrap_err();
+ assert!(err.to_string().contains("k must be greater than 0"));
+ let err = reader
+ .search_batch(py, wrong_dim.readonly(), 5, 2, 0, None)
+ .unwrap_err();
assert!(err
.to_string()
- .contains("query dimension 15 != index dimension 16"));
+ .contains("queries length 15 does not match nq * dimension
16"));
});
}
@@ -1066,10 +753,16 @@ class ShortWriter:
.unwrap()
.call0()
.unwrap();
- let config = Py::new(py, IvfPqConfig::new(16, 4, 4, "l2",
false).unwrap())
- .unwrap()
- .into_bound(py)
- .into_any();
+ let config = options(
+ py,
+ &[
+ ("index.type", "ivf_pq"),
+ ("dimension", "16"),
+ ("nlist", "4"),
+ ("pq.m", "4"),
+ ("metric", "l2"),
+ ],
+ );
let mut writer = VectorIndexWriter::new(&config).unwrap();
let data = generate_clustered_data(500, 16, 4);
let train = PyArray::from_vec2_bound(