This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new e9a7fe5764 Add `FilterPredicate::filter_record_batch` (#8693)
e9a7fe5764 is described below
commit e9a7fe576449e86541103423ca40c8c47ff3ec39
Author: Pepijn Van Eeckhoudt <[email protected]>
AuthorDate: Fri Oct 24 19:47:15 2025 +0200
Add `FilterPredicate::filter_record_batch` (#8693)
# Which issue does this PR close?
- Closes #8692.
# Rationale for this change
Explained in issue.
# What changes are included in this PR?
- Adds `FilterPredicate::filter_record_batch`
- Adapts the free function `filter_record_batch` to use the new function
- Uses `new_unchecked` to create the filtered result. The rationale for
this is identical to #8583
# Are these changes tested?
Covered by existing tests for `filter_record_batch`
# Are there any user-facing changes?
No
---------
Co-authored-by: Martin Grigorov <[email protected]>
---
arrow-select/src/filter.rs | 52 ++++++++++++++++++++++++++++++++++++++--------
1 file changed, 43 insertions(+), 9 deletions(-)
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index dace2bab72..5c21a4adca 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -122,6 +122,12 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) ->
BooleanArray {
/// Returns a filtered `values` [`Array`] where the corresponding elements of
/// `predicate` are `true`.
///
+/// If multiple arrays (or record batches) need to be filtered using the same
predicate array,
+/// consider using [FilterBuilder] to create a single [FilterPredicate] and
then
+/// calling [FilterPredicate::filter_record_batch].
+/// In contrast to this function, it is then the responsibility of the caller
+/// to use [FilterBuilder::optimize] if appropriate.
+///
/// # See also
/// * [`FilterBuilder`] for more control over the filtering process.
/// * [`filter_record_batch`] to filter a [`RecordBatch`]
@@ -168,25 +174,28 @@ fn multiple_arrays(data_type: &DataType) -> bool {
/// `predicate` are true.
///
/// This is the equivalent of calling [filter] on each column of the
[RecordBatch].
+///
+/// If multiple record batches (or arrays) need to be filtered using the same
predicate array,
+/// consider using [FilterBuilder] to create a single [FilterPredicate] and
then
+/// calling [FilterPredicate::filter_record_batch].
+/// In contrast to this function, it is then the responsibility of the caller
+/// to use [FilterBuilder::optimize] if appropriate.
pub fn filter_record_batch(
record_batch: &RecordBatch,
predicate: &BooleanArray,
) -> Result<RecordBatch, ArrowError> {
let mut filter_builder = FilterBuilder::new(predicate);
- if record_batch.num_columns() > 1 {
- // Only optimize if filtering more than one column
+ let num_cols = record_batch.num_columns();
+ if num_cols > 1
+ || (num_cols > 0 &&
multiple_arrays(record_batch.schema_ref().field(0).data_type()))
+ {
+ // Only optimize if filtering more than one column or if the column
contains multiple internal arrays
// Otherwise, the overhead of optimization can be more than the benefit
filter_builder = filter_builder.optimize();
}
let filter = filter_builder.build();
- let filtered_arrays = record_batch
- .columns()
- .iter()
- .map(|a| filter_array(a, &filter))
- .collect::<Result<Vec<_>, _>>()?;
- let options =
RecordBatchOptions::default().with_row_count(Some(filter.count()));
- RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays,
&options)
+ filter.filter_record_batch(record_batch)
}
/// A builder to construct [`FilterPredicate`]
@@ -300,6 +309,31 @@ impl FilterPredicate {
filter_array(values, self)
}
+ /// Returns a filtered [`RecordBatch`] containing only the rows that are
selected by this
+ /// [`FilterPredicate`].
+ ///
+ /// This is the equivalent of calling [filter] on each column of the
[`RecordBatch`].
+ pub fn filter_record_batch(
+ &self,
+ record_batch: &RecordBatch,
+ ) -> Result<RecordBatch, ArrowError> {
+ let filtered_arrays = record_batch
+ .columns()
+ .iter()
+ .map(|a| filter_array(a, self))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ // SAFETY: we know that the set of filtered arrays will match the
schema of the original
+ // record batch
+ unsafe {
+ Ok(RecordBatch::new_unchecked(
+ record_batch.schema(),
+ filtered_arrays,
+ self.count,
+ ))
+ }
+ }
+
/// Number of rows being selected based on this [`FilterPredicate`]
pub fn count(&self) -> usize {
self.count