This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new c1a57cb548 [Variant] Add negative tests for reading invalid primitive 
variant values (#7779)
c1a57cb548 is described below

commit c1a57cb548ddbc49f70ccb4b4d401b4c012ae6f6
Author: superserious-dev <[email protected]>
AuthorDate: Sun Jun 29 02:04:18 2025 -0700

    [Variant] Add negative tests for reading invalid primitive variant values 
(#7779)
    
    # Which issue does this PR close?
    
    - Closes #7645
    
    # Rationale for this change
    
    Follow-up from the previous PR that added decoders for primitive values.
    
    # What changes are included in this PR?
    
    - Verifies that an error is emitted if a decoder does not have enough
    bytes to consume
    - Ensures that decimal scale values can't exceed the maximum from the
    spec + tests to verify
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 parquet-variant/Cargo.toml     |   4 +
 parquet-variant/src/decoder.rs | 227 +++++++++++++++++++++++------------------
 2 files changed, 130 insertions(+), 101 deletions(-)

diff --git a/parquet-variant/Cargo.toml b/parquet-variant/Cargo.toml
index 838ca7de88..6bec373d02 100644
--- a/parquet-variant/Cargo.toml
+++ b/parquet-variant/Cargo.toml
@@ -38,4 +38,8 @@ chrono = { workspace = true }
 serde_json = "1.0"
 base64 = "0.22"
 
+[dev-dependencies]
+paste = { version = "1.0" }
+
+
 [lib]
diff --git a/parquet-variant/src/decoder.rs b/parquet-variant/src/decoder.rs
index e73911aa29..6b5c131078 100644
--- a/parquet-variant/src/decoder.rs
+++ b/parquet-variant/src/decoder.rs
@@ -283,157 +283,182 @@ pub(crate) fn decode_short_string(metadata: u8, data: 
&[u8]) -> Result<ShortStri
 #[cfg(test)]
 mod tests {
     use super::*;
-
-    #[test]
-    fn test_i8() -> Result<(), ArrowError> {
-        let data = [0x2a];
-        let result = decode_int8(&data)?;
-        assert_eq!(result, 42);
-        Ok(())
+    use paste::paste;
+
+    macro_rules! test_decoder_bounds {
+        ($test_name:ident, $data:expr, $decode_fn:ident, $expected:expr) => {
+            paste! {
+                #[test]
+                fn [<$test_name _exact_length>]() {
+                    let result = $decode_fn(&$data).unwrap();
+                    assert_eq!(result, $expected);
+                }
+
+                #[test]
+                fn [<$test_name _truncated_length>]() {
+                    // Remove the last byte of data so that there is not 
enough to decode
+                    let truncated_data = &$data[.. $data.len() - 1];
+                    let result = $decode_fn(truncated_data);
+                    assert!(matches!(result, 
Err(ArrowError::InvalidArgumentError(_))));
+                }
+            }
+        };
     }
 
-    #[test]
-    fn test_i16() -> Result<(), ArrowError> {
-        let data = [0xd2, 0x04];
-        let result = decode_int16(&data)?;
-        assert_eq!(result, 1234);
-        Ok(())
+    mod integer {
+        use super::*;
+
+        test_decoder_bounds!(test_i8, [0x2a], decode_int8, 42);
+        test_decoder_bounds!(test_i16, [0xd2, 0x04], decode_int16, 1234);
+        test_decoder_bounds!(test_i32, [0x40, 0xe2, 0x01, 0x00], decode_int32, 
123456);
+        test_decoder_bounds!(
+            test_i64,
+            [0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11],
+            decode_int64,
+            1234567890123456789
+        );
     }
 
-    #[test]
-    fn test_i32() -> Result<(), ArrowError> {
-        let data = [0x40, 0xe2, 0x01, 0x00];
-        let result = decode_int32(&data)?;
-        assert_eq!(result, 123456);
-        Ok(())
-    }
+    mod decimal {
+        use super::*;
+
+        test_decoder_bounds!(
+            test_decimal4,
+            [
+                0x02, // Scale
+                0xd2, 0x04, 0x00, 0x00, // Unscaled Value
+            ],
+            decode_decimal4,
+            (1234, 2)
+        );
 
-    #[test]
-    fn test_i64() -> Result<(), ArrowError> {
-        let data = [0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11];
-        let result = decode_int64(&data)?;
-        assert_eq!(result, 1234567890123456789);
-        Ok(())
-    }
+        test_decoder_bounds!(
+            test_decimal8,
+            [
+                0x02, // Scale
+                0xd2, 0x02, 0x96, 0x49, 0x00, 0x00, 0x00, 0x00, // Unscaled 
Value
+            ],
+            decode_decimal8,
+            (1234567890, 2)
+        );
 
-    #[test]
-    fn test_decimal4() -> Result<(), ArrowError> {
-        let data = [
-            0x02, // Scale
-            0xd2, 0x04, 0x00, 0x00, // Integer
-        ];
-        let result = decode_decimal4(&data)?;
-        assert_eq!(result, (1234, 2));
-        Ok(())
+        test_decoder_bounds!(
+            test_decimal16,
+            [
+                0x02, // Scale
+                0xd2, 0xb6, 0x23, 0xc0, 0xf4, 0x10, 0x22, 0x11, 0x00, 0x00, 
0x00, 0x00, 0x00, 0x00,
+                0x00, 0x00, // Unscaled Value
+            ],
+            decode_decimal16,
+            (1234567891234567890, 2)
+        );
     }
 
-    #[test]
-    fn test_decimal8() -> Result<(), ArrowError> {
-        let data = [
-            0x02, // Scale
-            0xd2, 0x02, 0x96, 0x49, 0x00, 0x00, 0x00, 0x00, // Integer
-        ];
-        let result = decode_decimal8(&data)?;
-        assert_eq!(result, (1234567890, 2));
-        Ok(())
-    }
+    mod float {
+        use super::*;
 
-    #[test]
-    fn test_decimal16() -> Result<(), ArrowError> {
-        let data = [
-            0x02, // Scale
-            0xd2, 0xb6, 0x23, 0xc0, 0xf4, 0x10, 0x22, 0x11, 0x00, 0x00, 0x00, 
0x00, 0x00, 0x00,
-            0x00, 0x00, // Integer
-        ];
-        let result = decode_decimal16(&data)?;
-        assert_eq!(result, (1234567891234567890, 2));
-        Ok(())
-    }
+        test_decoder_bounds!(
+            test_float,
+            [0x06, 0x2c, 0x93, 0x4e],
+            decode_float,
+            1234567890.1234
+        );
 
-    #[test]
-    fn test_float() -> Result<(), ArrowError> {
-        let data = [0x06, 0x2c, 0x93, 0x4e];
-        let result = decode_float(&data)?;
-        assert_eq!(result, 1234567890.1234);
-        Ok(())
+        test_decoder_bounds!(
+            test_double,
+            [0xc9, 0xe5, 0x87, 0xb4, 0x80, 0x65, 0xd2, 0x41],
+            decode_double,
+            1234567890.1234
+        );
     }
 
-    #[test]
-    fn test_double() -> Result<(), ArrowError> {
-        let data = [0xc9, 0xe5, 0x87, 0xb4, 0x80, 0x65, 0xd2, 0x41];
-        let result = decode_double(&data)?;
-        assert_eq!(result, 1234567890.1234);
-        Ok(())
-    }
+    mod datetime {
+        use super::*;
 
-    #[test]
-    fn test_date() -> Result<(), ArrowError> {
-        let data = [0xe2, 0x4e, 0x0, 0x0];
-        let result = decode_date(&data)?;
-        assert_eq!(result, NaiveDate::from_ymd_opt(2025, 4, 16).unwrap());
-        Ok(())
-    }
+        test_decoder_bounds!(
+            test_date,
+            [0xe2, 0x4e, 0x0, 0x0],
+            decode_date,
+            NaiveDate::from_ymd_opt(2025, 4, 16).unwrap()
+        );
 
-    #[test]
-    fn test_timestamp_micros() -> Result<(), ArrowError> {
-        let data = [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00];
-        let result = decode_timestamp_micros(&data)?;
-        assert_eq!(
-            result,
+        test_decoder_bounds!(
+            test_timestamp_micros,
+            [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00],
+            decode_timestamp_micros,
             NaiveDate::from_ymd_opt(2025, 4, 16)
                 .unwrap()
                 .and_hms_milli_opt(16, 34, 56, 780)
                 .unwrap()
                 .and_utc()
         );
-        Ok(())
-    }
 
-    #[test]
-    fn test_timestampntz_micros() -> Result<(), ArrowError> {
-        let data = [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00];
-        let result = decode_timestampntz_micros(&data)?;
-        assert_eq!(
-            result,
+        test_decoder_bounds!(
+            test_timestampntz_micros,
+            [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00],
+            decode_timestampntz_micros,
             NaiveDate::from_ymd_opt(2025, 4, 16)
                 .unwrap()
                 .and_hms_milli_opt(16, 34, 56, 780)
                 .unwrap()
         );
-        Ok(())
     }
 
     #[test]
-    fn test_binary() -> Result<(), ArrowError> {
+    fn test_binary_exact_length() {
         let data = [
             0x09, 0, 0, 0, // Length of binary data, 4-byte little-endian
             0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe,
         ];
-        let result = decode_binary(&data)?;
+        let result = decode_binary(&data).unwrap();
         assert_eq!(
             result,
             [0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe]
         );
-        Ok(())
     }
 
     #[test]
-    fn test_short_string() -> Result<(), ArrowError> {
+    fn test_binary_truncated_length() {
+        let data = [
+            0x09, 0, 0, 0, // Length of binary data, 4-byte little-endian
+            0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca,
+        ];
+        let result = decode_binary(&data);
+        assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_))));
+    }
+
+    #[test]
+    fn test_short_string_exact_length() {
         let data = [b'H', b'e', b'l', b'l', b'o', b'o'];
-        let result = decode_short_string(1 | 5 << 2, &data)?;
+        let result = decode_short_string(1 | 5 << 2, &data).unwrap();
         assert_eq!(result.0, "Hello");
-        Ok(())
     }
 
     #[test]
-    fn test_string() -> Result<(), ArrowError> {
+    fn test_short_string_truncated_length() {
+        let data = [b'H', b'e', b'l'];
+        let result = decode_short_string(1 | 5 << 2, &data);
+        assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_))));
+    }
+
+    #[test]
+    fn test_string_exact_length() {
         let data = [
             0x05, 0, 0, 0, // Length of string, 4-byte little-endian
             b'H', b'e', b'l', b'l', b'o', b'o',
         ];
-        let result = decode_long_string(&data)?;
+        let result = decode_long_string(&data).unwrap();
         assert_eq!(result, "Hello");
-        Ok(())
+    }
+
+    #[test]
+    fn test_string_truncated_length() {
+        let data = [
+            0x05, 0, 0, 0, // Length of string, 4-byte little-endian
+            b'H', b'e', b'l',
+        ];
+        let result = decode_long_string(&data);
+        assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_))));
     }
 
     #[test]

Reply via email to