This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e572a458b Fix Unsound Binary Casting in Unreleased Arrow (#3691) 
(#3692)
e572a458b is described below

commit e572a458b777b52c6f7a2876f2c15f42a5df5303
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Feb 10 19:49:04 2023 +0000

    Fix Unsound Binary Casting in Unreleased Arrow (#3691) (#3692)
    
    * Fix binary casting (#3691)
    
    * Clippy
    
    * More clippy
    
    * Update test
---
 arrow-array/src/array/string_array.rs | 51 +++++++++++++++--------
 arrow-cast/src/cast.rs                | 77 ++++++++++++++++-------------------
 arrow-row/src/lib.rs                  |  2 +-
 3 files changed, 69 insertions(+), 61 deletions(-)

diff --git a/arrow-array/src/array/string_array.rs 
b/arrow-array/src/array/string_array.rs
index cb401540d..2ff1118bc 100644
--- a/arrow-array/src/array/string_array.rs
+++ b/arrow-array/src/array/string_array.rs
@@ -21,7 +21,7 @@ use crate::{
 };
 use arrow_buffer::{bit_util, MutableBuffer};
 use arrow_data::ArrayData;
-use arrow_schema::DataType;
+use arrow_schema::{ArrowError, DataType};
 
 /// Generic struct for \[Large\]StringArray
 ///
@@ -99,6 +99,34 @@ impl<OffsetSize: OffsetSizeTrait> 
GenericStringArray<OffsetSize> {
     ) -> impl Iterator<Item = Option<&str>> + 'a {
         indexes.map(|opt_index| opt_index.map(|index| 
self.value_unchecked(index)))
     }
+
+    /// Fallibly creates a [`GenericStringArray`] from a 
[`GenericBinaryArray`] returning
+    /// an error if [`GenericBinaryArray`] contains invalid UTF-8 data
+    pub fn try_from_binary(
+        v: GenericBinaryArray<OffsetSize>,
+    ) -> Result<Self, ArrowError> {
+        let offsets = v.value_offsets();
+        let values = v.value_data();
+
+        // We only need to validate that all values are valid UTF-8
+        let validated = std::str::from_utf8(values).map_err(|e| {
+            ArrowError::CastError(format!("Encountered non UTF-8 data: {e}"))
+        })?;
+
+        for offset in offsets.iter() {
+            let o = offset.as_usize();
+            if !validated.is_char_boundary(o) {
+                return Err(ArrowError::CastError(format!(
+                    "Split UTF-8 codepoint at offset {o}"
+                )));
+            }
+        }
+
+        let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE);
+        // SAFETY:
+        // Validated UTF-8 above
+        Ok(Self::from(unsafe { builder.build_unchecked() }))
+    }
 }
 
 impl<'a, Ptr, OffsetSize: OffsetSizeTrait> FromIterator<&'a Option<Ptr>>
@@ -172,22 +200,7 @@ impl<OffsetSize: OffsetSizeTrait> 
From<GenericBinaryArray<OffsetSize>>
     for GenericStringArray<OffsetSize>
 {
     fn from(v: GenericBinaryArray<OffsetSize>) -> Self {
-        let offsets = v.value_offsets();
-        let values = v.value_data();
-
-        // We only need to validate that all values are valid UTF-8
-        let validated = std::str::from_utf8(values).expect("Invalid UTF-8 
sequence");
-        for offset in offsets.iter() {
-            assert!(
-                validated.is_char_boundary(offset.as_usize()),
-                "Invalid UTF-8 sequence"
-            )
-        }
-
-        let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE);
-        // SAFETY:
-        // Validated UTF-8 above
-        Self::from(unsafe { builder.build_unchecked() })
+        Self::try_from_binary(v).unwrap()
     }
 }
 
@@ -650,7 +663,9 @@ mod tests {
     }
 
     #[test]
-    #[should_panic(expected = "Invalid UTF-8 sequence: Utf8Error")]
+    #[should_panic(
+        expected = "Encountered non UTF-8 data: invalid utf-8 sequence of 1 
bytes from index 0"
+    )]
     fn test_list_array_utf8_validation() {
         let mut builder = 
ListBuilder::new(PrimitiveBuilder::<UInt8Type>::new());
         builder.values().append_value(0xFF);
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 1631f2e00..49461b14c 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -3202,49 +3202,25 @@ fn cast_binary_to_string<O: OffsetSizeTrait>(
         .downcast_ref::<GenericByteArray<GenericBinaryType<O>>>()
         .unwrap();
 
-    if !cast_options.safe {
-        let offsets = array.value_offsets();
-        let values = array.value_data();
-
-        // We only need to validate that all values are valid UTF-8
-        let validated = std::str::from_utf8(values)
-            .map_err(|_| ArrowError::CastError("Invalid UTF-8 
sequence".to_string()))?;
-        // Checks if the offsets are valid but does not re-encode
-        for offset in offsets.iter() {
-            if !validated.is_char_boundary(offset.as_usize()) {
-                return Err(ArrowError::CastError("Invalid UTF-8 
sequence".to_string()));
+    match GenericStringArray::<O>::try_from_binary(array.clone()) {
+        Ok(a) => Ok(Arc::new(a)),
+        Err(e) => match cast_options.safe {
+            true => {
+                // Fallback to slow method to convert invalid sequences to 
nulls
+                let mut builder = GenericStringBuilder::<O>::with_capacity(
+                    array.len(),
+                    array.value_data().len(),
+                );
+
+                let iter = array
+                    .iter()
+                    .map(|v| v.and_then(|v| std::str::from_utf8(v).ok()));
+
+                builder.extend(iter);
+                Ok(Arc::new(builder.finish()))
             }
-        }
-
-        let builder = array
-            .into_data()
-            .into_builder()
-            .data_type(GenericStringArray::<O>::DATA_TYPE);
-        // SAFETY:
-        // Validated UTF-8 above
-        Ok(Arc::new(GenericStringArray::<O>::from(unsafe {
-            builder.build_unchecked()
-        })))
-    } else {
-        let mut null_builder = BooleanBufferBuilder::new(array.len());
-        array.iter().for_each(|maybe_value| {
-            null_builder.append(
-                maybe_value
-                    .and_then(|value| std::str::from_utf8(value).ok())
-                    .is_some(),
-            );
-        });
-
-        let builder = array
-            .into_data()
-            .into_builder()
-            .null_bit_buffer(Some(null_builder.finish()))
-            .data_type(GenericStringArray::<O>::DATA_TYPE);
-        // SAFETY:
-        // Validated UTF-8 above
-        Ok(Arc::new(GenericStringArray::<O>::from(unsafe {
-            builder.build_unchecked()
-        })))
+            false => Err(e),
+        },
     }
 }
 
@@ -7588,4 +7564,21 @@ mod tests {
         test_tz("+00:00".to_owned());
         test_tz("+02:00".to_owned());
     }
+
+    #[test]
+    fn test_cast_invalid_utf8() {
+        let v1: &[u8] = b"\xFF invalid";
+        let v2: &[u8] = b"\x00 Foo";
+        let s = BinaryArray::from(vec![v1, v2]);
+        let options = CastOptions { safe: true };
+        let array = cast_with_options(&s, &DataType::Utf8, &options).unwrap();
+        let a = as_string_array(array.as_ref());
+        a.data().validate_full().unwrap();
+
+        assert_eq!(a.null_count(), 1);
+        assert_eq!(a.len(), 2);
+        assert!(a.is_null(0));
+        assert_eq!(a.value(0), "");
+        assert_eq!(a.value(1), "\x00 Foo");
+    }
 }
diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs
index 1d54a008f..2e489c974 100644
--- a/arrow-row/src/lib.rs
+++ b/arrow-row/src/lib.rs
@@ -1734,7 +1734,7 @@ mod tests {
     }
 
     #[test]
-    #[should_panic(expected = "Invalid UTF-8 sequence")]
+    #[should_panic(expected = "Encountered non UTF-8 data")]
     fn test_invalid_utf8() {
         let mut converter =
             RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();

Reply via email to