This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new f1fb2b1 fix take kernel null handling on structs (#531)
f1fb2b1 is described below
commit f1fb2b11bbd6350365de010d3e1d676a27602d3a
Author: Ben Chambers <[email protected]>
AuthorDate: Fri Jul 9 11:45:56 2021 -0700
fix take kernel null handling on structs (#531)
This closes #530.
Co-authored-by: Ben Chambers <[email protected]>
---
arrow/src/compute/kernels/take.rs | 151 ++++++++++++++++++++++----------------
1 file changed, 88 insertions(+), 63 deletions(-)
diff --git a/arrow/src/compute/kernels/take.rs
b/arrow/src/compute/kernels/take.rs
index bf9d1df..f36e29e 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -231,9 +231,22 @@ where
.map(|a| take_impl(a.as_ref(), indices, Some(options.clone())))
.collect();
let arrays = arrays?;
- let pairs: Vec<(Field, ArrayRef)> =
+ let fields: Vec<(Field, ArrayRef)> =
fields.clone().into_iter().zip(arrays).collect();
- Ok(Arc::new(StructArray::from(pairs)) as ArrayRef)
+
+ // Create the null bit buffer.
+ let is_valid: Buffer = indices
+ .iter()
+ .map(|index| {
+ if let Some(index) = index {
+
struct_.is_valid(ArrowNativeType::to_usize(&index).unwrap())
+ } else {
+ false
+ }
+ })
+ .collect();
+
+ Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
}
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => downcast_dict_take!(Int8Type, values, indices),
@@ -848,20 +861,34 @@ mod tests {
}
// create a simple struct for testing purposes
- fn create_test_struct() -> StructArray {
- let boolean_data = BooleanArray::from(vec![true, false, false, true])
- .data()
- .clone();
- let int_data = Int32Array::from(vec![42, 28, 19, 31]).data().clone();
- let mut field_types = vec![];
- field_types.push(Field::new("a", DataType::Boolean, true));
- field_types.push(Field::new("b", DataType::Int32, true));
- let struct_array_data =
ArrayData::builder(DataType::Struct(field_types))
- .len(4)
- .add_child_data(boolean_data)
- .add_child_data(int_data)
- .build();
- StructArray::from(struct_array_data)
+ fn create_test_struct(
+ values: Vec<Option<(Option<bool>, Option<i32>)>>,
+ ) -> StructArray {
+ let mut struct_builder = StructBuilder::new(
+ vec![
+ Field::new("a", DataType::Boolean, true),
+ Field::new("b", DataType::Int32, true),
+ ],
+ vec![
+ Box::new(BooleanBuilder::new(values.len())),
+ Box::new(Int32Builder::new(values.len())),
+ ],
+ );
+
+ for value in values {
+ struct_builder
+ .field_builder::<BooleanBuilder>(0)
+ .unwrap()
+ .append_option(value.and_then(|v| v.0))
+ .unwrap();
+ struct_builder
+ .field_builder::<Int32Builder>(1)
+ .unwrap()
+ .append_option(value.and_then(|v| v.1))
+ .unwrap();
+ struct_builder.append(value.is_some()).unwrap();
+ }
+ struct_builder.finish()
}
#[test]
@@ -1576,61 +1603,59 @@ mod tests {
#[test]
fn test_take_struct() {
- let array = create_test_struct();
-
- let index = UInt32Array::from(vec![0, 3, 1, 0, 2]);
- let a = take(&array, &index, None).unwrap();
- let a: &StructArray =
a.as_any().downcast_ref::<StructArray>().unwrap();
- assert_eq!(index.len(), a.len());
- assert_eq!(0, a.null_count());
+ let array = create_test_struct(vec![
+ Some((Some(true), Some(42))),
+ Some((Some(false), Some(28))),
+ Some((Some(false), Some(19))),
+ Some((Some(true), Some(31))),
+ None,
+ ]);
- let expected_bool_data = BooleanArray::from(vec![true, true, false,
true, false])
- .data()
- .clone();
- let expected_int_data = Int32Array::from(vec![42, 31, 28, 42,
19]).data().clone();
- let mut field_types = vec![];
- field_types.push(Field::new("a", DataType::Boolean, true));
- field_types.push(Field::new("b", DataType::Int32, true));
- let struct_array_data =
ArrayData::builder(DataType::Struct(field_types))
- .len(5)
- .add_child_data(expected_bool_data)
- .add_child_data(expected_int_data)
- .build();
- let struct_array = StructArray::from(struct_array_data);
+ let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
+ let actual = take(&array, &index, None).unwrap();
+ let actual: &StructArray =
actual.as_any().downcast_ref::<StructArray>().unwrap();
+ assert_eq!(index.len(), actual.len());
+ assert_eq!(1, actual.null_count());
+
+ let expected = create_test_struct(vec![
+ Some((Some(true), Some(42))),
+ Some((Some(true), Some(31))),
+ Some((Some(false), Some(28))),
+ Some((Some(true), Some(42))),
+ Some((Some(false), Some(19))),
+ None,
+ ]);
- assert_eq!(a, &struct_array);
+ assert_eq!(&expected, actual);
}
#[test]
- fn test_take_struct_with_nulls() {
- let array = create_test_struct();
+ fn test_take_struct_with_null_indices() {
+ let array = create_test_struct(vec![
+ Some((Some(true), Some(42))),
+ Some((Some(false), Some(28))),
+ Some((Some(false), Some(19))),
+ Some((Some(true), Some(31))),
+ None,
+ ]);
- let index = UInt32Array::from(vec![None, Some(3), Some(1), None,
Some(0)]);
- let a = take(&array, &index, None).unwrap();
- let a: &StructArray =
a.as_any().downcast_ref::<StructArray>().unwrap();
- assert_eq!(index.len(), a.len());
- assert_eq!(0, a.null_count());
+ let index =
+ UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0),
Some(4)]);
+ let actual = take(&array, &index, None).unwrap();
+ let actual: &StructArray =
actual.as_any().downcast_ref::<StructArray>().unwrap();
+ assert_eq!(index.len(), actual.len());
+ assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because
of struct array
- let expected_bool_data =
- BooleanArray::from(vec![None, Some(true), Some(false), None,
Some(true)])
- .data()
- .clone();
- let expected_int_data =
- Int32Array::from(vec![None, Some(31), Some(28), None, Some(42)])
- .data()
- .clone();
+ let expected = create_test_struct(vec![
+ None,
+ Some((Some(true), Some(31))),
+ Some((Some(false), Some(28))),
+ None,
+ Some((Some(true), Some(42))),
+ None,
+ ]);
- let mut field_types = vec![];
- field_types.push(Field::new("a", DataType::Boolean, true));
- field_types.push(Field::new("b", DataType::Int32, true));
- let struct_array_data =
ArrayData::builder(DataType::Struct(field_types))
- .len(5)
- // TODO: see https://issues.apache.org/jira/browse/ARROW-5408 for
why count != 2
- .add_child_data(expected_bool_data)
- .add_child_data(expected_int_data)
- .build();
- let struct_array = StructArray::from(struct_array_data);
- assert_eq!(a, &struct_array);
+ assert_eq!(&expected, actual);
}
#[test]