This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e2d6c05  ARROW-11357: [Rust]: Fix out-of-bounds reads in `take` and 
other undefined behavior
e2d6c05 is described below

commit e2d6c057684b587151afffe50f7eaef94533e017
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Tue Feb 16 15:21:44 2021 -0500

    ARROW-11357: [Rust]: Fix out-of-bounds reads in `take` and other undefined 
behavior
    
    This PR fixes two major issues in our `take` kernel for primitive arrays.
    
    Background
    
    * When using `values()` from an array, it is important to remember that 
only certain values are not arbitrary (those whose null bit is set / no buffer).
    
    * When reading values from an array using `array.value(i)`, it is important 
to remember that it currently performs no bound checks.
    
    * `take` offers an option to deactivate bound checks, by turning an error 
in a panic. that option defaults to not check (i.e. `panic`)
    
    however, `take` kernel respects none of the points above:
    * it reads and uses arbitrary indices
    * it accesses out of bound values
    * it does not panic when `check_bounds = false` (it instead reads out of 
bounds)
    
    Specifically, it is currently doing something like the following (ignoring 
some details):
    
    ```rust
    let indices = indices.values();
    for index in indices {
        let index = index.to_usize();
        let taken_value = values.value(index)
    }
    ```
    
    I.e.
    
    * there is no check that `index` is a valid slot
    * there is no check that `index < values.len()`.
    
    Independently, each of them is unsound. Combined, they allow for 
spectacular unbounded memory reads: reading from a pointer offsetted by an 
arbitrary value. 🤯
    
    This PR fixes this behavior. This PR also improves performance by 20-40% 
(thanks to @Dandandan and @tyrelr great suggestions).
    
    ```
    Your branch is ahead of 'origin/take_fix' by 1 commit.
      (use "git push" to publish your local commits)
       Compiling arrow v4.0.0-SNAPSHOT 
(/Users/jorgecarleitao/projects/arrow/rust/arrow)
        Finished bench [optimized] target(s) in 1m 25s
         Running 
/Users/jorgecarleitao/projects/arrow/rust/target/release/deps/take_kernels-683d8fd1eeba5497
    Gnuplot not found, using plotters backend
    take i32 512            time:   [1.0477 us 1.0519 us 1.0564 us]
                            change: [-25.007% -23.339% -21.840%] (p = 0.00 < 
0.05)
                            Performance has improved.
    Found 4 outliers among 100 measurements (4.00%)
      2 (2.00%) high mild
      2 (2.00%) high severe
    
    take i32 1024           time:   [1.4949 us 1.4996 us 1.5049 us]
                            change: [-34.066% -31.907% -29.853%] (p = 0.00 < 
0.05)
                            Performance has improved.
    Found 7 outliers among 100 measurements (7.00%)
      4 (4.00%) high mild
      3 (3.00%) high severe
    
    take i32 nulls 512      time:   [976.65 ns 983.09 ns 991.99 ns]
                            change: [-45.076% -39.795% -35.157%] (p = 0.00 < 
0.05)
                            Performance has improved.
    Found 9 outliers among 100 measurements (9.00%)
      1 (1.00%) low mild
      3 (3.00%) high mild
      5 (5.00%) high severe
    
    take i32 nulls 1024     time:   [1.3249 us 1.3278 us 1.3309 us]
                            change: [-32.887% -31.666% -30.524%] (p = 0.00 < 
0.05)
                            Performance has improved.
    Found 8 outliers among 100 measurements (8.00%)
      3 (3.00%) high mild
      5 (5.00%) high severe
    ```
    
    Closes #9301 from jorgecarleitao/take_fix
    
    Authored-by: Jorge C. Leitao <[email protected]>
    Signed-off-by: Andrew Lamb <[email protected]>
---
 rust/arrow/src/compute/kernels/take.rs | 321 ++++++++++++++++++++++++++-------
 rust/arrow/src/datatypes/native.rs     |  22 +++
 2 files changed, 274 insertions(+), 69 deletions(-)

