This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-vector-index.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b7afc1  Reject non-finite vector inputs (#33)
9b7afc1 is described below

commit 9b7afc1fae91a46fafb86763ce290a779a4d3bd1
Author: QuakeWang <[email protected]>
AuthorDate: Thu Jun 11 11:06:09 2026 +0800

    Reject non-finite vector inputs (#33)
---
 core/src/index.rs                                  | 174 +++++++++++++++++++++
 .../vector/VectorIndexNativePanicBoundaryTest.java |  39 ++++-
 .../vector/VectorIndexNativeValidationTest.java    |  91 +++++++++++
 3 files changed, 302 insertions(+), 2 deletions(-)

diff --git a/core/src/index.rs b/core/src/index.rs
index 6fcaa00..e66796b 100644
--- a/core/src/index.rs
+++ b/core/src/index.rs
@@ -531,6 +531,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
         query: &[f32],
         params: VectorSearchParams,
     ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        validate_query(query, self.dimension())?;
         match self {
             Self::IvfFlat(reader) => reader.search(query, params.top_k, 
params.nprobe),
             Self::IvfPq(reader) => search_with_reader(reader, query, 
params.top_k, params.nprobe),
@@ -549,6 +550,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
         params: VectorSearchParams,
         roaring_filter_bytes: &[u8],
     ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        validate_query(query, self.dimension())?;
         match self {
             Self::IvfFlat(reader) => reader.search_with_roaring_filter(
                 query,
@@ -586,6 +588,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
         query_count: usize,
         params: VectorSearchParams,
     ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        validate_queries(queries, query_count, self.dimension())?;
         match self {
             Self::IvfFlat(reader) => search_batch_ivfflat_reader(
                 reader,
@@ -623,6 +626,7 @@ impl<R: SeekRead> VectorIndexReader<R> {
         params: VectorSearchParams,
         roaring_filter_bytes: &[u8],
     ) -> io::Result<(Vec<i64>, Vec<f32>)> {
+        validate_queries(queries, query_count, self.dimension())?;
         match self {
             Self::IvfFlat(reader) => 
search_batch_ivfflat_reader_roaring_filter(
                 reader,
@@ -720,6 +724,57 @@ fn validate_vectors(data: &[f32], n: usize, dimension: 
usize, value_name: &str)
             ),
         ));
     }
+    validate_finite_values(data, expected_len, value_name)?;
+    Ok(())
+}
+
+fn validate_query(query: &[f32], dimension: usize) -> io::Result<()> {
+    if query.len() != dimension {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!(
+                "query length {} does not match index dimension {}",
+                query.len(),
+                dimension
+            ),
+        ));
+    }
+    validate_finite_values(query, dimension, "query")
+}
+
+fn validate_queries(queries: &[f32], query_count: usize, dimension: usize) -> 
io::Result<()> {
+    validate_positive(query_count, "query count")?;
+    let expected_len = query_count.checked_mul(dimension).ok_or_else(|| {
+        io::Error::new(
+            io::ErrorKind::InvalidInput,
+            "nq * dimension overflows usize",
+        )
+    })?;
+    if queries.len() != expected_len {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidInput,
+            format!(
+                "queries length {} does not match nq * dimension {}",
+                queries.len(),
+                expected_len
+            ),
+        ));
+    }
+    validate_finite_values(queries, expected_len, "queries")
+}
+
+fn validate_finite_values(values: &[f32], len: usize, value_name: &str) -> 
io::Result<()> {
+    for (offset, &value) in values[..len].iter().enumerate() {
+        if !value.is_finite() {
+            return Err(io::Error::new(
+                io::ErrorKind::InvalidInput,
+                format!(
+                    "{} contains non-finite value at offset {}: {}",
+                    value_name, offset, value
+                ),
+            ));
+        }
+    }
     Ok(())
 }
 
@@ -769,6 +824,32 @@ mod tests {
         assert_eq!(result_ids[0], 0);
     }
 
+    fn build_ivfflat_reader() -> VectorIndexReader<Cursor<Vec<u8>>> {
+        let mut writer = VectorIndexWriter::new(VectorIndexConfig::IvfFlat {
+            dimension: 1,
+            nlist: 1,
+            metric: MetricType::L2,
+        })
+        .unwrap();
+        writer.train(&[0.0, 1.0], 2).unwrap();
+        writer.add_vectors(&[1, 2], &[0.0, 1.0], 2).unwrap();
+
+        let mut bytes = Vec::new();
+        writer.write(&mut PosWriter::new(&mut bytes)).unwrap();
+        VectorIndexReader::open(Cursor::new(bytes)).unwrap()
+    }
+
+    fn assert_invalid_input_contains(result: io::Result<()>, expected: &str) {
+        let err = result.expect_err("invalid input should be rejected");
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(
+            err.to_string().contains(expected),
+            "error '{}' should contain '{}'",
+            err,
+            expected
+        );
+    }
+
     #[test]
     fn unified_reader_writer_roundtrips_all_index_types() {
         roundtrip(VectorIndexConfig::IvfFlat {
@@ -877,6 +958,32 @@ mod tests {
         }
     }
 
+    #[test]
+    fn unified_writer_rejects_non_finite_training_data() {
+        for (value, expected) in [
+            (
+                f32::NAN,
+                "training data contains non-finite value at offset 0: NaN",
+            ),
+            (
+                f32::INFINITY,
+                "training data contains non-finite value at offset 0: inf",
+            ),
+            (
+                f32::NEG_INFINITY,
+                "training data contains non-finite value at offset 0: -inf",
+            ),
+        ] {
+            let mut writer = VectorIndexWriter::new(VectorIndexConfig::IvfFlat 
{
+                dimension: 1,
+                nlist: 1,
+                metric: MetricType::L2,
+            })
+            .unwrap();
+            assert_invalid_input_contains(writer.train(&[value, 1.0], 2), 
expected);
+        }
+    }
+
     #[test]
     fn config_from_options_rejects_unknown_options() {
         let err = VectorIndexConfig::from_options(&options(&[
@@ -927,4 +1034,71 @@ mod tests {
         .unwrap_err();
         assert!(err.to_string().contains("unknown metric"));
     }
+
+    #[test]
+    fn unified_writer_rejects_non_finite_vector_data() {
+        for (value, expected) in [
+            (
+                f32::NAN,
+                "vector data contains non-finite value at offset 0: NaN",
+            ),
+            (
+                f32::INFINITY,
+                "vector data contains non-finite value at offset 0: inf",
+            ),
+            (
+                f32::NEG_INFINITY,
+                "vector data contains non-finite value at offset 0: -inf",
+            ),
+        ] {
+            let mut writer = VectorIndexWriter::new(VectorIndexConfig::IvfFlat 
{
+                dimension: 1,
+                nlist: 1,
+                metric: MetricType::L2,
+            })
+            .unwrap();
+            writer.train(&[0.0, 1.0], 2).unwrap();
+            assert_invalid_input_contains(writer.add_vectors(&[1, 2], &[value, 
1.0], 2), expected);
+        }
+    }
+
+    #[test]
+    fn unified_reader_rejects_non_finite_query() {
+        let mut reader = build_ivfflat_reader();
+        let err = reader
+            .search(&[f32::NAN], VectorSearchParams::new(1, 1))
+            .expect_err("non-finite query should be rejected");
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(err
+            .to_string()
+            .contains("query contains non-finite value at offset 0: NaN"));
+    }
+
+    #[test]
+    fn unified_reader_rejects_non_finite_batch_query() {
+        let mut reader = build_ivfflat_reader();
+        let err = reader
+            .search_batch(&[f32::NEG_INFINITY], 1, VectorSearchParams::new(1, 
1))
+            .expect_err("non-finite batch query should be rejected");
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(err
+            .to_string()
+            .contains("queries contains non-finite value at offset 0: -inf"));
+    }
+
+    #[test]
+    fn unified_reader_rejects_non_finite_query_before_decoding_filter() {
+        let mut reader = build_ivfflat_reader();
+        let err = reader
+            .search_with_roaring_filter(&[f32::NAN], 
VectorSearchParams::new(1, 1), &[0xFF])
+            .expect_err("non-finite filtered query should be rejected");
+
+        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
+        assert!(err
+            .to_string()
+            .contains("query contains non-finite value at offset 0: NaN"));
+        assert!(!err.to_string().contains("invalid RoaringTreemap"));
+    }
 }
diff --git 
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
 
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
index 98e99b9..739defe 100644
--- 
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
+++ 
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java
@@ -56,14 +56,16 @@ public class VectorIndexNativePanicBoundaryTest {
         VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions());
         try {
             writer.train(new float[] {0.0f, 1.0f}, 2);
-            writer.addVectors(new long[] {1L, 2L}, new float[] {Float.NaN, 
1.0f}, 2);
+            writer.addVectors(new long[] {1L, 2L}, new float[] {0.0f, 1.0f}, 
2);
             writer.writeIndex(output);
         } finally {
             writer.close();
         }
+        byte[] indexBytes = output.toByteArray();
+        corruptFirstIvfFlatVector(indexBytes, Float.NaN);
 
         VectorIndexReader reader =
-                new VectorIndexReader(new 
ByteArraySeekableInputStream(output.toByteArray()));
+                new VectorIndexReader(new 
ByteArraySeekableInputStream(indexBytes));
         try {
             assertEquals(1, reader.dimension());
             assertThrows(RuntimeException.class, new ThrowingRunnable() {
@@ -115,6 +117,39 @@ public class VectorIndexNativePanicBoundaryTest {
         throw new AssertionError("expected " + expected.getName());
     }
 
+    private static void corruptFirstIvfFlatVector(byte[] indexBytes, float 
value) {
+        int dimension = readIntLe(indexBytes, 8);
+        int nlist = readIntLe(indexBytes, 12);
+        int offsetTable = 64 + dimension * nlist * Float.BYTES;
+        int listOffset = (int) readLongLe(indexBytes, offsetTable);
+        int idBytesLength = readIntLe(indexBytes, listOffset + Long.BYTES);
+        int firstVectorOffset = listOffset + Long.BYTES + Integer.BYTES + 
idBytesLength;
+        writeFloatLe(indexBytes, firstVectorOffset, value);
+    }
+
+    private static int readIntLe(byte[] bytes, int offset) {
+        return (bytes[offset] & 0xFF)
+                | ((bytes[offset + 1] & 0xFF) << 8)
+                | ((bytes[offset + 2] & 0xFF) << 16)
+                | ((bytes[offset + 3] & 0xFF) << 24);
+    }
+
+    private static long readLongLe(byte[] bytes, int offset) {
+        long result = 0L;
+        for (int i = 0; i < Long.BYTES; i++) {
+            result |= (long) (bytes[offset + i] & 0xFF) << (8 * i);
+        }
+        return result;
+    }
+
+    private static void writeFloatLe(byte[] bytes, int offset, float value) {
+        int bits = Float.floatToRawIntBits(value);
+        bytes[offset] = (byte) bits;
+        bytes[offset + 1] = (byte) (bits >>> 8);
+        bytes[offset + 2] = (byte) (bits >>> 16);
+        bytes[offset + 3] = (byte) (bits >>> 24);
+    }
+
     private interface ThrowingRunnable {
         void run() throws Throwable;
     }
diff --git 
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
 
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
index ac0c3cc..6e40171 100644
--- 
a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
+++ 
b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativeValidationTest.java
@@ -31,7 +31,9 @@ public class VectorIndexNativeValidationTest {
         System.load(args[0]);
 
         testWriterValidationComesFromCore();
+        testWriterRejectsNonFiniteValues();
         testReaderValidationComesFromCore();
+        testReaderRejectsNonFiniteQueries();
     }
 
     private static void testWriterValidationComesFromCore() {
@@ -104,6 +106,74 @@ public class VectorIndexNativeValidationTest {
         }
     }
 
+    private static void testWriterRejectsNonFiniteValues() {
+        final VectorIndexWriter trainingWriter = new 
VectorIndexWriter(ivfFlatOptions());
+        try {
+            assertThrowsMessage(
+                    RuntimeException.class,
+                    "training data contains non-finite value at offset 0: NaN",
+                    new ThrowingRunnable() {
+                        @Override
+                        public void run() {
+                            trainingWriter.train(new float[] {Float.NaN, 
1.0f}, 2);
+                        }
+                    });
+        } finally {
+            trainingWriter.close();
+        }
+
+        final VectorIndexWriter vectorWriter = new 
VectorIndexWriter(ivfFlatOptions());
+        try {
+            vectorWriter.train(new float[] {0.0f, 1.0f}, 2);
+            assertThrowsMessage(
+                    RuntimeException.class,
+                    "vector data contains non-finite value at offset 0: inf",
+                    new ThrowingRunnable() {
+                        @Override
+                        public void run() {
+                            vectorWriter.addVectors(
+                                    new long[] {1L, 2L},
+                                    new float[] {Float.POSITIVE_INFINITY, 
1.0f},
+                                    2);
+                        }
+                    });
+        } finally {
+            vectorWriter.close();
+        }
+    }
+
+    private static void testReaderRejectsNonFiniteQueries() {
+        VectorIndexReader reader = new VectorIndexReader(new 
ByteArraySeekableInputStream(buildIndexBytes()));
+        try {
+            assertInvalidInput(
+                    new ThrowingRunnable() {
+                        @Override
+                        public void run() {
+                            reader.search(new float[] {Float.NaN}, 1, 1);
+                        }
+                    },
+                    "query contains non-finite value at offset 0: NaN");
+            assertInvalidInput(
+                    new ThrowingRunnable() {
+                        @Override
+                        public void run() {
+                            reader.searchBatch(new float[] 
{Float.NEGATIVE_INFINITY}, 1, 1, 1);
+                        }
+                    },
+                    "queries contains non-finite value at offset 0: -inf");
+            assertInvalidInput(
+                    new ThrowingRunnable() {
+                        @Override
+                        public void run() {
+                            reader.search(new float[] {Float.NaN}, 1, 1, new 
byte[] {(byte) 0xFF});
+                        }
+                    },
+                    "query contains non-finite value at offset 0: NaN");
+        } finally {
+            reader.close();
+        }
+    }
+
     private static byte[] buildIndexBytes() {
         VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions());
         ByteArrayPositionOutputStream output = new 
ByteArrayPositionOutputStream();
@@ -144,6 +214,27 @@ public class VectorIndexNativeValidationTest {
         throw new AssertionError("expected " + expected.getName());
     }
 
+    private static void assertInvalidInput(ThrowingRunnable runnable, String 
expectedMessage) {
+        try {
+            runnable.run();
+        } catch (RuntimeException e) {
+            String message = e.getMessage();
+            if (message == null || !message.contains(expectedMessage)) {
+                throw new AssertionError("unexpected exception message: " + 
message, e);
+            }
+            if (message.contains("Rust panic in JNI call")) {
+                throw new AssertionError("invalid input should not cross the 
panic boundary", e);
+            }
+            if (message.contains("invalid RoaringTreemap")) {
+                throw new AssertionError("query validation should run before 
filter decoding", e);
+            }
+            return;
+        } catch (Throwable t) {
+            throw new AssertionError("expected RuntimeException but got " + 
t.getClass().getName(), t);
+        }
+        throw new AssertionError("expected RuntimeException");
+    }
+
     private interface ThrowingRunnable {
         void run() throws Throwable;
     }

Reply via email to