This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 831a0804bf Specialize filter for structs and sparse unions (#6304)
831a0804bf is described below
commit 831a0804bf652bd7ae2773394ca21dd43e78b09b
Author: gstvg <[email protected]>
AuthorDate: Sat Aug 31 10:08:48 2024 -0300
Specialize filter for structs and sparse unions (#6304)
* specialize filter for structs and sparse unions
* fix: move nested function to top level
* fix: clarify optimization cases
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
arrow-array/src/cast.rs | 16 +++++
arrow-select/src/filter.rs | 141 ++++++++++++++++++++++++++++++++++++++++++++-
2 files changed, 156 insertions(+), 1 deletion(-)
diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs
index 7b4b1d6eca..cda179b78c 100644
--- a/arrow-array/src/cast.rs
+++ b/arrow-array/src/cast.rs
@@ -815,6 +815,14 @@ pub trait AsArray: private::Sealed {
self.as_struct_opt().expect("struct array")
}
+ /// Downcast this to a [`UnionArray`] returning `None` if not possible
+ fn as_union_opt(&self) -> Option<&UnionArray>;
+
+ /// Downcast this to a [`UnionArray`] panicking if not possible
+ fn as_union(&self) -> &UnionArray {
+ self.as_union_opt().expect("union array")
+ }
+
/// Downcast this to a [`GenericListArray`] returning `None` if not
possible
fn as_list_opt<O: OffsetSizeTrait>(&self) -> Option<&GenericListArray<O>>;
@@ -888,6 +896,10 @@ impl AsArray for dyn Array + '_ {
self.as_any().downcast_ref()
}
+ fn as_union_opt(&self) -> Option<&UnionArray> {
+ self.as_any().downcast_ref()
+ }
+
fn as_list_opt<O: OffsetSizeTrait>(&self) -> Option<&GenericListArray<O>> {
self.as_any().downcast_ref()
}
@@ -939,6 +951,10 @@ impl AsArray for ArrayRef {
self.as_ref().as_struct_opt()
}
+ fn as_union_opt(&self) -> Option<&UnionArray> {
+ self.as_any().downcast_ref()
+ }
+
fn as_list_opt<O: OffsetSizeTrait>(&self) -> Option<&GenericListArray<O>> {
self.as_ref().as_list_opt()
}
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index c51f44a977..e07b03d1f2 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -169,10 +169,29 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) ->
BooleanArray {
/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
/// ```
pub fn filter(values: &dyn Array, predicate: &BooleanArray) ->
Result<ArrayRef, ArrowError> {
- let predicate = FilterBuilder::new(predicate).build();
+ let mut filter_builder = FilterBuilder::new(predicate);
+
+ if multiple_arrays(values.data_type()) {
+ // Only optimize if filtering more than one array
+ // Otherwise, the overhead of optimization can be more than the benefit
+ filter_builder = filter_builder.optimize();
+ }
+
+ let predicate = filter_builder.build();
+
filter_array(values, &predicate)
}
+fn multiple_arrays(data_type: &DataType) -> bool {
+ match data_type {
+ DataType::Struct(fields) => {
+ fields.len() > 1 || fields.len() == 1 &&
multiple_arrays(fields[0].data_type())
+ }
+ DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
+ _ => false,
+ }
+}
+
/// Returns a filtered [RecordBatch] where the corresponding elements of
/// `predicate` are true.
///
@@ -365,6 +384,12 @@ fn filter_array(values: &dyn Array, predicate:
&FilterPredicate) -> Result<Array
values => Ok(Arc::new(filter_dict(values, predicate))),
t => unimplemented!("Filter not supported for dictionary type
{:?}", t)
}
+ DataType::Struct(_) => {
+ Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
+ }
+ DataType::Union(_, UnionMode::Sparse) => {
+ Ok(Arc::new(filter_sparse_union(values.as_union(),
predicate)?))
+ }
_ => {
let data = values.to_data();
// fallback to using MutableArrayData
@@ -789,6 +814,49 @@ where
DictionaryArray::from(unsafe { builder.build_unchecked() })
}
+/// `filter` implementation for structs
+fn filter_struct(
+ array: &StructArray,
+ predicate: &FilterPredicate,
+) -> Result<StructArray, ArrowError> {
+ let columns = array
+ .columns()
+ .iter()
+ .map(|column| filter_array(column, predicate))
+ .collect::<Result<_, _>>()?;
+
+ let nulls = if let Some((null_count, nulls)) =
filter_null_mask(array.nulls(), predicate) {
+ let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
+
+ Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
+ } else {
+ None
+ };
+
+ Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns,
nulls) })
+}
+
+/// `filter` implementation for sparse unions
+fn filter_sparse_union(
+ array: &UnionArray,
+ predicate: &FilterPredicate,
+) -> Result<UnionArray, ArrowError> {
+ let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
+ unreachable!()
+ };
+
+ let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(),
None), predicate);
+
+ let children = fields
+ .iter()
+ .map(|(child_type_id, _)| filter_array(array.child(child_type_id),
predicate))
+ .collect::<Result<_, _>>()?;
+
+ Ok(unsafe {
+ UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1,
None, children)
+ })
+}
+
#[cfg(test)]
mod tests {
use arrow_array::builder::*;
@@ -1878,4 +1946,75 @@ mod tests {
}
}
}
+
+ #[test]
+ fn test_filter_struct() {
+ let predicate = BooleanArray::from(vec![true, false, true, false]);
+
+ let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
+ let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
+
+ let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
+ let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
+
+ let null_mask = NullBuffer::from(vec![true, false, false, true]);
+ let null_mask_filtered = NullBuffer::from(vec![true, false]);
+
+ let a_field = Field::new("a", DataType::Utf8, false);
+ let b_field = Field::new("b", DataType::Int32, false);
+
+ let array = StructArray::new(vec![a_field.clone()].into(),
vec![a.clone()], None);
+ let expected =
+ StructArray::new(vec![a_field.clone()].into(),
vec![a_filtered.clone()], None);
+
+ let result = filter(&array, &predicate).unwrap();
+
+ assert_eq!(result.to_data(), expected.to_data());
+
+ let array = StructArray::new(
+ vec![a_field.clone()].into(),
+ vec![a.clone()],
+ Some(null_mask.clone()),
+ );
+ let expected = StructArray::new(
+ vec![a_field.clone()].into(),
+ vec![a_filtered.clone()],
+ Some(null_mask_filtered.clone()),
+ );
+
+ let result = filter(&array, &predicate).unwrap();
+
+ assert_eq!(result.to_data(), expected.to_data());
+
+ let array = StructArray::new(
+ vec![a_field.clone(), b_field.clone()].into(),
+ vec![a.clone(), b.clone()],
+ None,
+ );
+ let expected = StructArray::new(
+ vec![a_field.clone(), b_field.clone()].into(),
+ vec![a_filtered.clone(), b_filtered.clone()],
+ None,
+ );
+
+ let result = filter(&array, &predicate).unwrap();
+
+ assert_eq!(result.to_data(), expected.to_data());
+
+ let array = StructArray::new(
+ vec![a_field.clone(), b_field.clone()].into(),
+ vec![a.clone(), b.clone()],
+ Some(null_mask.clone()),
+ );
+
+ let expected = StructArray::new(
+ vec![a_field.clone(), b_field.clone()].into(),
+ vec![a_filtered.clone(), b_filtered.clone()],
+ Some(null_mask_filtered.clone()),
+ );
+
+ let result = filter(&array, &predicate).unwrap();
+
+ assert_eq!(result.to_data(), expected.to_data());
+ }
}