This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 18385e5634 Fix take_bytes Null and Overflow Handling (#4576) (#4579)
18385e5634 is described below
commit 18385e56343c64bbbc76f271c5fbb4f27b5e7e8d
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Sat Jul 29 15:24:53 2023 +0100
Fix take_bytes Null and Overflow Handling (#4576) (#4579)
* Cleanup take_bytes
* Use extend
* Tweak
* Review feedback
---
arrow-select/src/take.rs | 85 +++++++++++++++++++++---------------------------
1 file changed, 37 insertions(+), 48 deletions(-)
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index 0f5689ff99..cee9cbaf84 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -331,94 +331,70 @@ fn take_bytes<T: ByteArrayType, IndexType:
ArrowPrimitiveType>(
let data_len = indices.len();
let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
- let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);
+ let mut offsets = MutableBuffer::new(bytes_offset);
+ offsets.push(T::Offset::default());
- let offsets = offsets_buffer.typed_data_mut();
let mut values = MutableBuffer::new(0);
- let mut length_so_far = T::Offset::from_usize(0).unwrap();
- offsets[0] = length_so_far;
let nulls;
if array.null_count() == 0 && indices.null_count() == 0 {
- for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
- let index = indices.value(i).to_usize().ok_or_else(|| {
- ArrowError::ComputeError("Cast to usize failed".to_string())
- })?;
-
- let s = array.value(index);
-
- let s: &[u8] = s.as_ref();
- length_so_far += T::Offset::from_usize(s.len()).unwrap();
+ offsets.extend(indices.values().iter().map(|index| {
+ let s: &[u8] = array.value(index.as_usize()).as_ref();
values.extend_from_slice(s);
- *offset = length_so_far;
- }
+ T::Offset::usize_as(values.len())
+ }));
nulls = None
} else if indices.null_count() == 0 {
let num_bytes = bit_util::ceil(data_len, 8);
let mut null_buf =
MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();
-
- for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
- let index = indices.value(i).to_usize().ok_or_else(|| {
- ArrowError::ComputeError("Cast to usize failed".to_string())
- })?;
-
+ offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
+ let index = index.as_usize();
if array.is_valid(index) {
let s: &[u8] = array.value(index).as_ref();
-
- length_so_far += T::Offset::from_usize(s.len()).unwrap();
values.extend_from_slice(s.as_ref());
} else {
bit_util::unset_bit(null_slice, i);
}
- *offset = length_so_far;
- }
+ T::Offset::usize_as(values.len())
+ }));
nulls = Some(null_buf.into());
} else if array.null_count() == 0 {
- for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
+ offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
if indices.is_valid(i) {
- let index = indices.value(i).to_usize().ok_or_else(|| {
- ArrowError::ComputeError("Cast to usize
failed".to_string())
- })?;
-
- let s: &[u8] = array.value(index).as_ref();
-
- length_so_far += T::Offset::from_usize(s.len()).unwrap();
+ let s: &[u8] = array.value(index.as_usize()).as_ref();
values.extend_from_slice(s);
}
- *offset = length_so_far;
- }
+ T::Offset::usize_as(values.len())
+ }));
nulls = indices.nulls().map(|b| b.inner().sliced());
} else {
let num_bytes = bit_util::ceil(data_len, 8);
let mut null_buf =
MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();
-
- for (i, offset) in offsets.iter_mut().skip(1).enumerate() {
- let index = indices.value(i).to_usize().ok_or_else(|| {
- ArrowError::ComputeError("Cast to usize failed".to_string())
- })?;
-
- if array.is_valid(index) && indices.is_valid(i) {
+ offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
+ // check index is valid before using index. The value in
+ // NULL index slots may not be within bounds of array
+ let index = index.as_usize();
+ if indices.is_valid(i) && array.is_valid(index) {
let s: &[u8] = array.value(index).as_ref();
-
- length_so_far += T::Offset::from_usize(s.len()).unwrap();
values.extend_from_slice(s);
} else {
// set null bit
bit_util::unset_bit(null_slice, i);
}
- *offset = length_so_far;
- }
-
+ T::Offset::usize_as(values.len())
+ }));
nulls = Some(null_buf.into())
}
+ T::Offset::from_usize(values.len()).expect("offset overflow");
+
let array_data = ArrayData::builder(T::DATA_TYPE)
.len(data_len)
- .add_buffer(offsets_buffer.into())
+ .add_buffer(offsets.into())
.add_buffer(values.into())
.null_bit_buffer(nulls);
@@ -1937,6 +1913,7 @@ mod tests {
#[test]
fn test_take_null_indices() {
+ // Build indices with values that are out of bounds, but masked by
null mask
let indices = Int32Array::new(
vec![1, 2, 400, 400].into(),
Some(NullBuffer::from(vec![true, true, false, false])),
@@ -1949,4 +1926,16 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(&values, &[Some(23), Some(4), None, None])
}
+
+ #[test]
+ fn test_take_bytes_null_indices() {
+ let indices = Int32Array::new(
+ vec![0, 1, 400, 400].into(),
+ Some(NullBuffer::from_iter(vec![true, true, false, false])),
+ );
+ let values = StringArray::from(vec![Some("foo"), None]);
+ let r = take(&values, &indices, None).unwrap();
+ let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
+ assert_eq!(&values, &[Some("foo"), None, None, None])
+ }
}