This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 831a0804bf Specialize filter for structs and sparse unions (#6304)
831a0804bf is described below

commit 831a0804bf652bd7ae2773394ca21dd43e78b09b
Author: gstvg <[email protected]>
AuthorDate: Sat Aug 31 10:08:48 2024 -0300

    Specialize filter for structs and sparse unions (#6304)
    
    * specialize filter for structs and sparse unions
    
    * fix: move nested function to top level
    
    * fix: clarify optimization cases
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 arrow-array/src/cast.rs    |  16 +++++
 arrow-select/src/filter.rs | 141 ++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 156 insertions(+), 1 deletion(-)

diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs
index 7b4b1d6eca..cda179b78c 100644
--- a/arrow-array/src/cast.rs
+++ b/arrow-array/src/cast.rs
@@ -815,6 +815,14 @@ pub trait AsArray: private::Sealed {
         self.as_struct_opt().expect("struct array")
     }
 
+    /// Downcast this to a [`UnionArray`] returning `None` if not possible
+    fn as_union_opt(&self) -> Option<&UnionArray>;
+
+    /// Downcast this to a [`UnionArray`] panicking if not possible
+    fn as_union(&self) -> &UnionArray {
+        self.as_union_opt().expect("union array")
+    }
+
     /// Downcast this to a [`GenericListArray`] returning `None` if not 
possible
     fn as_list_opt<O: OffsetSizeTrait>(&self) -> Option<&GenericListArray<O>>;
 
@@ -888,6 +896,10 @@ impl AsArray for dyn Array + '_ {
         self.as_any().downcast_ref()
     }
 
+    fn as_union_opt(&self) -> Option<&UnionArray> {
+        self.as_any().downcast_ref()
+    }
+
     fn as_list_opt<O: OffsetSizeTrait>(&self) -> Option<&GenericListArray<O>> {
         self.as_any().downcast_ref()
     }
@@ -939,6 +951,10 @@ impl AsArray for ArrayRef {
         self.as_ref().as_struct_opt()
     }
 
+    fn as_union_opt(&self) -> Option<&UnionArray> {
+        self.as_any().downcast_ref()
+    }
+
     fn as_list_opt<O: OffsetSizeTrait>(&self) -> Option<&GenericListArray<O>> {
         self.as_ref().as_list_opt()
     }
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index c51f44a977..e07b03d1f2 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -169,10 +169,29 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> 
BooleanArray {
 /// assert_eq!(c, &Int32Array::from(vec![5, 8]));
 /// ```
 pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> 
Result<ArrayRef, ArrowError> {
-    let predicate = FilterBuilder::new(predicate).build();
+    let mut filter_builder = FilterBuilder::new(predicate);
+
+    if multiple_arrays(values.data_type()) {
+        // Only optimize if filtering more than one array
+        // Otherwise, the overhead of optimization can be more than the benefit
+        filter_builder = filter_builder.optimize();
+    }
+
+    let predicate = filter_builder.build();
+
     filter_array(values, &predicate)
 }
 
+fn multiple_arrays(data_type: &DataType) -> bool {
+    match data_type {
+        DataType::Struct(fields) => {
+            fields.len() > 1 || fields.len() == 1 && 
multiple_arrays(fields[0].data_type())
+        }
+        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
+        _ => false,
+    }
+}
+
 /// Returns a filtered [RecordBatch] where the corresponding elements of
 /// `predicate` are true.
 ///
@@ -365,6 +384,12 @@ fn filter_array(values: &dyn Array, predicate: 
&FilterPredicate) -> Result<Array
                 values => Ok(Arc::new(filter_dict(values, predicate))),
                 t => unimplemented!("Filter not supported for dictionary type 
{:?}", t)
             }
+            DataType::Struct(_) => {
+                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
+            }
+            DataType::Union(_, UnionMode::Sparse) => {
+                Ok(Arc::new(filter_sparse_union(values.as_union(), 
predicate)?))
+            }
             _ => {
                 let data = values.to_data();
                 // fallback to using MutableArrayData
@@ -789,6 +814,49 @@ where
     DictionaryArray::from(unsafe { builder.build_unchecked() })
 }
 
+/// `filter` implementation for structs
+fn filter_struct(
+    array: &StructArray,
+    predicate: &FilterPredicate,
+) -> Result<StructArray, ArrowError> {
+    let columns = array
+        .columns()
+        .iter()
+        .map(|column| filter_array(column, predicate))
+        .collect::<Result<_, _>>()?;
+
+    let nulls = if let Some((null_count, nulls)) = 
filter_null_mask(array.nulls(), predicate) {
+        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
+
+        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
+    } else {
+        None
+    };
+
+    Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns, 
nulls) })
+}
+
+/// `filter` implementation for sparse unions
+fn filter_sparse_union(
+    array: &UnionArray,
+    predicate: &FilterPredicate,
+) -> Result<UnionArray, ArrowError> {
+    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
+        unreachable!()
+    };
+
+    let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), 
None), predicate);
+
+    let children = fields
+        .iter()
+        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), 
predicate))
+        .collect::<Result<_, _>>()?;
+
+    Ok(unsafe {
+        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, 
None, children)
+    })
+}
+
 #[cfg(test)]
 mod tests {
     use arrow_array::builder::*;
@@ -1878,4 +1946,75 @@ mod tests {
             }
         }
     }
+
+    #[test]
+    fn test_filter_struct() {
+        let predicate = BooleanArray::from(vec![true, false, true, false]);
+
+        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
+        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
+
+        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
+        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
+
+        let null_mask = NullBuffer::from(vec![true, false, false, true]);
+        let null_mask_filtered = NullBuffer::from(vec![true, false]);
+
+        let a_field = Field::new("a", DataType::Utf8, false);
+        let b_field = Field::new("b", DataType::Int32, false);
+
+        let array = StructArray::new(vec![a_field.clone()].into(), 
vec![a.clone()], None);
+        let expected =
+            StructArray::new(vec![a_field.clone()].into(), 
vec![a_filtered.clone()], None);
+
+        let result = filter(&array, &predicate).unwrap();
+
+        assert_eq!(result.to_data(), expected.to_data());
+
+        let array = StructArray::new(
+            vec![a_field.clone()].into(),
+            vec![a.clone()],
+            Some(null_mask.clone()),
+        );
+        let expected = StructArray::new(
+            vec![a_field.clone()].into(),
+            vec![a_filtered.clone()],
+            Some(null_mask_filtered.clone()),
+        );
+
+        let result = filter(&array, &predicate).unwrap();
+
+        assert_eq!(result.to_data(), expected.to_data());
+
+        let array = StructArray::new(
+            vec![a_field.clone(), b_field.clone()].into(),
+            vec![a.clone(), b.clone()],
+            None,
+        );
+        let expected = StructArray::new(
+            vec![a_field.clone(), b_field.clone()].into(),
+            vec![a_filtered.clone(), b_filtered.clone()],
+            None,
+        );
+
+        let result = filter(&array, &predicate).unwrap();
+
+        assert_eq!(result.to_data(), expected.to_data());
+
+        let array = StructArray::new(
+            vec![a_field.clone(), b_field.clone()].into(),
+            vec![a.clone(), b.clone()],
+            Some(null_mask.clone()),
+        );
+
+        let expected = StructArray::new(
+            vec![a_field.clone(), b_field.clone()].into(),
+            vec![a_filtered.clone(), b_filtered.clone()],
+            Some(null_mask_filtered.clone()),
+        );
+
+        let result = filter(&array, &predicate).unwrap();
+
+        assert_eq!(result.to_data(), expected.to_data());
+    }
 }

Reply via email to