diff --git a/rust/arrow/src/compute/kernels/take.rs 
b/rust/arrow/src/compute/kernels/take.rs
index 8f3ed56..adae71d 100644
--- a/rust/arrow/src/compute/kernels/take.rs
+++ b/rust/arrow/src/compute/kernels/take.rs
@@ -254,6 +254,143 @@ impl Default for TakeOptions {
     }
 }
 
+#[inline(always)]
+fn maybe_usize<I: ArrowPrimitiveType>(index: I::Native) -> Result<usize> {
+    index
+        .to_usize()
+        .ok_or_else(|| ArrowError::ComputeError("Cast to usize 
failed".to_string()))
+}
+
+// take implementation when neither values nor indices contain nulls
+fn take_no_nulls<T, I>(
+    values: &[T::Native],
+    indices: &[I::Native],
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowPrimitiveType,
+    I: ArrowNumericType,
+{
+    let values = indices
+        .iter()
+        .map(|index| Result::Ok(values[maybe_usize::<I>(*index)?]));
+    // Soundness: `slice.map` is `TrustedLen`.
+    let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+    Ok((buffer, None))
+}
+
+// take implementation when only values contain nulls
+fn take_values_nulls<T, I>(
+    values: &PrimitiveArray<T>,
+    indices: &[I::Native],
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowPrimitiveType,
+    I: ArrowNumericType,
+    I::Native: ToPrimitive,
+{
+    let num_bytes = bit_util::ceil(indices.len(), 8);
+    let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
+    let null_slice = nulls.as_slice_mut();
+    let mut null_count = 0;
+
+    let values_values = values.values();
+
+    let values = indices.iter().enumerate().map(|(i, index)| {
+        let index = maybe_usize::<I>(*index)?;
+        if values.is_null(index) {
+            null_count += 1;
+            bit_util::unset_bit(null_slice, i);
+        }
+        Result::Ok(values_values[index])
+    });
+    // Soundness: `slice.map` is `TrustedLen`.
+    let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+    let nulls = if null_count == 0 {
+        // if only non-null values were taken
+        None
+    } else {
+        Some(nulls.into())
+    };
+
+    Ok((buffer, nulls))
+}
+
+// take implementation when only indices contain nulls
+fn take_indices_nulls<T, I>(
+    values: &[T::Native],
+    indices: &PrimitiveArray<I>,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowPrimitiveType,
+    I: ArrowNumericType,
+    I::Native: ToPrimitive,
+{
+    let values = indices.values().iter().map(|index| {
+        let index = maybe_usize::<I>(*index)?;
+        Result::Ok(match values.get(index) {
+            Some(value) => *value,
+            None => {
+                if indices.is_null(index) {
+                    T::Native::default()
+                } else {
+                    panic!("Out-of-bounds index {}", index)
+                }
+            }
+        })
+    });
+
+    // Soundness: `slice.map` is `TrustedLen`.
+    let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+    Ok((buffer, indices.data_ref().null_buffer().cloned()))
+}
+
+// take implementation when both values and indices contain nulls
+fn take_values_indices_nulls<T, I>(
+    values: &PrimitiveArray<T>,
+    indices: &PrimitiveArray<I>,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowPrimitiveType,
+    I: ArrowNumericType,
+    I::Native: ToPrimitive,
+{
+    let num_bytes = bit_util::ceil(indices.len(), 8);
+    let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
+    let null_slice = nulls.as_slice_mut();
+    let mut null_count = 0;
+
+    let values_values = values.values();
+    let values = indices.iter().enumerate().map(|(i, index)| match index {
+        Some(index) => {
+            let index = maybe_usize::<I>(index)?;
+            if values.is_null(index) {
+                null_count += 1;
+                bit_util::unset_bit(null_slice, i);
+            }
+            Result::Ok(values_values[index])
+        }
+        None => {
+            null_count += 1;
+            bit_util::unset_bit(null_slice, i);
+            Ok(T::Native::default())
+        }
+    });
+    // Soundness: `slice.map` is `TrustedLen`.
+    let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+    let nulls = if null_count == 0 {
+        // if only non-null values were taken
+        None
+    } else {
+        Some(nulls.into())
+    };
+
+    Ok((buffer, nulls))
+}
+
 /// `take` implementation for all primitive arrays
 ///
 /// This checks if an `indices` slot is populated, and gets the value from 
`values`
@@ -269,56 +406,36 @@ fn take_primitive<T, I>(
 ) -> Result<PrimitiveArray<T>>
 where
     T: ArrowPrimitiveType,
-    T::Native: num::Num,
     I: ArrowNumericType,
     I::Native: ToPrimitive,
 {
-    let data_len = indices.len();
-
-    let mut buffer =
-        MutableBuffer::from_len_zeroed(data_len * 
std::mem::size_of::<T::Native>());
-    let data = buffer.typed_data_mut();
-
-    let nulls;
-
-    if values.null_count() == 0 {
-        // Take indices without null checking
-        for (i, elem) in data.iter_mut().enumerate() {
-            let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| 
{
-                ArrowError::ComputeError("Cast to usize failed".to_string())
-            })?;
-
-            *elem = values.value(index);
+    let indices_has_nulls = indices.null_count() > 0;
+    let values_has_nulls = values.null_count() > 0;
+    // note: this function should only panic when "an index is not null and 
out of bounds".
+    // if the index is null, its value is undefined and therefore we should 
not read from it.
+
+    let (buffer, nulls) = match (values_has_nulls, indices_has_nulls) {
+        (false, false) => {
+            // * no nulls
+            // * all `indices.values()` are valid
+            take_no_nulls::<T, I>(values.values(), indices.values())?
         }
-        nulls = indices.data_ref().null_buffer().cloned();
-    } else {
-        let num_bytes = bit_util::ceil(data_len, 8);
-        let mut null_buf = 
MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
-
-        let null_slice = null_buf.as_slice_mut();
-
-        for (i, elem) in data.iter_mut().enumerate() {
-            let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| 
{
-                ArrowError::ComputeError("Cast to usize failed".to_string())
-            })?;
-
-            if values.is_null(index) {
-                bit_util::unset_bit(null_slice, i);
-            }
-
-            *elem = values.value(index);
+        (true, false) => {
+            // * nulls come from `values` alone
+            // * all `indices.values()` are valid
+            take_values_nulls::<T, I>(values, indices.values())?
         }
-        nulls = match indices.data_ref().null_buffer() {
-            Some(buffer) => Some(buffer_bin_and(
-                buffer,
-                0,
-                &null_buf.into(),
-                0,
-                indices.len(),
-            )),
-            None => Some(null_buf.into()),
-        };
-    }
+        (false, true) => {
+            // in this branch it is unsound to read and use `index.values()`,
+            // as doing so is UB when they come from a null slot.
+            take_indices_nulls::<T, I>(values.values(), indices)?
+        }
+        (true, true) => {
+            // in this branch it is unsound to read and use `index.values()`,
+            // as doing so is UB when they come from a null slot.
+            take_values_indices_nulls::<T, I>(values, indices)?
+        }
+    };
 
     let data = ArrayData::new(
         T::DATA_TYPE,
@@ -326,7 +443,7 @@ where
         None,
         nulls,
         0,
-        vec![buffer.into()],
+        vec![buffer],
         vec![],
     );
     Ok(PrimitiveArray::<T>::from(Arc::new(data)))
@@ -663,14 +780,16 @@ mod tests {
         index: &UInt32Array,
         options: Option<TakeOptions>,
         expected_data: Vec<Option<T::Native>>,
-    ) where
+    ) -> Result<()>
+    where
         T: ArrowPrimitiveType,
         PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
     {
         let output = PrimitiveArray::<T>::from(data);
         let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as 
ArrayRef;
-        let output = take(&output, index, options).unwrap();
-        assert_eq!(&output, &expected)
+        let output = take(&output, index, options)?;
+        assert_eq!(&output, &expected);
+        Ok(())
     }
 
     fn test_take_impl_primitive_arrays<T, I>(
@@ -707,6 +826,42 @@ mod tests {
     }
 
     #[test]
+    fn test_take_primitive_non_null_indices() {
+        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
+        test_take_primitive_arrays::<Int8Type>(
+            vec![None, Some(3), Some(5), Some(2), Some(3), None],
+            &index,
+            None,
+            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
+        )
+        .unwrap();
+    }
+
+    #[test]
+    fn test_take_primitive_non_null_values() {
+        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), 
Some(2)]);
+        test_take_primitive_arrays::<Int8Type>(
+            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
+            &index,
+            None,
+            vec![Some(3), None, Some(1), Some(3), Some(2)],
+        )
+        .unwrap();
+    }
+
+    #[test]
+    fn test_take_primitive_non_null() {
+        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
+        test_take_primitive_arrays::<Int8Type>(
+            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
+            &index,
+            None,
+            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
+        )
+        .unwrap();
+    }
+
+    #[test]
     fn test_take_primitive() {
         let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), 
Some(2)]);
 
@@ -716,7 +871,8 @@ mod tests {
             &index,
             None,
             vec![Some(3), None, None, Some(3), Some(2)],
-        );
+        )
+        .unwrap();
 
         // int16
         test_take_primitive_arrays::<Int16Type>(
@@ -724,7 +880,8 @@ mod tests {
             &index,
             None,
             vec![Some(3), None, None, Some(3), Some(2)],
-        );
+        )
+        .unwrap();
 
         // int32
         test_take_primitive_arrays::<Int32Type>(
@@ -732,7 +889,8 @@ mod tests {
             &index,
             None,
             vec![Some(3), None, None, Some(3), Some(2)],
-        );
+        )
+        .unwrap();
 
         // int64
         test_take_primitive_arrays::<Int64Type>(
@@ -740,7 +898,8 @@ mod tests {
             &index,
             None,
             vec![Some(3), None, None, Some(3), Some(2)],
-        );
+        )
+        .unwrap();
 
         // uint8
         test_take_primitive_arrays::<UInt8Type>(
@@ -748,7 +907,8 @@ mod tests {
             &index,
             None,
             vec![Some(3), None, None, Some(3), Some(2)],
-        );
+        )
+        .unwrap();
 
         // uint16
         test_take_primitive_arrays::<UInt16Type>(
@@ -756,7 +916,8 @@ mod tests {
             &index,
             None,
             vec![Some(3), None, None, Some(3), Some(2)],
-        );
+        )
+        .unwrap();
 
         // uint32
         test_take_primitive_arrays::<UInt32Type>(
@@ -764,7 +925,8 @@ mod tests {
             &index,
             None,
             vec![Some(3), None, None, Some(3), Some(2)],
-        );
+        )
+        .unwrap();
 
         // int64
         test_take_primitive_arrays::<Int64Type>(
@@ -772,7 +934,8 @@ mod tests {
             &index,
             None,
             vec![Some(-15), None, None, Some(-15), Some(2)],
-        );
+        )
+        .unwrap();
 
         // interval_year_month
         test_take_primitive_arrays::<IntervalYearMonthType>(
@@ -780,7 +943,8 @@ mod tests {
             &index,
             None,
             vec![Some(-15), None, None, Some(-15), Some(2)],
-        );
+        )
+        .unwrap();
 
         // interval_day_time
         test_take_primitive_arrays::<IntervalDayTimeType>(
@@ -788,7 +952,8 @@ mod tests {
             &index,
             None,
             vec![Some(-15), None, None, Some(-15), Some(2)],
-        );
+        )
+        .unwrap();
 
         // duration_second
         test_take_primitive_arrays::<DurationSecondType>(
@@ -796,7 +961,8 @@ mod tests {
             &index,
             None,
             vec![Some(-15), None, None, Some(-15), Some(2)],
-        );
+        )
+        .unwrap();
 
         // duration_millisecond
         test_take_primitive_arrays::<DurationMillisecondType>(
@@ -804,7 +970,8 @@ mod tests {
             &index,
             None,
             vec![Some(-15), None, None, Some(-15), Some(2)],
-        );
+        )
+        .unwrap();
 
         // duration_microsecond
         test_take_primitive_arrays::<DurationMicrosecondType>(
@@ -812,7 +979,8 @@ mod tests {
             &index,
             None,
             vec![Some(-15), None, None, Some(-15), Some(2)],
-        );
+        )
+        .unwrap();
 
         // duration_nanosecond
         test_take_primitive_arrays::<DurationNanosecondType>(
@@ -820,7 +988,8 @@ mod tests {
             &index,
             None,
             vec![Some(-15), None, None, Some(-15), Some(2)],
-        );
+        )
+        .unwrap();
 
         // float32
         test_take_primitive_arrays::<Float32Type>(
@@ -828,7 +997,8 @@ mod tests {
             &index,
             None,
             vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
-        );
+        )
+        .unwrap();
 
         // float64
         test_take_primitive_arrays::<Float64Type>(
@@ -836,7 +1006,8 @@ mod tests {
             &index,
             None,
             vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
-        );
+        )
+        .unwrap();
     }
 
     #[test]
