This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e2d6c05 ARROW-11357: [Rust]: Fix out-of-bounds reads in `take` and
other undefined behavior
e2d6c05 is described below
commit e2d6c057684b587151afffe50f7eaef94533e017
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Tue Feb 16 15:21:44 2021 -0500
ARROW-11357: [Rust]: Fix out-of-bounds reads in `take` and other undefined
behavior
This PR fixes two major issues in our `take` kernel for primitive arrays.
Background
* When using `values()` from an array, it is important to remember that
only certain values are not arbitrary (those whose null bit is set / no buffer).
* When reading values from an array using `array.value(i)`, it is important
to remember that it currently performs no bound checks.
* `take` offers an option to deactivate bound checks, by turning an error
in a panic. that option defaults to not check (i.e. `panic`)
however, `take` kernel respects none of the points above:
* it reads and uses arbitrary indices
* it accesses out of bound values
* it does not panic when `check_bounds = false` (it instead reads out of
bounds)
Specifically, it is currently doing something like the following (ignoring
some details):
```rust
let indices = indices.values();
for index in indices {
let index = index.to_usize();
let taken_value = values.value(index)
}
```
I.e.
* there is no check that `index` is a valid slot
* there is no check that `index < values.len()`.
Independently, each of them is unsound. Combined, they allow for
spectacular unbounded memory reads: reading from a pointer offsetted by an
arbitrary value. 🤯
This PR fixes this behavior. This PR also improves performance by 20-40%
(thanks to @Dandandan and @tyrelr great suggestions).
```
Your branch is ahead of 'origin/take_fix' by 1 commit.
(use "git push" to publish your local commits)
Compiling arrow v4.0.0-SNAPSHOT
(/Users/jorgecarleitao/projects/arrow/rust/arrow)
Finished bench [optimized] target(s) in 1m 25s
Running
/Users/jorgecarleitao/projects/arrow/rust/target/release/deps/take_kernels-683d8fd1eeba5497
Gnuplot not found, using plotters backend
take i32 512 time: [1.0477 us 1.0519 us 1.0564 us]
change: [-25.007% -23.339% -21.840%] (p = 0.00 <
0.05)
Performance has improved.
Found 4 outliers among 100 measurements (4.00%)
2 (2.00%) high mild
2 (2.00%) high severe
take i32 1024 time: [1.4949 us 1.4996 us 1.5049 us]
change: [-34.066% -31.907% -29.853%] (p = 0.00 <
0.05)
Performance has improved.
Found 7 outliers among 100 measurements (7.00%)
4 (4.00%) high mild
3 (3.00%) high severe
take i32 nulls 512 time: [976.65 ns 983.09 ns 991.99 ns]
change: [-45.076% -39.795% -35.157%] (p = 0.00 <
0.05)
Performance has improved.
Found 9 outliers among 100 measurements (9.00%)
1 (1.00%) low mild
3 (3.00%) high mild
5 (5.00%) high severe
take i32 nulls 1024 time: [1.3249 us 1.3278 us 1.3309 us]
change: [-32.887% -31.666% -30.524%] (p = 0.00 <
0.05)
Performance has improved.
Found 8 outliers among 100 measurements (8.00%)
3 (3.00%) high mild
5 (5.00%) high severe
```
Closes #9301 from jorgecarleitao/take_fix
Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Andrew Lamb <[email protected]>
---
rust/arrow/src/compute/kernels/take.rs | 321 ++++++++++++++++++++++++++-------
rust/arrow/src/datatypes/native.rs | 22 +++
2 files changed, 274 insertions(+), 69 deletions(-)
diff --git a/rust/arrow/src/compute/kernels/take.rs
b/rust/arrow/src/compute/kernels/take.rs
index 8f3ed56..adae71d 100644
--- a/rust/arrow/src/compute/kernels/take.rs
+++ b/rust/arrow/src/compute/kernels/take.rs
@@ -254,6 +254,143 @@ impl Default for TakeOptions {
}
}
+#[inline(always)]
+fn maybe_usize<I: ArrowPrimitiveType>(index: I::Native) -> Result<usize> {
+ index
+ .to_usize()
+ .ok_or_else(|| ArrowError::ComputeError("Cast to usize
failed".to_string()))
+}
+
+// take implementation when neither values nor indices contain nulls
+fn take_no_nulls<T, I>(
+ values: &[T::Native],
+ indices: &[I::Native],
+) -> Result<(Buffer, Option<Buffer>)>
+where
+ T: ArrowPrimitiveType,
+ I: ArrowNumericType,
+{
+ let values = indices
+ .iter()
+ .map(|index| Result::Ok(values[maybe_usize::<I>(*index)?]));
+ // Soundness: `slice.map` is `TrustedLen`.
+ let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+ Ok((buffer, None))
+}
+
+// take implementation when only values contain nulls
+fn take_values_nulls<T, I>(
+ values: &PrimitiveArray<T>,
+ indices: &[I::Native],
+) -> Result<(Buffer, Option<Buffer>)>
+where
+ T: ArrowPrimitiveType,
+ I: ArrowNumericType,
+ I::Native: ToPrimitive,
+{
+ let num_bytes = bit_util::ceil(indices.len(), 8);
+ let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
+ let null_slice = nulls.as_slice_mut();
+ let mut null_count = 0;
+
+ let values_values = values.values();
+
+ let values = indices.iter().enumerate().map(|(i, index)| {
+ let index = maybe_usize::<I>(*index)?;
+ if values.is_null(index) {
+ null_count += 1;
+ bit_util::unset_bit(null_slice, i);
+ }
+ Result::Ok(values_values[index])
+ });
+ // Soundness: `slice.map` is `TrustedLen`.
+ let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+ let nulls = if null_count == 0 {
+ // if only non-null values were taken
+ None
+ } else {
+ Some(nulls.into())
+ };
+
+ Ok((buffer, nulls))
+}
+
+// take implementation when only indices contain nulls
+fn take_indices_nulls<T, I>(
+ values: &[T::Native],
+ indices: &PrimitiveArray<I>,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+ T: ArrowPrimitiveType,
+ I: ArrowNumericType,
+ I::Native: ToPrimitive,
+{
+ let values = indices.values().iter().map(|index| {
+ let index = maybe_usize::<I>(*index)?;
+ Result::Ok(match values.get(index) {
+ Some(value) => *value,
+ None => {
+ if indices.is_null(index) {
+ T::Native::default()
+ } else {
+ panic!("Out-of-bounds index {}", index)
+ }
+ }
+ })
+ });
+
+ // Soundness: `slice.map` is `TrustedLen`.
+ let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+ Ok((buffer, indices.data_ref().null_buffer().cloned()))
+}
+
+// take implementation when both values and indices contain nulls
+fn take_values_indices_nulls<T, I>(
+ values: &PrimitiveArray<T>,
+ indices: &PrimitiveArray<I>,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+ T: ArrowPrimitiveType,
+ I: ArrowNumericType,
+ I::Native: ToPrimitive,
+{
+ let num_bytes = bit_util::ceil(indices.len(), 8);
+ let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
+ let null_slice = nulls.as_slice_mut();
+ let mut null_count = 0;
+
+ let values_values = values.values();
+ let values = indices.iter().enumerate().map(|(i, index)| match index {
+ Some(index) => {
+ let index = maybe_usize::<I>(index)?;
+ if values.is_null(index) {
+ null_count += 1;
+ bit_util::unset_bit(null_slice, i);
+ }
+ Result::Ok(values_values[index])
+ }
+ None => {
+ null_count += 1;
+ bit_util::unset_bit(null_slice, i);
+ Ok(T::Native::default())
+ }
+ });
+ // Soundness: `slice.map` is `TrustedLen`.
+ let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
+
+ let nulls = if null_count == 0 {
+ // if only non-null values were taken
+ None
+ } else {
+ Some(nulls.into())
+ };
+
+ Ok((buffer, nulls))
+}
+
/// `take` implementation for all primitive arrays
///
/// This checks if an `indices` slot is populated, and gets the value from
`values`
@@ -269,56 +406,36 @@ fn take_primitive<T, I>(
) -> Result<PrimitiveArray<T>>
where
T: ArrowPrimitiveType,
- T::Native: num::Num,
I: ArrowNumericType,
I::Native: ToPrimitive,
{
- let data_len = indices.len();
-
- let mut buffer =
- MutableBuffer::from_len_zeroed(data_len *
std::mem::size_of::<T::Native>());
- let data = buffer.typed_data_mut();
-
- let nulls;
-
- if values.null_count() == 0 {
- // Take indices without null checking
- for (i, elem) in data.iter_mut().enumerate() {
- let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(||
{
- ArrowError::ComputeError("Cast to usize failed".to_string())
- })?;
-
- *elem = values.value(index);
+ let indices_has_nulls = indices.null_count() > 0;
+ let values_has_nulls = values.null_count() > 0;
+ // note: this function should only panic when "an index is not null and
out of bounds".
+ // if the index is null, its value is undefined and therefore we should
not read from it.
+
+ let (buffer, nulls) = match (values_has_nulls, indices_has_nulls) {
+ (false, false) => {
+ // * no nulls
+ // * all `indices.values()` are valid
+ take_no_nulls::<T, I>(values.values(), indices.values())?
}
- nulls = indices.data_ref().null_buffer().cloned();
- } else {
- let num_bytes = bit_util::ceil(data_len, 8);
- let mut null_buf =
MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
-
- let null_slice = null_buf.as_slice_mut();
-
- for (i, elem) in data.iter_mut().enumerate() {
- let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(||
{
- ArrowError::ComputeError("Cast to usize failed".to_string())
- })?;
-
- if values.is_null(index) {
- bit_util::unset_bit(null_slice, i);
- }
-
- *elem = values.value(index);
+ (true, false) => {
+ // * nulls come from `values` alone
+ // * all `indices.values()` are valid
+ take_values_nulls::<T, I>(values, indices.values())?
}
- nulls = match indices.data_ref().null_buffer() {
- Some(buffer) => Some(buffer_bin_and(
- buffer,
- 0,
- &null_buf.into(),
- 0,
- indices.len(),
- )),
- None => Some(null_buf.into()),
- };
- }
+ (false, true) => {
+ // in this branch it is unsound to read and use `index.values()`,
+ // as doing so is UB when they come from a null slot.
+ take_indices_nulls::<T, I>(values.values(), indices)?
+ }
+ (true, true) => {
+ // in this branch it is unsound to read and use `index.values()`,
+ // as doing so is UB when they come from a null slot.
+ take_values_indices_nulls::<T, I>(values, indices)?
+ }
+ };
let data = ArrayData::new(
T::DATA_TYPE,
@@ -326,7 +443,7 @@ where
None,
nulls,
0,
- vec![buffer.into()],
+ vec![buffer],
vec![],
);
Ok(PrimitiveArray::<T>::from(Arc::new(data)))
@@ -663,14 +780,16 @@ mod tests {
index: &UInt32Array,
options: Option<TakeOptions>,
expected_data: Vec<Option<T::Native>>,
- ) where
+ ) -> Result<()>
+ where
T: ArrowPrimitiveType,
PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
{
let output = PrimitiveArray::<T>::from(data);
let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as
ArrayRef;
- let output = take(&output, index, options).unwrap();
- assert_eq!(&output, &expected)
+ let output = take(&output, index, options)?;
+ assert_eq!(&output, &expected);
+ Ok(())
}
fn test_take_impl_primitive_arrays<T, I>(
@@ -707,6 +826,42 @@ mod tests {
}
#[test]
+ fn test_take_primitive_non_null_indices() {
+ let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
+ test_take_primitive_arrays::<Int8Type>(
+ vec![None, Some(3), Some(5), Some(2), Some(3), None],
+ &index,
+ None,
+ vec![None, None, Some(2), Some(3), Some(3), Some(5)],
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_take_primitive_non_null_values() {
+ let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3),
Some(2)]);
+ test_take_primitive_arrays::<Int8Type>(
+ vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
+ &index,
+ None,
+ vec![Some(3), None, Some(1), Some(3), Some(2)],
+ )
+ .unwrap();
+ }
+
+ #[test]
+ fn test_take_primitive_non_null() {
+ let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
+ test_take_primitive_arrays::<Int8Type>(
+ vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
+ &index,
+ None,
+ vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
+ )
+ .unwrap();
+ }
+
+ #[test]
fn test_take_primitive() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3),
Some(2)]);
@@ -716,7 +871,8 @@ mod tests {
&index,
None,
vec![Some(3), None, None, Some(3), Some(2)],
- );
+ )
+ .unwrap();
// int16
test_take_primitive_arrays::<Int16Type>(
@@ -724,7 +880,8 @@ mod tests {
&index,
None,
vec![Some(3), None, None, Some(3), Some(2)],
- );
+ )
+ .unwrap();
// int32
test_take_primitive_arrays::<Int32Type>(
@@ -732,7 +889,8 @@ mod tests {
&index,
None,
vec![Some(3), None, None, Some(3), Some(2)],
- );
+ )
+ .unwrap();
// int64
test_take_primitive_arrays::<Int64Type>(
@@ -740,7 +898,8 @@ mod tests {
&index,
None,
vec![Some(3), None, None, Some(3), Some(2)],
- );
+ )
+ .unwrap();
// uint8
test_take_primitive_arrays::<UInt8Type>(
@@ -748,7 +907,8 @@ mod tests {
&index,
None,
vec![Some(3), None, None, Some(3), Some(2)],
- );
+ )
+ .unwrap();
// uint16
test_take_primitive_arrays::<UInt16Type>(
@@ -756,7 +916,8 @@ mod tests {
&index,
None,
vec![Some(3), None, None, Some(3), Some(2)],
- );
+ )
+ .unwrap();
// uint32
test_take_primitive_arrays::<UInt32Type>(
@@ -764,7 +925,8 @@ mod tests {
&index,
None,
vec![Some(3), None, None, Some(3), Some(2)],
- );
+ )
+ .unwrap();
// int64
test_take_primitive_arrays::<Int64Type>(
@@ -772,7 +934,8 @@ mod tests {
&index,
None,
vec![Some(-15), None, None, Some(-15), Some(2)],
- );
+ )
+ .unwrap();
// interval_year_month
test_take_primitive_arrays::<IntervalYearMonthType>(
@@ -780,7 +943,8 @@ mod tests {
&index,
None,
vec![Some(-15), None, None, Some(-15), Some(2)],
- );
+ )
+ .unwrap();
// interval_day_time
test_take_primitive_arrays::<IntervalDayTimeType>(
@@ -788,7 +952,8 @@ mod tests {
&index,
None,
vec![Some(-15), None, None, Some(-15), Some(2)],
- );
+ )
+ .unwrap();
// duration_second
test_take_primitive_arrays::<DurationSecondType>(
@@ -796,7 +961,8 @@ mod tests {
&index,
None,
vec![Some(-15), None, None, Some(-15), Some(2)],
- );
+ )
+ .unwrap();
// duration_millisecond
test_take_primitive_arrays::<DurationMillisecondType>(
@@ -804,7 +970,8 @@ mod tests {
&index,
None,
vec![Some(-15), None, None, Some(-15), Some(2)],
- );
+ )
+ .unwrap();
// duration_microsecond
test_take_primitive_arrays::<DurationMicrosecondType>(
@@ -812,7 +979,8 @@ mod tests {
&index,
None,
vec![Some(-15), None, None, Some(-15), Some(2)],
- );
+ )
+ .unwrap();
// duration_nanosecond
test_take_primitive_arrays::<DurationNanosecondType>(
@@ -820,7 +988,8 @@ mod tests {
&index,
None,
vec![Some(-15), None, None, Some(-15), Some(2)],
- );
+ )
+ .unwrap();
// float32
test_take_primitive_arrays::<Float32Type>(
@@ -828,7 +997,8 @@ mod tests {
&index,
None,
vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
- );
+ )
+ .unwrap();
// float64
test_take_primitive_arrays::<Float64Type>(
@@ -836,7 +1006,8 @@ mod tests {
&index,
None,
vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
- );
+ )
+ .unwrap();
}
#[test]
@@ -1350,20 +1521,32 @@ mod tests {
}
#[test]
- #[should_panic(
- expected = "Array index out of bounds, cannot get item at index 6 from
5 entries"
- )]
fn test_take_out_of_bounds() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3),
Some(6)]);
let take_opt = TakeOptions { check_bounds: true };
// int64
- test_take_primitive_arrays::<Int64Type>(
+ let result = test_take_primitive_arrays::<Int64Type>(
vec![Some(0), None, Some(2), Some(3), None],
&index,
Some(take_opt),
vec![None],
);
+ assert!(result.is_err());
+ }
+
+ #[test]
+ #[should_panic(expected = "index out of bounds: the len is 4 but the index
is 1000")]
+ fn test_take_out_of_bounds_panic() {
+ let index = UInt32Array::from(vec![Some(1000)]);
+
+ test_take_primitive_arrays::<Int64Type>(
+ vec![Some(0), Some(1), Some(2), Some(3)],
+ &index,
+ None,
+ vec![None],
+ )
+ .unwrap();
}
#[test]
diff --git a/rust/arrow/src/datatypes/native.rs
b/rust/arrow/src/datatypes/native.rs
index bfca235..fb1bad4 100644
--- a/rust/arrow/src/datatypes/native.rs
+++ b/rust/arrow/src/datatypes/native.rs
@@ -39,21 +39,25 @@ pub trait ArrowNativeType:
+ JsonSerializable
{
/// Convert native type from usize.
+ #[inline]
fn from_usize(_: usize) -> Option<Self> {
None
}
/// Convert native type to usize.
+ #[inline]
fn to_usize(&self) -> Option<usize> {
None
}
/// Convert native type from i32.
+ #[inline]
fn from_i32(_: i32) -> Option<Self> {
None
}
/// Convert native type from i64.
+ #[inline]
fn from_i64(_: i64) -> Option<Self> {
None
}
@@ -94,10 +98,12 @@ impl JsonSerializable for i8 {
}
impl ArrowNativeType for i8 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}
@@ -110,10 +116,12 @@ impl JsonSerializable for i16 {
}
impl ArrowNativeType for i16 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}
@@ -126,15 +134,18 @@ impl JsonSerializable for i32 {
}
impl ArrowNativeType for i32 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}
/// Convert native type from i32.
+ #[inline]
fn from_i32(val: i32) -> Option<Self> {
Some(val)
}
@@ -147,15 +158,18 @@ impl JsonSerializable for i64 {
}
impl ArrowNativeType for i64 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}
/// Convert native type from i64.
+ #[inline]
fn from_i64(val: i64) -> Option<Self> {
Some(val)
}
@@ -168,10 +182,12 @@ impl JsonSerializable for u8 {
}
impl ArrowNativeType for u8 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}
@@ -184,10 +200,12 @@ impl JsonSerializable for u16 {
}
impl ArrowNativeType for u16 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}
@@ -200,10 +218,12 @@ impl JsonSerializable for u32 {
}
impl ArrowNativeType for u32 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}
@@ -216,10 +236,12 @@ impl JsonSerializable for u64 {
}
impl ArrowNativeType for u64 {
+ #[inline]
fn from_usize(v: usize) -> Option<Self> {
num::FromPrimitive::from_usize(v)
}
+ #[inline]
fn to_usize(&self) -> Option<usize> {
num::ToPrimitive::to_usize(self)
}