This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 8e476aee95a feat: implement take for dense union array (#5873)
8e476aee95a is described below
commit 8e476aee95affa20122bc72fc7a8b701763a26ad
Author: gstvg <[email protected]>
AuthorDate: Thu Jun 13 12:44:03 2024 -0300
feat: implement take for dense union array (#5873)
---
arrow-select/src/filter.rs | 5 +-
arrow-select/src/take.rs | 128 ++++++++++++++++++++++++++++++++++++++++++++-
2 files changed, 131 insertions(+), 2 deletions(-)
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index 8e06b07f5ef..65ccbe1e01a 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -552,7 +552,10 @@ fn filter_native<T: ArrowNativeType>(values: &[T],
predicate: &FilterPredicate)
}
/// `filter` implementation for primitive arrays
-fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate)
-> PrimitiveArray<T>
+pub(crate) fn filter_primitive<T>(
+ array: &PrimitiveArray<T>,
+ predicate: &FilterPredicate,
+) -> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
{
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index b8d59142db7..d6892eb0a9e 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -31,6 +31,8 @@ use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
use num::{One, Zero};
+use crate::filter::{filter_primitive, FilterBuilder};
+
/// Take elements by index from [Array], creating a new [Array] from those
indexes.
///
/// ```text
@@ -240,6 +242,44 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
let array = UnionArray::try_new(fields.clone(), type_ids, None,
children)?;
Ok(Arc::new(array))
}
+ DataType::Union(fields, UnionMode::Dense) => {
+ let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ let type_ids =
<PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
+ let offsets =
<PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(),
indices), None);
+
+ let children = fields.iter()
+ .map(|(field_type_id, _)| {
+ let mask = BooleanArray::from_unary(&type_ids,
|value_type_id| value_type_id == field_type_id);
+ let predicate = FilterBuilder::new(&mask).build();
+
+ let indices = filter_primitive(&offsets, &predicate);
+
+ let values = values.child(field_type_id);
+
+ take_impl(values, &indices)
+ })
+ .collect::<Result<_, _>>()?;
+
+ let mut child_offsets = [0; 128];
+
+ let offsets = type_ids.values()
+ .iter()
+ .map(|&i| {
+ let offset = child_offsets[i as usize];
+
+ child_offsets[i as usize] += 1;
+
+ offset
+ })
+ .collect();
+
+ let (_, type_ids, _) = type_ids.into_parts();
+
+ let array = UnionArray::try_new(fields.clone(), type_ids,
Some(offsets), children)?;
+
+ Ok(Arc::new(array))
+ }
t => unimplemented!("Take not supported for data type {:?}", t)
}
}
@@ -2146,7 +2186,7 @@ mod tests {
}
#[test]
- fn test_take_union() {
+ fn test_take_union_sparse() {
let structs = create_test_struct(vec![
Some((Some(true), Some(42))),
Some((Some(false), Some(28))),
@@ -2183,4 +2223,90 @@ mod tests {
let expected = vec![Some("a"), None, None, Some("a"), Some("c"),
Some("d")];
assert_eq!(expected, actual);
}
+
+ #[test]
+ fn test_take_union_dense() {
+ let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
+ let offsets = vec![0, 0, 1, 1, 2, 2, 3];
+ let ints = vec![10, 20, 30, 40];
+ let strings = vec![Some("a"), None, Some("c"), Some("d")];
+
+ let indices = vec![0, 3, 1, 0, 2, 4];
+
+ let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
+ let taken_offsets = vec![0, 1, 0, 2, 1, 3];
+ let taken_ints = vec![10, 20, 10, 30];
+ let taken_strings = vec![Some("a"), None];
+
+ let type_ids = <ScalarBuffer<i8>>::from(type_ids);
+ let offsets = <ScalarBuffer<i32>>::from(offsets);
+ let ints = UInt32Array::from(ints);
+ let strings = StringArray::from(strings);
+
+ let union_fields = [
+ (
+ 0,
+ Arc::new(Field::new("f1", ints.data_type().clone(), true)),
+ ),
+ (
+ 1,
+ Arc::new(Field::new("f2", strings.data_type().clone(), true)),
+ ),
+ ]
+ .into_iter()
+ .collect();
+
+ let array = UnionArray::try_new(
+ union_fields,
+ type_ids,
+ Some(offsets),
+ vec![Arc::new(ints), Arc::new(strings)],
+ )
+ .unwrap();
+
+ let index = UInt32Array::from(indices);
+
+ let actual = take(&array, &index, None).unwrap();
+ let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
+ assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
+ assert_eq!(
+ UInt32Array::from(actual.child(0).to_data()),
+ UInt32Array::from(taken_ints)
+ );
+ assert_eq!(
+ StringArray::from(actual.child(1).to_data()),
+ StringArray::from(taken_strings)
+ );
+ }
+
+ #[test]
+ fn test_take_union_dense_using_builder() {
+ let mut builder = UnionBuilder::new_dense();
+
+ builder.append::<Int32Type>("a", 1).unwrap();
+ builder.append::<Float64Type>("b", 3.0).unwrap();
+ builder.append::<Int32Type>("a", 4).unwrap();
+ builder.append::<Int32Type>("a", 5).unwrap();
+ builder.append::<Float64Type>("b", 2.0).unwrap();
+
+ let union = builder.build().unwrap();
+
+ let indices = UInt32Array::from(vec![2, 0, 1, 2]);
+
+ let mut builder = UnionBuilder::new_dense();
+
+ builder.append::<Int32Type>("a", 4).unwrap();
+ builder.append::<Int32Type>("a", 1).unwrap();
+ builder.append::<Float64Type>("b", 3.0).unwrap();
+ builder.append::<Int32Type>("a", 4).unwrap();
+
+ let taken = builder.build().unwrap();
+
+ assert_eq!(
+ taken.to_data(),
+ take(&union, &indices, None).unwrap().to_data()
+ );
+ }
}