@@ -1350,20 +1521,32 @@ mod tests {
     }
 
     #[test]
-    #[should_panic(
-        expected = "Array index out of bounds, cannot get item at index 6 from 
5 entries"
-    )]
     fn test_take_out_of_bounds() {
         let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), 
Some(6)]);
         let take_opt = TakeOptions { check_bounds: true };
 
         // int64
-        test_take_primitive_arrays::<Int64Type>(
+        let result = test_take_primitive_arrays::<Int64Type>(
             vec![Some(0), None, Some(2), Some(3), None],
             &index,
             Some(take_opt),
             vec![None],
         );
+        assert!(result.is_err());
+    }
+
+    #[test]
+    #[should_panic(expected = "index out of bounds: the len is 4 but the index 
is 1000")]
+    fn test_take_out_of_bounds_panic() {
+        let index = UInt32Array::from(vec![Some(1000)]);
+
+        test_take_primitive_arrays::<Int64Type>(
+            vec![Some(0), Some(1), Some(2), Some(3)],
+            &index,
+            None,
+            vec![None],
+        )
+        .unwrap();
     }
 
     #[test]
diff --git a/rust/arrow/src/datatypes/native.rs 
b/rust/arrow/src/datatypes/native.rs
index bfca235..fb1bad4 100644
--- a/rust/arrow/src/datatypes/native.rs
+++ b/rust/arrow/src/datatypes/native.rs
@@ -39,21 +39,25 @@ pub trait ArrowNativeType:
     + JsonSerializable
 {
     /// Convert native type from usize.
+    #[inline]
     fn from_usize(_: usize) -> Option<Self> {
         None
     }
 
     /// Convert native type to usize.
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         None
     }
 
     /// Convert native type from i32.
