This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git
The following commit(s) were added to refs/heads/main by this push:
new 9b7afc1 Reject non-finite vector inputs (#33)
9b7afc1 is described below
commit 9b7afc1fae91a46fafb86763ce290a779a4d3bd1
Author: QuakeWang <[email protected]>
AuthorDate: Thu Jun 11 11:06:09 2026 +0800
Reject non-finite vector inputs (#33)
---
core/src/index.rs | 174 +++++++++++++++++++++
.../vector/VectorIndexNativePanicBoundaryTest.java | 39 ++++-
.../vector/VectorIndexNativeValidationTest.java | 91 +++++++++++
3 files changed, 302 insertions(+), 2 deletions(-)
diff --git a/core/src/index.rs b/core/src/index.rs
index 6fcaa00..e66796b 100644
--- a/core/src/index.rs
+++ b/core/src/index.rs
@@ -531,6 +531,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
query: &[f32],
params: VectorSearchParams,
) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ validate_query(query, self.dimension())?;
match self {
Self::IvfFlat(reader) => reader.search(query, params.top_k,
params.nprobe),
Self::IvfPq(reader) => search_with_reader(reader, query,
params.top_k, params.nprobe),
@@ -549,6 +550,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
params: VectorSearchParams,
roaring_filter_bytes: &[u8],
) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ validate_query(query, self.dimension())?;
match self {
Self::IvfFlat(reader) => reader.search_with_roaring_filter(
query,
@@ -586,6 +588,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
query_count: usize,
params: VectorSearchParams,
) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ validate_queries(queries, query_count, self.dimension())?;
match self {
Self::IvfFlat(reader) => search_batch_ivfflat_reader(
reader,
@@ -623,6 +626,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
params: VectorSearchParams,
roaring_filter_bytes: &[u8],
) -> io::Result<(Vec<i64>, Vec<f32>)> {
+ validate_queries(queries, query_count, self.dimension())?;
match self {
Self::IvfFlat(reader) =>
search_batch_ivfflat_reader_roaring_filter(
reader,
@@ -720,6 +724,57 @@ fn validate_vectors(data: &[f32], n: usize, dimension:
usize, value_name: &str)
),
));
}
+ validate_finite_values(data, expected_len, value_name)?;
+ Ok(())
+}
+
+fn validate_query(query: &[f32], dimension: usize) -> io::Result<()> {
+ if query.len() != dimension {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "query length {} does not match index dimension {}",
+ query.len(),
+ dimension
+ ),
+ ));
+ }
+ validate_finite_values(query, dimension, "query")
+}
+
+fn validate_queries(queries: &[f32], query_count: usize, dimension: usize) ->
io::Result<()> {
+ validate_positive(query_count, "query count")?;
+ let expected_len = query_count.checked_mul(dimension).ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "nq * dimension overflows usize",
+ )
+ })?;
+ if queries.len() != expected_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "queries length {} does not match nq * dimension {}",
+ queries.len(),
+ expected_len
+ ),
+ ));
+ }
+ validate_finite_values(queries, expected_len, "queries")
+}
+
+fn validate_finite_values(values: &[f32], len: usize, value_name: &str) ->
io::Result<()> {
+ for (offset, &value) in values[..len].iter().enumerate() {
+ if !value.is_finite() {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!(
+ "{} contains non-finite value at offset {}: {}",
+ value_name, offset, value
+ ),
+ ));
+ }
+ }
Ok(())
}
@@ -769,6 +824,32 @@ mod tests {
assert_eq!(result_ids[0], 0);
}
+ fn build_ivfflat_reader() -> VectorIndexReader<Cursor<Vec<u8>>> {
+ let mut writer = VectorIndexWriter::new(VectorIndexConfig::IvfFlat {
+ dimension: 1,
+ nlist: 1,
+ metric: MetricType::L2,
+ })
+ .unwrap();
+ writer.train(&[0.0, 1.0], 2).unwrap();
+ writer.add_vectors(&[1, 2], &[0.0, 1.0], 2).unwrap();
+
+ let mut bytes = Vec::new();
+ writer.write(&mut PosWriter::new(&mut bytes)).unwrap();
+ VectorIndexReader::open(Cursor::new(bytes)).unwrap()
+ }
+
+ fn assert_invalid_input_contains(result: io::Result<()>, expected: &str) {
+ let err = result.expect_err("invalid input should be rejected");
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(
+ err.to_string().contains(expected),
+ "error '{}' should contain '{}'",
+ err,
+ expected
+ );
+ }
+
#[test]
fn unified_reader_writer_roundtrips_all_index_types() {
roundtrip(VectorIndexConfig::IvfFlat {
@@ -877,6 +958,32 @@ mod tests {
}
}
+ #[test]
+ fn unified_writer_rejects_non_finite_training_data() {
+ for (value, expected) in [
+ (
+ f32::NAN,
+ "training data contains non-finite value at offset 0: NaN",
+ ),
+ (
+ f32::INFINITY,
+ "training data contains non-finite value at offset 0: inf",
+ ),
+ (
+ f32::NEG_INFINITY,
+ "training data contains non-finite value at offset 0: -inf",
+ ),
+ ] {
+ let mut writer = VectorIndexWriter::new(VectorIndexConfig::IvfFlat
{
+ dimension: 1,
+ nlist: 1,
+ metric: MetricType::L2,
+ })
+ .unwrap();
+ assert_invalid_input_contains(writer.train(&[value, 1.0], 2),
expected);
+ }
+ }
+
#[test]
fn config_from_options_rejects_unknown_options() {
let err = VectorIndexConfig::from_options(&options(&[
@@ -927,4 +1034,71 @@ mod tests {
.unwrap_err();
assert!(err.to_string().contains("unknown metric"));
}
+
+ #[test]
+ fn unified_writer_rejects_non_finite_vector_data() {
+ for (value, expected) in [
+ (
+ f32::NAN,
+ "vector data contains non-finite value at offset 0: NaN",
+ ),
+ (
+ f32::INFINITY,
+ "vector data contains non-finite value at offset 0: inf",
+ ),
+ (
+ f32::NEG_INFINITY,
+ "vector data contains non-finite value at offset 0: -inf",
+ ),
+ ] {
+ let mut writer = VectorIndexWriter::new(VectorIndexConfig::IvfFlat
{
+ dimension: 1,
+ nlist: 1,
+ metric: MetricType::L2,
+ })
+ .unwrap();
+ writer.train(&[0.0, 1.0], 2).unwrap();
+ assert_invalid_input_contains(writer.add_vectors(&[1, 2], &[value,
1.0], 2), expected);
+ }
+ }
+
+ #[test]
+ fn unified_reader_rejects_non_finite_query() {
+ let mut reader = build_ivfflat_reader();
+ let err = reader
+ .search(&[f32::NAN], VectorSearchParams::new(1, 1))
+ .expect_err("non-finite query should be rejected");
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err
+ .to_string()
+ .contains("query contains non-finite value at offset 0: NaN"));
+ }
+
+ #[test]
+ fn unified_reader_rejects_non_finite_batch_query() {
+ let mut reader = build_ivfflat_reader();
+ let err = reader
+ .search_batch(&[f32::NEG_INFINITY], 1, VectorSearchParams::new(1,
1))
+ .expect_err("non-finite batch query should be rejected");
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err
+ .to_string()
+ .contains("queries contains non-finite value at offset 0: -inf"));
+ }
+
+ #[test]
+ fn unified_reader_rejects_non_finite_query_before_decoding_filter() {
+ let mut reader = build_ivfflat_reader();
+ let err = reader
+ .search_with_roaring_filter(&[f32::NAN],
VectorSearchParams::new(1, 1), &[0xFF])
+ .expect_err("non-finite filtered query should be rejected");
+
+ assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+ assert!(err
+ .to_string()
+ .contains("query contains non-finite value at offset 0: NaN"));
+ assert!(!err.to_string().contains("invalid RoaringTreemap"));
+ }
}
diff --git
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
index 98e99b9..739defe 100644
---
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
+++
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
@@ -56,14 +56,16 @@ public class VectorIndexNativePanicBoundaryTest {
VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions());
try {
writer.train(new float[] {0.0f, 1.0f}, 2);
- writer.addVectors(new long[] {1L, 2L}, new float[] {Float.NaN,
1.0f}, 2);
+ writer.addVectors(new long[] {1L, 2L}, new float[] {0.0f, 1.0f},
2);
writer.writeIndex(output);
} finally {
writer.close();
}
+ byte[] indexBytes = output.toByteArray();
+ corruptFirstIvfFlatVector(indexBytes, Float.NaN);
VectorIndexReader reader =
- new VectorIndexReader(new
ByteArraySeekableInputStream(output.toByteArray()));
+ new VectorIndexReader(new
ByteArraySeekableInputStream(indexBytes));
try {
assertEquals(1, reader.dimension());
assertThrows(RuntimeException.class, new ThrowingRunnable() {
@@ -115,6 +117,39 @@ public class VectorIndexNativePanicBoundaryTest {
throw new AssertionError("expected " + expected.getName());
}
+ private static void corruptFirstIvfFlatVector(byte[] indexBytes, float
value) {
+ int dimension = readIntLe(indexBytes, 8);
+ int nlist = readIntLe(indexBytes, 12);
+ int offsetTable = 64 + dimension * nlist * Float.BYTES;
+ int listOffset = (int) readLongLe(indexBytes, offsetTable);
+ int idBytesLength = readIntLe(indexBytes, listOffset + Long.BYTES);
+ int firstVectorOffset = listOffset + Long.BYTES + Integer.BYTES +
idBytesLength;
+ writeFloatLe(indexBytes, firstVectorOffset, value);
+ }
+
+ private static int readIntLe(byte[] bytes, int offset) {
+ return (bytes[offset] & 0xFF)
+ | ((bytes[offset + 1] & 0xFF) << 8)
+ | ((bytes[offset + 2] & 0xFF) << 16)
+ | ((bytes[offset + 3] & 0xFF) << 24);
+ }
+
+ private static long readLongLe(byte[] bytes, int offset) {
+ long result = 0L;
+ for (int i = 0; i < Long.BYTES; i++) {
+ result |= (long) (bytes[offset + i] & 0xFF) << (8 * i);
+ }
+ return result;
+ }
+
+ private static void writeFloatLe(byte[] bytes, int offset, float value) {
+ int bits = Float.floatToRawIntBits(value);
+ bytes[offset] = (byte) bits;
+ bytes[offset + 1] = (byte) (bits >>> 8);
+ bytes[offset + 2] = (byte) (bits >>> 16);
+ bytes[offset + 3] = (byte) (bits >>> 24);
+ }
+
private interface ThrowingRunnable {
void run() throws Throwable;
}
diff --git
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
index ac0c3cc..6e40171 100644
---
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
+++
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
@@ -31,7 +31,9 @@ public class VectorIndexNativeValidationTest {
System.load(args[0]);
testWriterValidationComesFromCore();
+ testWriterRejectsNonFiniteValues();
testReaderValidationComesFromCore();
+ testReaderRejectsNonFiniteQueries();
}
private static void testWriterValidationComesFromCore() {
@@ -104,6 +106,74 @@ public class VectorIndexNativeValidationTest {
}
}
+ private static void testWriterRejectsNonFiniteValues() {
+ final VectorIndexWriter trainingWriter = new
VectorIndexWriter(ivfFlatOptions());
+ try {
+ assertThrowsMessage(
+ RuntimeException.class,
+ "training data contains non-finite value at offset 0: NaN",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ trainingWriter.train(new float[] {Float.NaN,
1.0f}, 2);
+ }
+ });
+ } finally {
+ trainingWriter.close();
+ }
+
+ final VectorIndexWriter vectorWriter = new
VectorIndexWriter(ivfFlatOptions());
+ try {
+ vectorWriter.train(new float[] {0.0f, 1.0f}, 2);
+ assertThrowsMessage(
+ RuntimeException.class,
+ "vector data contains non-finite value at offset 0: inf",
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ vectorWriter.addVectors(
+ new long[] {1L, 2L},
+ new float[] {Float.POSITIVE_INFINITY,
1.0f},
+ 2);
+ }
+ });
+ } finally {
+ vectorWriter.close();
+ }
+ }
+
+ private static void testReaderRejectsNonFiniteQueries() {
+ VectorIndexReader reader = new VectorIndexReader(new
ByteArraySeekableInputStream(buildIndexBytes()));
+ try {
+ assertInvalidInput(
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ reader.search(new float[] {Float.NaN}, 1, 1);
+ }
+ },
+ "query contains non-finite value at offset 0: NaN");
+ assertInvalidInput(
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ reader.searchBatch(new float[]
{Float.NEGATIVE_INFINITY}, 1, 1, 1);
+ }
+ },
+ "queries contains non-finite value at offset 0: -inf");
+ assertInvalidInput(
+ new ThrowingRunnable() {
+ @Override
+ public void run() {
+ reader.search(new float[] {Float.NaN}, 1, 1, new
byte[] {(byte) 0xFF});
+ }
+ },
+ "query contains non-finite value at offset 0: NaN");
+ } finally {
+ reader.close();
+ }
+ }
+
private static byte[] buildIndexBytes() {
VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions());
ByteArrayPositionOutputStream output = new
ByteArrayPositionOutputStream();
@@ -144,6 +214,27 @@ public class VectorIndexNativeValidationTest {
throw new AssertionError("expected " + expected.getName());
}
+ private static void assertInvalidInput(ThrowingRunnable runnable, String
expectedMessage) {
+ try {
+ runnable.run();
+ } catch (RuntimeException e) {
+ String message = e.getMessage();
+ if (message == null || !message.contains(expectedMessage)) {
+ throw new AssertionError("unexpected exception message: " +
message, e);
+ }
+ if (message.contains("Rust panic in JNI call")) {
+ throw new AssertionError("invalid input should not cross the
panic boundary", e);
+ }
+ if (message.contains("invalid RoaringTreemap")) {
+ throw new AssertionError("query validation should run before
filter decoding", e);
+ }
+ return;
+ } catch (Throwable t) {
+ throw new AssertionError("expected RuntimeException but got " +
t.getClass().getName(), t);
+ }
+ throw new AssertionError("expected RuntimeException");
+ }
+
private interface ThrowingRunnable {
void run() throws Throwable;
}