This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 4320a753be Implement Take for UnionArray (#4883)
4320a753be is described below

commit 4320a753beaee0a1a6870c59ef46b59e88c9c323
Author: Brent Gardner <[email protected]>
AuthorDate: Mon Oct 2 09:21:51 2023 -0600

    Implement Take for UnionArray (#4883)
    
    Implement Take for UnionArray (#4883)
---
 arrow-select/src/take.rs | 54 +++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 53 insertions(+), 1 deletion(-)

diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index 70b80e5878..a546949f86 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -28,7 +28,7 @@ use arrow_buffer::{
     ScalarBuffer,
 };
 use arrow_data::{ArrayData, ArrayDataBuilder};
-use arrow_schema::{ArrowError, DataType, FieldRef};
+use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
 
 use num::{One, Zero};
 
@@ -223,6 +223,21 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
                 Ok(new_null_array(&DataType::Null, indices.len()))
             }
         }
+        DataType::Union(fields, UnionMode::Sparse) => {
+            let mut field_type_ids = Vec::with_capacity(fields.len());
+            let mut children = Vec::with_capacity(fields.len());
+            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
+            let type_ids = take_native(values.type_ids(), 
indices).into_inner();
+            for (type_id, field) in fields.iter() {
+                let values = values.child(type_id);
+                let values = take_impl(values, indices)?;
+                let field = (**field).clone();
+                children.push((field, values));
+                field_type_ids.push(type_id);
+            }
+            let array = UnionArray::try_new(field_type_ids.as_slice(), 
type_ids, None, children)?;
+            Ok(Arc::new(array))
+        }
         t => unimplemented!("Take not supported for data type {:?}", t)
     }
 }
@@ -2013,4 +2028,41 @@ mod tests {
         let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
         assert_eq!(&values, &[Some("foo"), None, None, None])
     }
+
+    #[test]
+    fn test_take_union() {
+        let structs = create_test_struct(vec![
+            Some((Some(true), Some(42))),
+            Some((Some(false), Some(28))),
+            Some((Some(false), Some(19))),
+            Some((Some(true), Some(31))),
+            None,
+        ]);
+        let strings =
+            StringArray::from(vec![Some("a"), None, Some("c"), None, 
Some("d")]);
+        let type_ids = Buffer::from_slice_ref(vec![1i8; 5]);
+
+        let children: Vec<(Field, Arc<dyn Array>)> = vec![
+            (
+                Field::new("f1", structs.data_type().clone(), true),
+                Arc::new(structs),
+            ),
+            (
+                Field::new("f2", strings.data_type().clone(), true),
+                Arc::new(strings),
+            ),
+        ];
+        let array = UnionArray::try_new(&[0, 1], type_ids, None, 
children).unwrap();
+
+        let indices = vec![0, 3, 1, 0, 2, 4];
+        let index = UInt32Array::from(indices.clone());
+        let actual = take(&array, &index, None).unwrap();
+        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
+        let strings = actual.child(1);
+        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
+
+        let actual = strings.iter().collect::<Vec<_>>();
+        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), 
Some("d")];
+        assert_eq!(expected, actual);
+    }
 }

Reply via email to