This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e476aee95a feat: implement take for dense union array (#5873)
8e476aee95a is described below

commit 8e476aee95affa20122bc72fc7a8b701763a26ad
Author: gstvg <[email protected]>
AuthorDate: Thu Jun 13 12:44:03 2024 -0300

    feat: implement take for dense union array (#5873)
---
 arrow-select/src/filter.rs |   5 +-
 arrow-select/src/take.rs   | 128 ++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 131 insertions(+), 2 deletions(-)

diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index 8e06b07f5ef..65ccbe1e01a 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -552,7 +552,10 @@ fn filter_native<T: ArrowNativeType>(values: &[T], 
predicate: &FilterPredicate)
 }
 
 /// `filter` implementation for primitive arrays
-fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) 
-> PrimitiveArray<T>
+pub(crate) fn filter_primitive<T>(
+    array: &PrimitiveArray<T>,
+    predicate: &FilterPredicate,
+) -> PrimitiveArray<T>
 where
     T: ArrowPrimitiveType,
 {
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index b8d59142db7..d6892eb0a9e 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -31,6 +31,8 @@ use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
 
 use num::{One, Zero};
 
+use crate::filter::{filter_primitive, FilterBuilder};
+
 /// Take elements by index from [Array], creating a new [Array] from those 
indexes.
 ///
 /// ```text
@@ -240,6 +242,44 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
             let array = UnionArray::try_new(fields.clone(), type_ids, None, 
children)?;
             Ok(Arc::new(array))
         }
+        DataType::Union(fields, UnionMode::Dense) => {
+            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
+
+            let type_ids = 
<PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
+            let offsets = 
<PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), 
indices), None);
+
+            let children = fields.iter()
+                .map(|(field_type_id, _)| {
+                    let mask = BooleanArray::from_unary(&type_ids, 
|value_type_id| value_type_id == field_type_id);
+                    let predicate = FilterBuilder::new(&mask).build();
+
+                    let indices = filter_primitive(&offsets, &predicate);
+
+                    let values = values.child(field_type_id);
+
+                    take_impl(values, &indices)
+                })
+                .collect::<Result<_, _>>()?;
+
+            let mut child_offsets = [0; 128];
+
+            let offsets = type_ids.values()
+                .iter()
+                .map(|&i| {
+                    let offset = child_offsets[i as usize];
+
+                    child_offsets[i as usize] += 1;
+
+                    offset
+                })
+                .collect();
+
+            let (_, type_ids, _) = type_ids.into_parts();
+
+            let array = UnionArray::try_new(fields.clone(), type_ids, 
Some(offsets), children)?;
+
+            Ok(Arc::new(array))
+        }
         t => unimplemented!("Take not supported for data type {:?}", t)
     }
 }
@@ -2146,7 +2186,7 @@ mod tests {
     }
 
     #[test]
-    fn test_take_union() {
+    fn test_take_union_sparse() {
         let structs = create_test_struct(vec![
             Some((Some(true), Some(42))),
             Some((Some(false), Some(28))),
@@ -2183,4 +2223,90 @@ mod tests {
         let expected = vec![Some("a"), None, None, Some("a"), Some("c"), 
Some("d")];
         assert_eq!(expected, actual);
     }
+
+    #[test]
+    fn test_take_union_dense() {
+        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
+        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
+        let ints = vec![10, 20, 30, 40];
+        let strings = vec![Some("a"), None, Some("c"), Some("d")];
+
+        let indices = vec![0, 3, 1, 0, 2, 4];
+
+        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
+        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
+        let taken_ints = vec![10, 20, 10, 30];
+        let taken_strings = vec![Some("a"), None];
+
+        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
+        let offsets = <ScalarBuffer<i32>>::from(offsets);
+        let ints = UInt32Array::from(ints);
+        let strings = StringArray::from(strings);
+
+        let union_fields = [
+            (
+                0,
+                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
+            ),
+            (
+                1,
+                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
+            ),
+        ]
+        .into_iter()
+        .collect();
+
+        let array = UnionArray::try_new(
+            union_fields,
+            type_ids,
+            Some(offsets),
+            vec![Arc::new(ints), Arc::new(strings)],
+        )
+        .unwrap();
+
+        let index = UInt32Array::from(indices);
+
+        let actual = take(&array, &index, None).unwrap();
+        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
+
+        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
+        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
+        assert_eq!(
+            UInt32Array::from(actual.child(0).to_data()),
+            UInt32Array::from(taken_ints)
+        );
+        assert_eq!(
+            StringArray::from(actual.child(1).to_data()),
+            StringArray::from(taken_strings)
+        );
+    }
+
+    #[test]
+    fn test_take_union_dense_using_builder() {
+        let mut builder = UnionBuilder::new_dense();
+
+        builder.append::<Int32Type>("a", 1).unwrap();
+        builder.append::<Float64Type>("b", 3.0).unwrap();
+        builder.append::<Int32Type>("a", 4).unwrap();
+        builder.append::<Int32Type>("a", 5).unwrap();
+        builder.append::<Float64Type>("b", 2.0).unwrap();
+
+        let union = builder.build().unwrap();
+
+        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
+
+        let mut builder = UnionBuilder::new_dense();
+
+        builder.append::<Int32Type>("a", 4).unwrap();
+        builder.append::<Int32Type>("a", 1).unwrap();
+        builder.append::<Float64Type>("b", 3.0).unwrap();
+        builder.append::<Int32Type>("a", 4).unwrap();
+
+        let taken = builder.build().unwrap();
+
+        assert_eq!(
+            taken.to_data(),
+            take(&union, &indices, None).unwrap().to_data()
+        );
+    }
 }

Reply via email to