This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new a160e94aa8 Convert some panics that happen on invalid parquet files to
error results (#6738)
a160e94aa8 is described below
commit a160e94aa8f1845a264ef208a2ab0fb8d9137240
Author: Jinpeng <[email protected]>
AuthorDate: Mon Jan 6 16:09:05 2025 -0500
Convert some panics that happen on invalid parquet files to error results
(#6738)
* Reduce panics
* t pushmove integer logical type from format.rs to schema type.rs
* remove some changes as per reviews
* use wrapping_shl
* fix typo in error message
* return error for invalid decimal length
---------
Co-authored-by: jp0317 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
parquet/src/errors.rs | 7 +++++
parquet/src/file/metadata/reader.rs | 26 ++++++++---------
parquet/src/file/serialized_reader.rs | 53 +++++++++++++++++++++++++++++-----
parquet/src/file/statistics.rs | 26 +++++++++++++++++
parquet/src/schema/types.rs | 25 +++++++++++++++-
parquet/src/thrift.rs | 35 ++++++++++++++++++----
parquet/tests/arrow_reader/bad_data.rs | 2 +-
7 files changed, 146 insertions(+), 28 deletions(-)
diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs
index 8dc97f4ca2..d749287bba 100644
--- a/parquet/src/errors.rs
+++ b/parquet/src/errors.rs
@@ -17,6 +17,7 @@
//! Common Parquet errors and macros.
+use core::num::TryFromIntError;
use std::error::Error;
use std::{cell, io, result, str};
@@ -81,6 +82,12 @@ impl Error for ParquetError {
}
}
+impl From<TryFromIntError> for ParquetError {
+ fn from(e: TryFromIntError) -> ParquetError {
+ ParquetError::General(format!("Integer overflow: {e}"))
+ }
+}
+
impl From<io::Error> for ParquetError {
fn from(e: io::Error) -> ParquetError {
ParquetError::External(Box::new(e))
diff --git a/parquet/src/file/metadata/reader.rs
b/parquet/src/file/metadata/reader.rs
index ec2cd1094d..c6715a33b5 100644
--- a/parquet/src/file/metadata/reader.rs
+++ b/parquet/src/file/metadata/reader.rs
@@ -627,7 +627,8 @@ impl ParquetMetaDataReader {
for rg in t_file_metadata.row_groups {
row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?);
}
- let column_orders =
Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr);
+ let column_orders =
+ Self::parse_column_orders(t_file_metadata.column_orders,
&schema_descr)?;
let file_metadata = FileMetaData::new(
t_file_metadata.version,
@@ -645,15 +646,13 @@ impl ParquetMetaDataReader {
fn parse_column_orders(
t_column_orders: Option<Vec<TColumnOrder>>,
schema_descr: &SchemaDescriptor,
- ) -> Option<Vec<ColumnOrder>> {
+ ) -> Result<Option<Vec<ColumnOrder>>> {
match t_column_orders {
Some(orders) => {
// Should always be the case
- assert_eq!(
- orders.len(),
- schema_descr.num_columns(),
- "Column order length mismatch"
- );
+ if orders.len() != schema_descr.num_columns() {
+ return Err(general_err!("Column order length mismatch"));
+ };
let mut res = Vec::new();
for (i, column) in schema_descr.columns().iter().enumerate() {
match orders[i] {
@@ -667,9 +666,9 @@ impl ParquetMetaDataReader {
}
}
}
- Some(res)
+ Ok(Some(res))
}
- None => None,
+ None => Ok(None),
}
}
}
@@ -741,7 +740,7 @@ mod tests {
]);
assert_eq!(
- ParquetMetaDataReader::parse_column_orders(t_column_orders,
&schema_descr),
+ ParquetMetaDataReader::parse_column_orders(t_column_orders,
&schema_descr).unwrap(),
Some(vec![
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED),
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED)
@@ -750,20 +749,21 @@ mod tests {
// Test when no column orders are defined.
assert_eq!(
- ParquetMetaDataReader::parse_column_orders(None, &schema_descr),
+ ParquetMetaDataReader::parse_column_orders(None,
&schema_descr).unwrap(),
None
);
}
#[test]
- #[should_panic(expected = "Column order length mismatch")]
fn test_metadata_column_orders_len_mismatch() {
let schema = SchemaType::group_type_builder("schema").build().unwrap();
let schema_descr = SchemaDescriptor::new(Arc::new(schema));
let t_column_orders =
Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]);
- ParquetMetaDataReader::parse_column_orders(t_column_orders,
&schema_descr);
+ let res = ParquetMetaDataReader::parse_column_orders(t_column_orders,
&schema_descr);
+ assert!(res.is_err());
+ assert!(format!("{:?}", res.unwrap_err()).contains("Column order
length mismatch"));
}
#[test]
diff --git a/parquet/src/file/serialized_reader.rs
b/parquet/src/file/serialized_reader.rs
index 06f3cf9fb2..a942481f7e 100644
--- a/parquet/src/file/serialized_reader.rs
+++ b/parquet/src/file/serialized_reader.rs
@@ -435,7 +435,7 @@ pub(crate) fn decode_page(
let is_sorted = dict_header.is_sorted.unwrap_or(false);
Page::DictionaryPage {
buf: buffer,
- num_values: dict_header.num_values as u32,
+ num_values: dict_header.num_values.try_into()?,
encoding: Encoding::try_from(dict_header.encoding)?,
is_sorted,
}
@@ -446,7 +446,7 @@ pub(crate) fn decode_page(
.ok_or_else(|| ParquetError::General("Missing V1 data page
header".to_string()))?;
Page::DataPage {
buf: buffer,
- num_values: header.num_values as u32,
+ num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
def_level_encoding:
Encoding::try_from(header.definition_level_encoding)?,
rep_level_encoding:
Encoding::try_from(header.repetition_level_encoding)?,
@@ -460,12 +460,12 @@ pub(crate) fn decode_page(
let is_compressed = header.is_compressed.unwrap_or(true);
Page::DataPageV2 {
buf: buffer,
- num_values: header.num_values as u32,
+ num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
- num_nulls: header.num_nulls as u32,
- num_rows: header.num_rows as u32,
- def_levels_byte_len: header.definition_levels_byte_length as
u32,
- rep_levels_byte_len: header.repetition_levels_byte_length as
u32,
+ num_nulls: header.num_nulls.try_into()?,
+ num_rows: header.num_rows.try_into()?,
+ def_levels_byte_len:
header.definition_levels_byte_length.try_into()?,
+ rep_levels_byte_len:
header.repetition_levels_byte_length.try_into()?,
is_compressed,
statistics: statistics::from_thrift(physical_type,
header.statistics)?,
}
@@ -578,6 +578,27 @@ impl<R: ChunkReader> Iterator for SerializedPageReader<R> {
}
}
+fn verify_page_header_len(header_len: usize, remaining_bytes: usize) ->
Result<()> {
+ if header_len > remaining_bytes {
+ return Err(eof_err!("Invalid page header"));
+ }
+ Ok(())
+}
+
+fn verify_page_size(
+ compressed_size: i32,
+ uncompressed_size: i32,
+ remaining_bytes: usize,
+) -> Result<()> {
+ // The page's compressed size should not exceed the remaining bytes that
are
+ // available to read. The page's uncompressed size is the expected size
+ // after decompression, which can never be negative.
+ if compressed_size < 0 || compressed_size as usize > remaining_bytes ||
uncompressed_size < 0 {
+ return Err(eof_err!("Invalid page header"));
+ }
+ Ok(())
+}
+
impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
fn get_next_page(&mut self) -> Result<Option<Page>> {
loop {
@@ -596,10 +617,16 @@ impl<R: ChunkReader> PageReader for
SerializedPageReader<R> {
*header
} else {
let (header_len, header) = read_page_header_len(&mut
read)?;
+ verify_page_header_len(header_len, *remaining)?;
*offset += header_len;
*remaining -= header_len;
header
};
+ verify_page_size(
+ header.compressed_page_size,
+ header.uncompressed_page_size,
+ *remaining,
+ )?;
let data_len = header.compressed_page_size as usize;
*offset += data_len;
*remaining -= data_len;
@@ -683,6 +710,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R>
{
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut
read)?;
+ verify_page_header_len(header_len, *remaining_bytes)?;
*offset += header_len;
*remaining_bytes -= header_len;
let page_meta = if let Ok(page_meta) =
(&header).try_into() {
@@ -733,12 +761,23 @@ impl<R: ChunkReader> PageReader for
SerializedPageReader<R> {
next_page_header,
} => {
if let Some(buffered_header) = next_page_header.take() {
+ verify_page_size(
+ buffered_header.compressed_page_size,
+ buffered_header.uncompressed_page_size,
+ *remaining_bytes,
+ )?;
// The next page header has already been peeked, so just
advance the offset
*offset += buffered_header.compressed_page_size as usize;
*remaining_bytes -= buffered_header.compressed_page_size
as usize;
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut
read)?;
+ verify_page_header_len(header_len, *remaining_bytes)?;
+ verify_page_size(
+ header.compressed_page_size,
+ header.uncompressed_page_size,
+ *remaining_bytes,
+ )?;
let data_page_size = header.compressed_page_size as usize;
*offset += header_len + data_page_size;
*remaining_bytes -= header_len + data_page_size;
diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs
index 2e05b83369..b7522a76f0 100644
--- a/parquet/src/file/statistics.rs
+++ b/parquet/src/file/statistics.rs
@@ -157,6 +157,32 @@ pub fn from_thrift(
stats.max_value
};
+ fn check_len(min: &Option<Vec<u8>>, max: &Option<Vec<u8>>, len:
usize) -> Result<()> {
+ if let Some(min) = min {
+ if min.len() < len {
+ return Err(ParquetError::General(
+ "Insufficient bytes to parse min
statistic".to_string(),
+ ));
+ }
+ }
+ if let Some(max) = max {
+ if max.len() < len {
+ return Err(ParquetError::General(
+ "Insufficient bytes to parse max
statistic".to_string(),
+ ));
+ }
+ }
+ Ok(())
+ }
+
+ match physical_type {
+ Type::BOOLEAN => check_len(&min, &max, 1),
+ Type::INT32 | Type::FLOAT => check_len(&min, &max, 4),
+ Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8),
+ Type::INT96 => check_len(&min, &max, 12),
+ _ => Ok(()),
+ }?;
+
// Values are encoded using PLAIN encoding definition, except that
// variable-length byte arrays do not include a length prefix.
//
diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs
index d168e46de0..d9e9b22e80 100644
--- a/parquet/src/schema/types.rs
+++ b/parquet/src/schema/types.rs
@@ -556,7 +556,11 @@ impl<'a> PrimitiveTypeBuilder<'a> {
}
}
PhysicalType::FIXED_LEN_BYTE_ARRAY => {
- let max_precision = (2f64.powi(8 * self.length - 1) -
1f64).log10().floor() as i32;
+ let length = self
+ .length
+ .checked_mul(8)
+ .ok_or(general_err!("Invalid length {} for Decimal",
self.length))?;
+ let max_precision = (2f64.powi(length - 1) -
1f64).log10().floor() as i32;
if self.precision > max_precision {
return Err(general_err!(
@@ -1171,9 +1175,25 @@ pub fn from_thrift(elements: &[SchemaElement]) ->
Result<TypePtr> {
));
}
+ if !schema_nodes[0].is_group() {
+ return Err(general_err!("Expected root node to be a group type"));
+ }
+
Ok(schema_nodes.remove(0))
}
+/// Checks if the logical type is valid.
+fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
+ if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
+ if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width
!= 64 {
+ return Err(general_err!(
+ "Bit width must be 8, 16, 32, or 64 for Integer logical type"
+ ));
+ }
+ }
+ Ok(())
+}
+
/// Constructs a new Type from the `elements`, starting at index `index`.
/// The first result is the starting index for the next Type after this one.
If it is
/// equal to `elements.len()`, then this Type is the last one.
@@ -1198,6 +1218,9 @@ fn from_thrift_helper(elements: &[SchemaElement], index:
usize) -> Result<(usize
.logical_type
.as_ref()
.map(|value| LogicalType::from(value.clone()));
+
+ check_logical_type(&logical_type)?;
+
let field_id = elements[index].field_id;
match elements[index].num_children {
// From parquet-format:
diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs
index ceb6b1c29f..b216fec6f3 100644
--- a/parquet/src/thrift.rs
+++ b/parquet/src/thrift.rs
@@ -67,7 +67,7 @@ impl<'a> TCompactSliceInputProtocol<'a> {
let mut shift = 0;
loop {
let byte = self.read_byte()?;
- in_progress |= ((byte & 0x7F) as u64) << shift;
+ in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
shift += 7;
if byte & 0x80 == 0 {
return Ok(in_progress);
@@ -96,13 +96,22 @@ impl<'a> TCompactSliceInputProtocol<'a> {
}
}
+macro_rules! thrift_unimplemented {
+ () => {
+ Err(thrift::Error::Protocol(thrift::ProtocolError {
+ kind: thrift::ProtocolErrorKind::NotImplemented,
+ message: "not implemented".to_string(),
+ }))
+ };
+}
+
impl TInputProtocol for TCompactSliceInputProtocol<'_> {
fn read_message_begin(&mut self) -> thrift::Result<TMessageIdentifier> {
unimplemented!()
}
fn read_message_end(&mut self) -> thrift::Result<()> {
- unimplemented!()
+ thrift_unimplemented!()
}
fn read_struct_begin(&mut self) ->
thrift::Result<Option<TStructIdentifier>> {
@@ -147,7 +156,21 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
),
_ => {
if field_delta != 0 {
- self.last_read_field_id += field_delta as i16;
+ self.last_read_field_id = self
+ .last_read_field_id
+ .checked_add(field_delta as i16)
+ .map_or_else(
+ || {
+
Err(thrift::Error::Protocol(thrift::ProtocolError {
+ kind:
thrift::ProtocolErrorKind::InvalidData,
+ message: format!(
+ "cannot add {} to {}",
+ field_delta, self.last_read_field_id
+ ),
+ }))
+ },
+ Ok,
+ )?;
} else {
self.last_read_field_id = self.read_i16()?;
};
@@ -226,15 +249,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
}
fn read_set_begin(&mut self) -> thrift::Result<TSetIdentifier> {
- unimplemented!()
+ thrift_unimplemented!()
}
fn read_set_end(&mut self) -> thrift::Result<()> {
- unimplemented!()
+ thrift_unimplemented!()
}
fn read_map_begin(&mut self) -> thrift::Result<TMapIdentifier> {
- unimplemented!()
+ thrift_unimplemented!()
}
fn read_map_end(&mut self) -> thrift::Result<()> {
diff --git a/parquet/tests/arrow_reader/bad_data.rs
b/parquet/tests/arrow_reader/bad_data.rs
index 7434203143..cfd61e82d3 100644
--- a/parquet/tests/arrow_reader/bad_data.rs
+++ b/parquet/tests/arrow_reader/bad_data.rs
@@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() {
let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err();
assert_eq!(
err.to_string(),
- "External: Parquet argument error: EOF: eof decoding byte array"
+ "External: Parquet argument error: Parquet error: Integer overflow:
out of range integral type conversion attempted"
);
}