This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new e572a458b Fix Unsound Binary Casting in Unreleased Arrow (#3691)
(#3692)
e572a458b is described below
commit e572a458b777b52c6f7a2876f2c15f42a5df5303
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Feb 10 19:49:04 2023 +0000
Fix Unsound Binary Casting in Unreleased Arrow (#3691) (#3692)
* Fix binary casting (#3691)
* Clippy
* More clippy
* Update test
---
arrow-array/src/array/string_array.rs | 51 +++++++++++++++--------
arrow-cast/src/cast.rs | 77 ++++++++++++++++-------------------
arrow-row/src/lib.rs | 2 +-
3 files changed, 69 insertions(+), 61 deletions(-)
diff --git a/arrow-array/src/array/string_array.rs
b/arrow-array/src/array/string_array.rs
index cb401540d..2ff1118bc 100644
--- a/arrow-array/src/array/string_array.rs
+++ b/arrow-array/src/array/string_array.rs
@@ -21,7 +21,7 @@ use crate::{
};
use arrow_buffer::{bit_util, MutableBuffer};
use arrow_data::ArrayData;
-use arrow_schema::DataType;
+use arrow_schema::{ArrowError, DataType};
/// Generic struct for \[Large\]StringArray
///
@@ -99,6 +99,34 @@ impl<OffsetSize: OffsetSizeTrait>
GenericStringArray<OffsetSize> {
) -> impl Iterator<Item = Option<&str>> + 'a {
indexes.map(|opt_index| opt_index.map(|index|
self.value_unchecked(index)))
}
+
+ /// Fallibly creates a [`GenericStringArray`] from a
[`GenericBinaryArray`] returning
+ /// an error if [`GenericBinaryArray`] contains invalid UTF-8 data
+ pub fn try_from_binary(
+ v: GenericBinaryArray<OffsetSize>,
+ ) -> Result<Self, ArrowError> {
+ let offsets = v.value_offsets();
+ let values = v.value_data();
+
+ // We only need to validate that all values are valid UTF-8
+ let validated = std::str::from_utf8(values).map_err(|e| {
+ ArrowError::CastError(format!("Encountered non UTF-8 data: {e}"))
+ })?;
+
+ for offset in offsets.iter() {
+ let o = offset.as_usize();
+ if !validated.is_char_boundary(o) {
+ return Err(ArrowError::CastError(format!(
+ "Split UTF-8 codepoint at offset {o}"
+ )));
+ }
+ }
+
+ let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE);
+ // SAFETY:
+ // Validated UTF-8 above
+ Ok(Self::from(unsafe { builder.build_unchecked() }))
+ }
}
impl<'a, Ptr, OffsetSize: OffsetSizeTrait> FromIterator<&'a Option<Ptr>>
@@ -172,22 +200,7 @@ impl<OffsetSize: OffsetSizeTrait>
From<GenericBinaryArray<OffsetSize>>
for GenericStringArray<OffsetSize>
{
fn from(v: GenericBinaryArray<OffsetSize>) -> Self {
- let offsets = v.value_offsets();
- let values = v.value_data();
-
- // We only need to validate that all values are valid UTF-8
- let validated = std::str::from_utf8(values).expect("Invalid UTF-8
sequence");
- for offset in offsets.iter() {
- assert!(
- validated.is_char_boundary(offset.as_usize()),
- "Invalid UTF-8 sequence"
- )
- }
-
- let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE);
- // SAFETY:
- // Validated UTF-8 above
- Self::from(unsafe { builder.build_unchecked() })
+ Self::try_from_binary(v).unwrap()
}
}
@@ -650,7 +663,9 @@ mod tests {
}
#[test]
- #[should_panic(expected = "Invalid UTF-8 sequence: Utf8Error")]
+ #[should_panic(
+ expected = "Encountered non UTF-8 data: invalid utf-8 sequence of 1
bytes from index 0"
+ )]
fn test_list_array_utf8_validation() {
let mut builder =
ListBuilder::new(PrimitiveBuilder::<UInt8Type>::new());
builder.values().append_value(0xFF);
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 1631f2e00..49461b14c 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -3202,49 +3202,25 @@ fn cast_binary_to_string<O: OffsetSizeTrait>(
.downcast_ref::<GenericByteArray<GenericBinaryType<O>>>()
.unwrap();
- if !cast_options.safe {
- let offsets = array.value_offsets();
- let values = array.value_data();
-
- // We only need to validate that all values are valid UTF-8
- let validated = std::str::from_utf8(values)
- .map_err(|_| ArrowError::CastError("Invalid UTF-8
sequence".to_string()))?;
- // Checks if the offsets are valid but does not re-encode
- for offset in offsets.iter() {
- if !validated.is_char_boundary(offset.as_usize()) {
- return Err(ArrowError::CastError("Invalid UTF-8
sequence".to_string()));
+ match GenericStringArray::<O>::try_from_binary(array.clone()) {
+ Ok(a) => Ok(Arc::new(a)),
+ Err(e) => match cast_options.safe {
+ true => {
+ // Fallback to slow method to convert invalid sequences to
nulls
+ let mut builder = GenericStringBuilder::<O>::with_capacity(
+ array.len(),
+ array.value_data().len(),
+ );
+
+ let iter = array
+ .iter()
+ .map(|v| v.and_then(|v| std::str::from_utf8(v).ok()));
+
+ builder.extend(iter);
+ Ok(Arc::new(builder.finish()))
}
- }
-
- let builder = array
- .into_data()
- .into_builder()
- .data_type(GenericStringArray::<O>::DATA_TYPE);
- // SAFETY:
- // Validated UTF-8 above
- Ok(Arc::new(GenericStringArray::<O>::from(unsafe {
- builder.build_unchecked()
- })))
- } else {
- let mut null_builder = BooleanBufferBuilder::new(array.len());
- array.iter().for_each(|maybe_value| {
- null_builder.append(
- maybe_value
- .and_then(|value| std::str::from_utf8(value).ok())
- .is_some(),
- );
- });
-
- let builder = array
- .into_data()
- .into_builder()
- .null_bit_buffer(Some(null_builder.finish()))
- .data_type(GenericStringArray::<O>::DATA_TYPE);
- // SAFETY:
- // Validated UTF-8 above
- Ok(Arc::new(GenericStringArray::<O>::from(unsafe {
- builder.build_unchecked()
- })))
+ false => Err(e),
+ },
}
}
@@ -7588,4 +7564,21 @@ mod tests {
test_tz("+00:00".to_owned());
test_tz("+02:00".to_owned());
}
+
+ #[test]
+ fn test_cast_invalid_utf8() {
+ let v1: &[u8] = b"\xFF invalid";
+ let v2: &[u8] = b"\x00 Foo";
+ let s = BinaryArray::from(vec![v1, v2]);
+ let options = CastOptions { safe: true };
+ let array = cast_with_options(&s, &DataType::Utf8, &options).unwrap();
+ let a = as_string_array(array.as_ref());
+ a.data().validate_full().unwrap();
+
+ assert_eq!(a.null_count(), 1);
+ assert_eq!(a.len(), 2);
+ assert!(a.is_null(0));
+ assert_eq!(a.value(0), "");
+ assert_eq!(a.value(1), "\x00 Foo");
+ }
}
diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs
index 1d54a008f..2e489c974 100644
--- a/arrow-row/src/lib.rs
+++ b/arrow-row/src/lib.rs
@@ -1734,7 +1734,7 @@ mod tests {
}
#[test]
- #[should_panic(expected = "Invalid UTF-8 sequence")]
+ #[should_panic(expected = "Encountered non UTF-8 data")]
fn test_invalid_utf8() {
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();