This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 18385e5634 Fix take_bytes Null and Overflow Handling (#4576) (#4579)
18385e5634 is described below

commit 18385e56343c64bbbc76f271c5fbb4f27b5e7e8d
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Sat Jul 29 15:24:53 2023 +0100

    Fix take_bytes Null and Overflow Handling (#4576) (#4579)
    
    * Cleanup take_bytes
    
    * Use extend
    
    * Tweak
    
    * Review feedback
---
 arrow-select/src/take.rs | 85 +++++++++++++++++++++---------------------------
 1 file changed, 37 insertions(+), 48 deletions(-)

diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index 0f5689ff99..cee9cbaf84 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -331,94 +331,70 @@ fn take_bytes<T: ByteArrayType, IndexType: 
ArrowPrimitiveType>(
     let data_len = indices.len();
 
     let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
-    let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);
+    let mut offsets = MutableBuffer::new(bytes_offset);
+    offsets.push(T::Offset::default());
 
-    let offsets = offsets_buffer.typed_data_mut();
     let mut values = MutableBuffer::new(0);
-    let mut length_so_far = T::Offset::from_usize(0).unwrap();
-    offsets[0] = length_so_far;
 
     let nulls;
     if array.null_count() == 0 && indices.null_count() == 0 {
-        for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
-            let index = indices.value(i).to_usize().ok_or_else(|| {
-                ArrowError::ComputeError("Cast to usize failed".to_string())
-            })?;
-
-            let s = array.value(index);
-
-            let s: &[u8] = s.as_ref();
-            length_so_far += T::Offset::from_usize(s.len()).unwrap();
+        offsets.extend(indices.values().iter().map(|index| {
+            let s: &[u8] = array.value(index.as_usize()).as_ref();
             values.extend_from_slice(s);
-            *offset = length_so_far;
-        }
+            T::Offset::usize_as(values.len())
+        }));
         nulls = None
     } else if indices.null_count() == 0 {
         let num_bytes = bit_util::ceil(data_len, 8);
 
         let mut null_buf = 
MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
         let null_slice = null_buf.as_slice_mut();
-
-        for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
-            let index = indices.value(i).to_usize().ok_or_else(|| {
-                ArrowError::ComputeError("Cast to usize failed".to_string())
-            })?;
-
+        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
+            let index = index.as_usize();
             if array.is_valid(index) {
                 let s: &[u8] = array.value(index).as_ref();
-
-                length_so_far += T::Offset::from_usize(s.len()).unwrap();
                 values.extend_from_slice(s.as_ref());
             } else {
                 bit_util::unset_bit(null_slice, i);
             }
-            *offset = length_so_far;
-        }
+            T::Offset::usize_as(values.len())
+        }));
         nulls = Some(null_buf.into());
     } else if array.null_count() == 0 {
-        for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
+        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
             if indices.is_valid(i) {
-                let index = indices.value(i).to_usize().ok_or_else(|| {
-                    ArrowError::ComputeError("Cast to usize 
failed".to_string())
-                })?;
-
-                let s: &[u8] = array.value(index).as_ref();
-
-                length_so_far += T::Offset::from_usize(s.len()).unwrap();
+                let s: &[u8] = array.value(index.as_usize()).as_ref();
                 values.extend_from_slice(s);
             }
-            *offset = length_so_far;
-        }
+            T::Offset::usize_as(values.len())
+        }));
         nulls = indices.nulls().map(|b| b.inner().sliced());
     } else {
         let num_bytes = bit_util::ceil(data_len, 8);
 
         let mut null_buf = 
MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
         let null_slice = null_buf.as_slice_mut();
-
-        for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
-            let index = indices.value(i).to_usize().ok_or_else(|| {
-                ArrowError::ComputeError("Cast to usize failed".to_string())
-            })?;
-
-            if array.is_valid(index) && indices.is_valid(i) {
+        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
+            // check index is valid before using index. The value in
+            // NULL index slots may not be within bounds of array
+            let index = index.as_usize();
+            if indices.is_valid(i) && array.is_valid(index) {
                 let s: &[u8] = array.value(index).as_ref();
-
-                length_so_far += T::Offset::from_usize(s.len()).unwrap();
                 values.extend_from_slice(s);
             } else {
                 // set null bit
                 bit_util::unset_bit(null_slice, i);
             }
-            *offset = length_so_far;
-        }
-
+            T::Offset::usize_as(values.len())
+        }));
         nulls = Some(null_buf.into())
     }
 
+    T::Offset::from_usize(values.len()).expect("offset overflow");
+
     let array_data = ArrayData::builder(T::DATA_TYPE)
         .len(data_len)
-        .add_buffer(offsets_buffer.into())
+        .add_buffer(offsets.into())
         .add_buffer(values.into())
         .null_bit_buffer(nulls);
 
@@ -1937,6 +1913,7 @@ mod tests {
 
     #[test]
     fn test_take_null_indices() {
+        // Build indices with values that are out of bounds, but masked by 
null mask
         let indices = Int32Array::new(
             vec![1, 2, 400, 400].into(),
             Some(NullBuffer::from(vec![true, true, false, false])),
@@ -1949,4 +1926,16 @@ mod tests {
             .collect::<Vec<_>>();
         assert_eq!(&values, &[Some(23), Some(4), None, None])
     }
+
+    #[test]
+    fn test_take_bytes_null_indices() {
+        let indices = Int32Array::new(
+            vec![0, 1, 400, 400].into(),
+            Some(NullBuffer::from_iter(vec![true, true, false, false])),
+        );
+        let values = StringArray::from(vec![Some("foo"), None]);
+        let r = take(&values, &indices, None).unwrap();
+        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
+        assert_eq!(&values, &[Some("foo"), None, None, None])
+    }
 }

Reply via email to