This is an automated email from the ASF dual-hosted git repository.
avantgardner pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 4320a753be Implement Take for UnionArray (#4883)
4320a753be is described below
commit 4320a753beaee0a1a6870c59ef46b59e88c9c323
Author: Brent Gardner <[email protected]>
AuthorDate: Mon Oct 2 09:21:51 2023 -0600
Implement Take for UnionArray (#4883)
Implement Take for UnionArray (#4883)
---
arrow-select/src/take.rs | 54 +++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 53 insertions(+), 1 deletion(-)
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index 70b80e5878..a546949f86 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -28,7 +28,7 @@ use arrow_buffer::{
ScalarBuffer,
};
use arrow_data::{ArrayData, ArrayDataBuilder};
-use arrow_schema::{ArrowError, DataType, FieldRef};
+use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
use num::{One, Zero};
@@ -223,6 +223,21 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
Ok(new_null_array(&DataType::Null, indices.len()))
}
}
+ DataType::Union(fields, UnionMode::Sparse) => {
+ let mut field_type_ids = Vec::with_capacity(fields.len());
+ let mut children = Vec::with_capacity(fields.len());
+ let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
+ let type_ids = take_native(values.type_ids(),
indices).into_inner();
+ for (type_id, field) in fields.iter() {
+ let values = values.child(type_id);
+ let values = take_impl(values, indices)?;
+ let field = (**field).clone();
+ children.push((field, values));
+ field_type_ids.push(type_id);
+ }
+ let array = UnionArray::try_new(field_type_ids.as_slice(),
type_ids, None, children)?;
+ Ok(Arc::new(array))
+ }
t => unimplemented!("Take not supported for data type {:?}", t)
}
}
@@ -2013,4 +2028,41 @@ mod tests {
let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
assert_eq!(&values, &[Some("foo"), None, None, None])
}
+
+ #[test]
+ fn test_take_union() {
+ let structs = create_test_struct(vec![
+ Some((Some(true), Some(42))),
+ Some((Some(false), Some(28))),
+ Some((Some(false), Some(19))),
+ Some((Some(true), Some(31))),
+ None,
+ ]);
+ let strings =
+ StringArray::from(vec![Some("a"), None, Some("c"), None,
Some("d")]);
+ let type_ids = Buffer::from_slice_ref(vec![1i8; 5]);
+
+ let children: Vec<(Field, Arc<dyn Array>)> = vec![
+ (
+ Field::new("f1", structs.data_type().clone(), true),
+ Arc::new(structs),
+ ),
+ (
+ Field::new("f2", strings.data_type().clone(), true),
+ Arc::new(strings),
+ ),
+ ];
+ let array = UnionArray::try_new(&[0, 1], type_ids, None,
children).unwrap();
+
+ let indices = vec![0, 3, 1, 0, 2, 4];
+ let index = UInt32Array::from(indices.clone());
+ let actual = take(&array, &index, None).unwrap();
+ let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
+ let strings = actual.child(1);
+ let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
+
+ let actual = strings.iter().collect::<Vec<_>>();
+ let expected = vec![Some("a"), None, None, Some("a"), Some("c"),
Some("d")];
+ assert_eq!(expected, actual);
+ }
}