This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new a160e94aa8 Convert some panics that happen on invalid parquet files to 
error results (#6738)
a160e94aa8 is described below

commit a160e94aa8f1845a264ef208a2ab0fb8d9137240
Author: Jinpeng <[email protected]>
AuthorDate: Mon Jan 6 16:09:05 2025 -0500

    Convert some panics that happen on invalid parquet files to error results 
(#6738)
    
    * Reduce  panics
    
    * t pushmove integer logical type from format.rs to schema type.rs
    
    * remove some changes as per reviews
    
    * use wrapping_shl
    
    * fix typo in error message
    
    * return error for invalid decimal length
    
    ---------
    
    Co-authored-by: jp0317 <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 parquet/src/errors.rs                  |  7 +++++
 parquet/src/file/metadata/reader.rs    | 26 ++++++++---------
 parquet/src/file/serialized_reader.rs  | 53 +++++++++++++++++++++++++++++-----
 parquet/src/file/statistics.rs         | 26 +++++++++++++++++
 parquet/src/schema/types.rs            | 25 +++++++++++++++-
 parquet/src/thrift.rs                  | 35 ++++++++++++++++++----
 parquet/tests/arrow_reader/bad_data.rs |  2 +-
 7 files changed, 146 insertions(+), 28 deletions(-)

diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs
index 8dc97f4ca2..d749287bba 100644
--- a/parquet/src/errors.rs
+++ b/parquet/src/errors.rs
@@ -17,6 +17,7 @@
 
 //! Common Parquet errors and macros.
 
+use core::num::TryFromIntError;
 use std::error::Error;
 use std::{cell, io, result, str};
 
@@ -81,6 +82,12 @@ impl Error for ParquetError {
     }
 }
 
+impl From<TryFromIntError> for ParquetError {
+    fn from(e: TryFromIntError) -> ParquetError {
+        ParquetError::General(format!("Integer overflow: {e}"))
+    }
+}
+
 impl From<io::Error> for ParquetError {
     fn from(e: io::Error) -> ParquetError {
         ParquetError::External(Box::new(e))
diff --git a/parquet/src/file/metadata/reader.rs 
b/parquet/src/file/metadata/reader.rs
index ec2cd1094d..c6715a33b5 100644
--- a/parquet/src/file/metadata/reader.rs
+++ b/parquet/src/file/metadata/reader.rs
@@ -627,7 +627,8 @@ impl ParquetMetaDataReader {
         for rg in t_file_metadata.row_groups {
             
row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?);
         }
-        let column_orders = 
Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr);
+        let column_orders =
+            Self::parse_column_orders(t_file_metadata.column_orders, 
&schema_descr)?;
 
         let file_metadata = FileMetaData::new(
             t_file_metadata.version,
@@ -645,15 +646,13 @@ impl ParquetMetaDataReader {
     fn parse_column_orders(
         t_column_orders: Option<Vec<TColumnOrder>>,
         schema_descr: &SchemaDescriptor,
-    ) -> Option<Vec<ColumnOrder>> {
+    ) -> Result<Option<Vec<ColumnOrder>>> {
         match t_column_orders {
             Some(orders) => {
                 // Should always be the case
-                assert_eq!(
-                    orders.len(),
-                    schema_descr.num_columns(),
-                    "Column order length mismatch"
-                );
+                if orders.len() != schema_descr.num_columns() {
+                    return Err(general_err!("Column order length mismatch"));
+                };
                 let mut res = Vec::new();
                 for (i, column) in schema_descr.columns().iter().enumerate() {
                     match orders[i] {
@@ -667,9 +666,9 @@ impl ParquetMetaDataReader {
                         }
                     }
                 }
-                Some(res)
+                Ok(Some(res))
             }
-            None => None,
+            None => Ok(None),
         }
     }
 }
@@ -741,7 +740,7 @@ mod tests {
         ]);
 
         assert_eq!(
-            ParquetMetaDataReader::parse_column_orders(t_column_orders, 
&schema_descr),
+            ParquetMetaDataReader::parse_column_orders(t_column_orders, 
&schema_descr).unwrap(),
             Some(vec![
                 ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED),
                 ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED)
@@ -750,20 +749,21 @@ mod tests {
 
         // Test when no column orders are defined.
         assert_eq!(
-            ParquetMetaDataReader::parse_column_orders(None, &schema_descr),
+            ParquetMetaDataReader::parse_column_orders(None, 
&schema_descr).unwrap(),
             None
         );
     }
 
     #[test]
-    #[should_panic(expected = "Column order length mismatch")]
     fn test_metadata_column_orders_len_mismatch() {
         let schema = SchemaType::group_type_builder("schema").build().unwrap();
         let schema_descr = SchemaDescriptor::new(Arc::new(schema));
 
         let t_column_orders = 
Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]);
 
-        ParquetMetaDataReader::parse_column_orders(t_column_orders, 
&schema_descr);
+        let res = ParquetMetaDataReader::parse_column_orders(t_column_orders, 
&schema_descr);
+        assert!(res.is_err());
+        assert!(format!("{:?}", res.unwrap_err()).contains("Column order 
length mismatch"));
     }
 
     #[test]
diff --git a/parquet/src/file/serialized_reader.rs 
b/parquet/src/file/serialized_reader.rs
index 06f3cf9fb2..a942481f7e 100644
--- a/parquet/src/file/serialized_reader.rs
+++ b/parquet/src/file/serialized_reader.rs
@@ -435,7 +435,7 @@ pub(crate) fn decode_page(
             let is_sorted = dict_header.is_sorted.unwrap_or(false);
             Page::DictionaryPage {
                 buf: buffer,
-                num_values: dict_header.num_values as u32,
+                num_values: dict_header.num_values.try_into()?,
                 encoding: Encoding::try_from(dict_header.encoding)?,
                 is_sorted,
             }
@@ -446,7 +446,7 @@ pub(crate) fn decode_page(
                 .ok_or_else(|| ParquetError::General("Missing V1 data page 
header".to_string()))?;
             Page::DataPage {
                 buf: buffer,
-                num_values: header.num_values as u32,
+                num_values: header.num_values.try_into()?,
                 encoding: Encoding::try_from(header.encoding)?,
                 def_level_encoding: 
Encoding::try_from(header.definition_level_encoding)?,
                 rep_level_encoding: 
Encoding::try_from(header.repetition_level_encoding)?,
@@ -460,12 +460,12 @@ pub(crate) fn decode_page(
             let is_compressed = header.is_compressed.unwrap_or(true);
             Page::DataPageV2 {
                 buf: buffer,
-                num_values: header.num_values as u32,
+                num_values: header.num_values.try_into()?,
                 encoding: Encoding::try_from(header.encoding)?,
-                num_nulls: header.num_nulls as u32,
-                num_rows: header.num_rows as u32,
-                def_levels_byte_len: header.definition_levels_byte_length as 
u32,
-                rep_levels_byte_len: header.repetition_levels_byte_length as 
u32,
+                num_nulls: header.num_nulls.try_into()?,
+                num_rows: header.num_rows.try_into()?,
+                def_levels_byte_len: 
header.definition_levels_byte_length.try_into()?,
+                rep_levels_byte_len: 
header.repetition_levels_byte_length.try_into()?,
                 is_compressed,
                 statistics: statistics::from_thrift(physical_type, 
header.statistics)?,
             }
@@ -578,6 +578,27 @@ impl<R: ChunkReader> Iterator for SerializedPageReader<R> {
     }
 }
 
+fn verify_page_header_len(header_len: usize, remaining_bytes: usize) -> 
Result<()> {
+    if header_len > remaining_bytes {
+        return Err(eof_err!("Invalid page header"));
+    }
+    Ok(())
+}
+
+fn verify_page_size(
+    compressed_size: i32,
+    uncompressed_size: i32,
+    remaining_bytes: usize,
+) -> Result<()> {
+    // The page's compressed size should not exceed the remaining bytes that 
are
+    // available to read. The page's uncompressed size is the expected size
+    // after decompression, which can never be negative.
+    if compressed_size < 0 || compressed_size as usize > remaining_bytes || 
uncompressed_size < 0 {
+        return Err(eof_err!("Invalid page header"));
+    }
+    Ok(())
+}
+
 impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
     fn get_next_page(&mut self) -> Result<Option<Page>> {
         loop {
@@ -596,10 +617,16 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                         *header
                     } else {
                         let (header_len, header) = read_page_header_len(&mut 
read)?;
+                        verify_page_header_len(header_len, *remaining)?;
                         *offset += header_len;
                         *remaining -= header_len;
                         header
                     };
+                    verify_page_size(
+                        header.compressed_page_size,
+                        header.uncompressed_page_size,
+                        *remaining,
+                    )?;
                     let data_len = header.compressed_page_size as usize;
                     *offset += data_len;
                     *remaining -= data_len;
@@ -683,6 +710,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> 
{
                     } else {
                         let mut read = self.reader.get_read(*offset as u64)?;
                         let (header_len, header) = read_page_header_len(&mut 
read)?;
+                        verify_page_header_len(header_len, *remaining_bytes)?;
                         *offset += header_len;
                         *remaining_bytes -= header_len;
                         let page_meta = if let Ok(page_meta) = 
(&header).try_into() {
@@ -733,12 +761,23 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                 next_page_header,
             } => {
                 if let Some(buffered_header) = next_page_header.take() {
+                    verify_page_size(
+                        buffered_header.compressed_page_size,
+                        buffered_header.uncompressed_page_size,
+                        *remaining_bytes,
+                    )?;
                     // The next page header has already been peeked, so just 
advance the offset
                     *offset += buffered_header.compressed_page_size as usize;
                     *remaining_bytes -= buffered_header.compressed_page_size 
as usize;
                 } else {
                     let mut read = self.reader.get_read(*offset as u64)?;
                     let (header_len, header) = read_page_header_len(&mut 
read)?;
+                    verify_page_header_len(header_len, *remaining_bytes)?;
+                    verify_page_size(
+                        header.compressed_page_size,
+                        header.uncompressed_page_size,
+                        *remaining_bytes,
+                    )?;
                     let data_page_size = header.compressed_page_size as usize;
                     *offset += header_len + data_page_size;
                     *remaining_bytes -= header_len + data_page_size;
diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs
index 2e05b83369..b7522a76f0 100644
--- a/parquet/src/file/statistics.rs
+++ b/parquet/src/file/statistics.rs
@@ -157,6 +157,32 @@ pub fn from_thrift(
                 stats.max_value
             };
 
+            fn check_len(min: &Option<Vec<u8>>, max: &Option<Vec<u8>>, len: 
usize) -> Result<()> {
+                if let Some(min) = min {
+                    if min.len() < len {
+                        return Err(ParquetError::General(
+                            "Insufficient bytes to parse min 
statistic".to_string(),
+                        ));
+                    }
+                }
+                if let Some(max) = max {
+                    if max.len() < len {
+                        return Err(ParquetError::General(
+                            "Insufficient bytes to parse max 
statistic".to_string(),
+                        ));
+                    }
+                }
+                Ok(())
+            }
+
+            match physical_type {
+                Type::BOOLEAN => check_len(&min, &max, 1),
+                Type::INT32 | Type::FLOAT => check_len(&min, &max, 4),
+                Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8),
+                Type::INT96 => check_len(&min, &max, 12),
+                _ => Ok(()),
+            }?;
+
             // Values are encoded using PLAIN encoding definition, except that
             // variable-length byte arrays do not include a length prefix.
             //
diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs
index d168e46de0..d9e9b22e80 100644
--- a/parquet/src/schema/types.rs
+++ b/parquet/src/schema/types.rs
@@ -556,7 +556,11 @@ impl<'a> PrimitiveTypeBuilder<'a> {
                 }
             }
             PhysicalType::FIXED_LEN_BYTE_ARRAY => {
-                let max_precision = (2f64.powi(8 * self.length - 1) - 
1f64).log10().floor() as i32;
+                let length = self
+                    .length
+                    .checked_mul(8)
+                    .ok_or(general_err!("Invalid length {} for Decimal", 
self.length))?;
+                let max_precision = (2f64.powi(length - 1) - 
1f64).log10().floor() as i32;
 
                 if self.precision > max_precision {
                     return Err(general_err!(
@@ -1171,9 +1175,25 @@ pub fn from_thrift(elements: &[SchemaElement]) -> 
Result<TypePtr> {
         ));
     }
 
+    if !schema_nodes[0].is_group() {
+        return Err(general_err!("Expected root node to be a group type"));
+    }
+
     Ok(schema_nodes.remove(0))
 }
 
+/// Checks if the logical type is valid.
+fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
+    if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
+        if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width 
!= 64 {
+            return Err(general_err!(
+                "Bit width must be 8, 16, 32, or 64 for Integer logical type"
+            ));
+        }
+    }
+    Ok(())
+}
+
 /// Constructs a new Type from the `elements`, starting at index `index`.
 /// The first result is the starting index for the next Type after this one. 
If it is
 /// equal to `elements.len()`, then this Type is the last one.
@@ -1198,6 +1218,9 @@ fn from_thrift_helper(elements: &[SchemaElement], index: 
usize) -> Result<(usize
         .logical_type
         .as_ref()
         .map(|value| LogicalType::from(value.clone()));
+
+    check_logical_type(&logical_type)?;
+
     let field_id = elements[index].field_id;
     match elements[index].num_children {
         // From parquet-format:
diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs
index ceb6b1c29f..b216fec6f3 100644
--- a/parquet/src/thrift.rs
+++ b/parquet/src/thrift.rs
@@ -67,7 +67,7 @@ impl<'a> TCompactSliceInputProtocol<'a> {
         let mut shift = 0;
         loop {
             let byte = self.read_byte()?;
-            in_progress |= ((byte & 0x7F) as u64) << shift;
+            in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
             shift += 7;
             if byte & 0x80 == 0 {
                 return Ok(in_progress);
@@ -96,13 +96,22 @@ impl<'a> TCompactSliceInputProtocol<'a> {
     }
 }
 
+macro_rules! thrift_unimplemented {
+    () => {
+        Err(thrift::Error::Protocol(thrift::ProtocolError {
+            kind: thrift::ProtocolErrorKind::NotImplemented,
+            message: "not implemented".to_string(),
+        }))
+    };
+}
+
 impl TInputProtocol for TCompactSliceInputProtocol<'_> {
     fn read_message_begin(&mut self) -> thrift::Result<TMessageIdentifier> {
         unimplemented!()
     }
 
     fn read_message_end(&mut self) -> thrift::Result<()> {
-        unimplemented!()
+        thrift_unimplemented!()
     }
 
     fn read_struct_begin(&mut self) -> 
thrift::Result<Option<TStructIdentifier>> {
@@ -147,7 +156,21 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
             ),
             _ => {
                 if field_delta != 0 {
-                    self.last_read_field_id += field_delta as i16;
+                    self.last_read_field_id = self
+                        .last_read_field_id
+                        .checked_add(field_delta as i16)
+                        .map_or_else(
+                            || {
+                                
Err(thrift::Error::Protocol(thrift::ProtocolError {
+                                    kind: 
thrift::ProtocolErrorKind::InvalidData,
+                                    message: format!(
+                                        "cannot add {} to {}",
+                                        field_delta, self.last_read_field_id
+                                    ),
+                                }))
+                            },
+                            Ok,
+                        )?;
                 } else {
                     self.last_read_field_id = self.read_i16()?;
                 };
@@ -226,15 +249,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
     }
 
     fn read_set_begin(&mut self) -> thrift::Result<TSetIdentifier> {
-        unimplemented!()
+        thrift_unimplemented!()
     }
 
     fn read_set_end(&mut self) -> thrift::Result<()> {
-        unimplemented!()
+        thrift_unimplemented!()
     }
 
     fn read_map_begin(&mut self) -> thrift::Result<TMapIdentifier> {
-        unimplemented!()
+        thrift_unimplemented!()
     }
 
     fn read_map_end(&mut self) -> thrift::Result<()> {
diff --git a/parquet/tests/arrow_reader/bad_data.rs 
b/parquet/tests/arrow_reader/bad_data.rs
index 7434203143..cfd61e82d3 100644
--- a/parquet/tests/arrow_reader/bad_data.rs
+++ b/parquet/tests/arrow_reader/bad_data.rs
@@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() {
     let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err();
     assert_eq!(
         err.to_string(),
-        "External: Parquet argument error: EOF: eof decoding byte array"
+        "External: Parquet argument error: Parquet error: Integer overflow: 
out of range integral type conversion attempted"
     );
 }
 

Reply via email to