+    #[inline]
     fn from_i32(_: i32) -> Option<Self> {
         None
     }
 
     /// Convert native type from i64.
+    #[inline]
     fn from_i64(_: i64) -> Option<Self> {
         None
     }
@@ -94,10 +98,12 @@ impl JsonSerializable for i8 {
 }
 
 impl ArrowNativeType for i8 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }
@@ -110,10 +116,12 @@ impl JsonSerializable for i16 {
 }
 
 impl ArrowNativeType for i16 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }
@@ -126,15 +134,18 @@ impl JsonSerializable for i32 {
 }
 
 impl ArrowNativeType for i32 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }
 
     /// Convert native type from i32.
+    #[inline]
     fn from_i32(val: i32) -> Option<Self> {
         Some(val)
     }
@@ -147,15 +158,18 @@ impl JsonSerializable for i64 {
 }
 
 impl ArrowNativeType for i64 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }
 
     /// Convert native type from i64.
+    #[inline]
     fn from_i64(val: i64) -> Option<Self> {
         Some(val)
     }
@@ -168,10 +182,12 @@ impl JsonSerializable for u8 {
 }
 
 impl ArrowNativeType for u8 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }
@@ -184,10 +200,12 @@ impl JsonSerializable for u16 {
 }
 
 impl ArrowNativeType for u16 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }
@@ -200,10 +218,12 @@ impl JsonSerializable for u32 {
 }
 
 impl ArrowNativeType for u32 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }
@@ -216,10 +236,12 @@ impl JsonSerializable for u64 {
 }
 
 impl ArrowNativeType for u64 {
+    #[inline]
     fn from_usize(v: usize) -> Option<Self> {
         num::FromPrimitive::from_usize(v)
     }
 
+    #[inline]
     fn to_usize(&self) -> Option<usize> {
         num::ToPrimitive::to_usize(self)
     }

Reply via email to