This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new c1a57cb548 [Variant] Add negative tests for reading invalid primitive
variant values (#7779)
c1a57cb548 is described below
commit c1a57cb548ddbc49f70ccb4b4d401b4c012ae6f6
Author: superserious-dev <[email protected]>
AuthorDate: Sun Jun 29 02:04:18 2025 -0700
[Variant] Add negative tests for reading invalid primitive variant values
(#7779)
# Which issue does this PR close?
- Closes #7645
# Rationale for this change
Follow-up from the previous PR that added decoders for primitive values.
# What changes are included in this PR?
- Verifies that an error is emitted if a decoder does not have enough
bytes to consume
- Ensures that decimal scale values can't exceed the maximum from the
spec + tests to verify
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
parquet-variant/Cargo.toml | 4 +
parquet-variant/src/decoder.rs | 227 +++++++++++++++++++++++------------------
2 files changed, 130 insertions(+), 101 deletions(-)
diff --git a/parquet-variant/Cargo.toml b/parquet-variant/Cargo.toml
index 838ca7de88..6bec373d02 100644
--- a/parquet-variant/Cargo.toml
+++ b/parquet-variant/Cargo.toml
@@ -38,4 +38,8 @@ chrono = { workspace = true }
serde_json = "1.0"
base64 = "0.22"
+[dev-dependencies]
+paste = { version = "1.0" }
+
+
[lib]
diff --git a/parquet-variant/src/decoder.rs b/parquet-variant/src/decoder.rs
index e73911aa29..6b5c131078 100644
--- a/parquet-variant/src/decoder.rs
+++ b/parquet-variant/src/decoder.rs
@@ -283,157 +283,182 @@ pub(crate) fn decode_short_string(metadata: u8, data:
&[u8]) -> Result<ShortStri
#[cfg(test)]
mod tests {
use super::*;
-
- #[test]
- fn test_i8() -> Result<(), ArrowError> {
- let data = [0x2a];
- let result = decode_int8(&data)?;
- assert_eq!(result, 42);
- Ok(())
+ use paste::paste;
+
+ macro_rules! test_decoder_bounds {
+ ($test_name:ident, $data:expr, $decode_fn:ident, $expected:expr) => {
+ paste! {
+ #[test]
+ fn [<$test_name _exact_length>]() {
+ let result = $decode_fn(&$data).unwrap();
+ assert_eq!(result, $expected);
+ }
+
+ #[test]
+ fn [<$test_name _truncated_length>]() {
+ // Remove the last byte of data so that there is not
enough to decode
+ let truncated_data = &$data[.. $data.len() - 1];
+ let result = $decode_fn(truncated_data);
+ assert!(matches!(result,
Err(ArrowError::InvalidArgumentError(_))));
+ }
+ }
+ };
}
- #[test]
- fn test_i16() -> Result<(), ArrowError> {
- let data = [0xd2, 0x04];
- let result = decode_int16(&data)?;
- assert_eq!(result, 1234);
- Ok(())
+ mod integer {
+ use super::*;
+
+ test_decoder_bounds!(test_i8, [0x2a], decode_int8, 42);
+ test_decoder_bounds!(test_i16, [0xd2, 0x04], decode_int16, 1234);
+ test_decoder_bounds!(test_i32, [0x40, 0xe2, 0x01, 0x00], decode_int32,
123456);
+ test_decoder_bounds!(
+ test_i64,
+ [0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11],
+ decode_int64,
+ 1234567890123456789
+ );
}
- #[test]
- fn test_i32() -> Result<(), ArrowError> {
- let data = [0x40, 0xe2, 0x01, 0x00];
- let result = decode_int32(&data)?;
- assert_eq!(result, 123456);
- Ok(())
- }
+ mod decimal {
+ use super::*;
+
+ test_decoder_bounds!(
+ test_decimal4,
+ [
+ 0x02, // Scale
+ 0xd2, 0x04, 0x00, 0x00, // Unscaled Value
+ ],
+ decode_decimal4,
+ (1234, 2)
+ );
- #[test]
- fn test_i64() -> Result<(), ArrowError> {
- let data = [0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11];
- let result = decode_int64(&data)?;
- assert_eq!(result, 1234567890123456789);
- Ok(())
- }
+ test_decoder_bounds!(
+ test_decimal8,
+ [
+ 0x02, // Scale
+ 0xd2, 0x02, 0x96, 0x49, 0x00, 0x00, 0x00, 0x00, // Unscaled
Value
+ ],
+ decode_decimal8,
+ (1234567890, 2)
+ );
- #[test]
- fn test_decimal4() -> Result<(), ArrowError> {
- let data = [
- 0x02, // Scale
- 0xd2, 0x04, 0x00, 0x00, // Integer
- ];
- let result = decode_decimal4(&data)?;
- assert_eq!(result, (1234, 2));
- Ok(())
+ test_decoder_bounds!(
+ test_decimal16,
+ [
+ 0x02, // Scale
+ 0xd2, 0xb6, 0x23, 0xc0, 0xf4, 0x10, 0x22, 0x11, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, // Unscaled Value
+ ],
+ decode_decimal16,
+ (1234567891234567890, 2)
+ );
}
- #[test]
- fn test_decimal8() -> Result<(), ArrowError> {
- let data = [
- 0x02, // Scale
- 0xd2, 0x02, 0x96, 0x49, 0x00, 0x00, 0x00, 0x00, // Integer
- ];
- let result = decode_decimal8(&data)?;
- assert_eq!(result, (1234567890, 2));
- Ok(())
- }
+ mod float {
+ use super::*;
- #[test]
- fn test_decimal16() -> Result<(), ArrowError> {
- let data = [
- 0x02, // Scale
- 0xd2, 0xb6, 0x23, 0xc0, 0xf4, 0x10, 0x22, 0x11, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00,
- 0x00, 0x00, // Integer
- ];
- let result = decode_decimal16(&data)?;
- assert_eq!(result, (1234567891234567890, 2));
- Ok(())
- }
+ test_decoder_bounds!(
+ test_float,
+ [0x06, 0x2c, 0x93, 0x4e],
+ decode_float,
+ 1234567890.1234
+ );
- #[test]
- fn test_float() -> Result<(), ArrowError> {
- let data = [0x06, 0x2c, 0x93, 0x4e];
- let result = decode_float(&data)?;
- assert_eq!(result, 1234567890.1234);
- Ok(())
+ test_decoder_bounds!(
+ test_double,
+ [0xc9, 0xe5, 0x87, 0xb4, 0x80, 0x65, 0xd2, 0x41],
+ decode_double,
+ 1234567890.1234
+ );
}
- #[test]
- fn test_double() -> Result<(), ArrowError> {
- let data = [0xc9, 0xe5, 0x87, 0xb4, 0x80, 0x65, 0xd2, 0x41];
- let result = decode_double(&data)?;
- assert_eq!(result, 1234567890.1234);
- Ok(())
- }
+ mod datetime {
+ use super::*;
- #[test]
- fn test_date() -> Result<(), ArrowError> {
- let data = [0xe2, 0x4e, 0x0, 0x0];
- let result = decode_date(&data)?;
- assert_eq!(result, NaiveDate::from_ymd_opt(2025, 4, 16).unwrap());
- Ok(())
- }
+ test_decoder_bounds!(
+ test_date,
+ [0xe2, 0x4e, 0x0, 0x0],
+ decode_date,
+ NaiveDate::from_ymd_opt(2025, 4, 16).unwrap()
+ );
- #[test]
- fn test_timestamp_micros() -> Result<(), ArrowError> {
- let data = [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00];
- let result = decode_timestamp_micros(&data)?;
- assert_eq!(
- result,
+ test_decoder_bounds!(
+ test_timestamp_micros,
+ [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00],
+ decode_timestamp_micros,
NaiveDate::from_ymd_opt(2025, 4, 16)
.unwrap()
.and_hms_milli_opt(16, 34, 56, 780)
.unwrap()
.and_utc()
);
- Ok(())
- }
- #[test]
- fn test_timestampntz_micros() -> Result<(), ArrowError> {
- let data = [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00];
- let result = decode_timestampntz_micros(&data)?;
- assert_eq!(
- result,
+ test_decoder_bounds!(
+ test_timestampntz_micros,
+ [0xe0, 0x52, 0x97, 0xdd, 0xe7, 0x32, 0x06, 0x00],
+ decode_timestampntz_micros,
NaiveDate::from_ymd_opt(2025, 4, 16)
.unwrap()
.and_hms_milli_opt(16, 34, 56, 780)
.unwrap()
);
- Ok(())
}
#[test]
- fn test_binary() -> Result<(), ArrowError> {
+ fn test_binary_exact_length() {
let data = [
0x09, 0, 0, 0, // Length of binary data, 4-byte little-endian
0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe,
];
- let result = decode_binary(&data)?;
+ let result = decode_binary(&data).unwrap();
assert_eq!(
result,
[0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe]
);
- Ok(())
}
#[test]
- fn test_short_string() -> Result<(), ArrowError> {
+ fn test_binary_truncated_length() {
+ let data = [
+ 0x09, 0, 0, 0, // Length of binary data, 4-byte little-endian
+ 0x03, 0x13, 0x37, 0xde, 0xad, 0xbe, 0xef, 0xca,
+ ];
+ let result = decode_binary(&data);
+ assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_))));
+ }
+
+ #[test]
+ fn test_short_string_exact_length() {
let data = [b'H', b'e', b'l', b'l', b'o', b'o'];
- let result = decode_short_string(1 | 5 << 2, &data)?;
+ let result = decode_short_string(1 | 5 << 2, &data).unwrap();
assert_eq!(result.0, "Hello");
- Ok(())
}
#[test]
- fn test_string() -> Result<(), ArrowError> {
+ fn test_short_string_truncated_length() {
+ let data = [b'H', b'e', b'l'];
+ let result = decode_short_string(1 | 5 << 2, &data);
+ assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_))));
+ }
+
+ #[test]
+ fn test_string_exact_length() {
let data = [
0x05, 0, 0, 0, // Length of string, 4-byte little-endian
b'H', b'e', b'l', b'l', b'o', b'o',
];
- let result = decode_long_string(&data)?;
+ let result = decode_long_string(&data).unwrap();
assert_eq!(result, "Hello");
- Ok(())
+ }
+
+ #[test]
+ fn test_string_truncated_length() {
+ let data = [
+ 0x05, 0, 0, 0, // Length of string, 4-byte little-endian
+ b'H', b'e', b'l',
+ ];
+ let result = decode_long_string(&data);
+ assert!(matches!(result, Err(ArrowError::InvalidArgumentError(_))));
}
#[test]