This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 36a1361807 Implement GroupsAccumulator for corr(x,y) aggregate 
function (#13581)
36a1361807 is described below

commit 36a1361807060b5221291e5c8a7d59d7acf7954a
Author: Yongting You <[email protected]>
AuthorDate: Fri Dec 13 04:25:49 2024 +0800

    Implement GroupsAccumulator for corr(x,y) aggregate function (#13581)
    
    * Implement GroupsAccumulator for corr(x,y)
    
    * feedbacks
    
    * fix CI MSRV
    
    * review
    
    * avoid collect in accumulation
    
    * add back cast
---
 .../src/aggregate/groups_accumulator/accumulate.rs | 174 ++++++++++-
 datafusion/functions-aggregate/src/correlation.rs  | 327 ++++++++++++++++++++-
 2 files changed, 499 insertions(+), 2 deletions(-)

diff --git 
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
 
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
index ac4d0e7553..e629e99e16 100644
--- 
a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
+++ 
b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs
@@ -371,6 +371,75 @@ pub fn accumulate<T, F>(
     }
 }
 
+/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
+///
+/// This method assumes that for any input record index, if any of the value 
column
+/// is null, or it's filtered out by `opt_filter`, then the record would be 
ignored.
+/// (won't be accumulated by `value_fn`)
+///
+/// # Arguments
+///
+/// * `group_indices` - To which groups do the rows in `value_columns` belong
+/// * `value_columns` - The input arrays to accumulate
+/// * `opt_filter` - Optional filter array. If present, only rows where filter 
is `Some(true)` are included
+/// * `value_fn` - Callback function for each valid row, with parameters:
+///     * `group_idx`: The group index for the current row
+///     * `batch_idx`: The index of the current row in the input arrays
+///     * `columns`: Reference to all input arrays for accessing values
+pub fn accumulate_multiple<T, F>(
+    group_indices: &[usize],
+    value_columns: &[&PrimitiveArray<T>],
+    opt_filter: Option<&BooleanArray>,
+    mut value_fn: F,
+) where
+    T: ArrowPrimitiveType + Send,
+    F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
+{
+    // Calculate `valid_indices` to accumulate, non-valid indices are ignored.
+    // `valid_indices` is a bit mask corresponding to the `group_indices`. An 
index
+    // is considered valid if:
+    // 1. All columns are non-null at this index.
+    // 2. Not filtered out by `opt_filter`
+
+    // Take AND from all null buffers of `value_columns`.
+    let combined_nulls = value_columns
+        .iter()
+        .map(|arr| arr.logical_nulls())
+        .fold(None, |acc, nulls| {
+            NullBuffer::union(acc.as_ref(), nulls.as_ref())
+        });
+
+    // Take AND from previous combined nulls and `opt_filter`.
+    let valid_indices = match (combined_nulls, opt_filter) {
+        (None, None) => None,
+        (None, Some(filter)) => Some(filter.clone()),
+        (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), 
None)),
+        (Some(nulls), Some(filter)) => {
+            let combined = nulls.inner() & filter.values();
+            Some(BooleanArray::new(combined, None))
+        }
+    };
+
+    for col in value_columns.iter() {
+        debug_assert_eq!(col.len(), group_indices.len());
+    }
+
+    match valid_indices {
+        None => {
+            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
+                value_fn(group_idx, batch_idx, value_columns);
+            }
+        }
+        Some(valid_indices) => {
+            for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
+                if valid_indices.value(batch_idx) {
+                    value_fn(group_idx, batch_idx, value_columns);
+                }
+            }
+        }
+    }
+}
+
 /// This function is called to update the accumulator state per row
 /// when the value is not needed (e.g. COUNT)
 ///
@@ -528,7 +597,7 @@ fn initialize_builder(
 mod test {
     use super::*;
 
-    use arrow::array::UInt32Array;
+    use arrow::array::{Int32Array, UInt32Array};
     use rand::{rngs::ThreadRng, Rng};
     use std::collections::HashSet;
 
@@ -940,4 +1009,107 @@ mod test {
                 .collect()
         }
     }
+
+    #[test]
+    fn test_accumulate_multiple_no_nulls_no_filter() {
+        let group_indices = vec![0, 1, 0, 1];
+        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
+        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
+        let value_columns = [values1, values2];
+
+        let mut accumulated = vec![];
+        accumulate_multiple(
+            &group_indices,
+            &value_columns.iter().collect::<Vec<_>>(),
+            None,
+            |group_idx, batch_idx, columns| {
+                let values = columns.iter().map(|col| 
col.value(batch_idx)).collect();
+                accumulated.push((group_idx, values));
+            },
+        );
+
+        let expected = vec![
+            (0, vec![1, 10]),
+            (1, vec![2, 20]),
+            (0, vec![3, 30]),
+            (1, vec![4, 40]),
+        ];
+        assert_eq!(accumulated, expected);
+    }
+
+    #[test]
+    fn test_accumulate_multiple_with_nulls() {
+        let group_indices = vec![0, 1, 0, 1];
+        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
+        let values2 = Int32Array::from(vec![Some(10), Some(20), None, 
Some(40)]);
+        let value_columns = [values1, values2];
+
+        let mut accumulated = vec![];
+        accumulate_multiple(
+            &group_indices,
+            &value_columns.iter().collect::<Vec<_>>(),
+            None,
+            |group_idx, batch_idx, columns| {
+                let values = columns.iter().map(|col| 
col.value(batch_idx)).collect();
+                accumulated.push((group_idx, values));
+            },
+        );
+
+        // Only rows where both columns are non-null should be accumulated
+        let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
+        assert_eq!(accumulated, expected);
+    }
+
+    #[test]
+    fn test_accumulate_multiple_with_filter() {
+        let group_indices = vec![0, 1, 0, 1];
+        let values1 = Int32Array::from(vec![1, 2, 3, 4]);
+        let values2 = Int32Array::from(vec![10, 20, 30, 40]);
+        let value_columns = [values1, values2];
+
+        let filter = BooleanArray::from(vec![true, false, true, false]);
+
+        let mut accumulated = vec![];
+        accumulate_multiple(
+            &group_indices,
+            &value_columns.iter().collect::<Vec<_>>(),
+            Some(&filter),
+            |group_idx, batch_idx, columns| {
+                let values = columns.iter().map(|col| 
col.value(batch_idx)).collect();
+                accumulated.push((group_idx, values));
+            },
+        );
+
+        // Only rows where filter is true should be accumulated
+        let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
+        assert_eq!(accumulated, expected);
+    }
+
+    #[test]
+    fn test_accumulate_multiple_with_nulls_and_filter() {
+        let group_indices = vec![0, 1, 0, 1];
+        let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
+        let values2 = Int32Array::from(vec![Some(10), Some(20), None, 
Some(40)]);
+        let value_columns = [values1, values2];
+
+        let filter = BooleanArray::from(vec![true, true, true, false]);
+
+        let mut accumulated = vec![];
+        accumulate_multiple(
+            &group_indices,
+            &value_columns.iter().collect::<Vec<_>>(),
+            Some(&filter),
+            |group_idx, batch_idx, columns| {
+                let values = columns.iter().map(|col| 
col.value(batch_idx)).collect();
+                accumulated.push((group_idx, values));
+            },
+        );
+
+        // Only rows where both:
+        // 1. Filter is true
+        // 2. Both columns are non-null
+        // should be accumulated
+        let expected = [(0, vec![1, 10])];
+        assert_eq!(accumulated, expected);
+    }
 }
diff --git a/datafusion/functions-aggregate/src/correlation.rs 
b/datafusion/functions-aggregate/src/correlation.rs
index a0ccdb0ae7..72c1f6dbae 100644
--- a/datafusion/functions-aggregate/src/correlation.rs
+++ b/datafusion/functions-aggregate/src/correlation.rs
@@ -22,11 +22,19 @@ use std::fmt::Debug;
 use std::mem::size_of_val;
 use std::sync::Arc;
 
-use arrow::compute::{and, filter, is_not_null};
+use arrow::array::{
+    downcast_array, Array, AsArray, BooleanArray, BooleanBufferBuilder, 
Float64Array,
+    UInt64Array,
+};
+use arrow::compute::{and, filter, is_not_null, kernels::cast};
+use arrow::datatypes::{Float64Type, UInt64Type};
 use arrow::{
     array::ArrayRef,
     datatypes::{DataType, Field},
 };
+use datafusion_expr::{EmitTo, GroupsAccumulator};
+use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple;
+use log::debug;
 
 use crate::covariance::CovarianceAccumulator;
 use crate::stddev::StddevAccumulator;
@@ -128,6 +136,18 @@ impl AggregateUDFImpl for Correlation {
     fn documentation(&self) -> Option<&Documentation> {
         self.doc()
     }
+
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
+        true
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        _args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        debug!("GroupsAccumulator is created for aggregate function `corr(c1, 
c2)`");
+        Ok(Box::new(CorrelationGroupsAccumulator::new()))
+    }
 }
 
 /// An accumulator to compute correlation
@@ -252,3 +272,308 @@ impl Accumulator for CorrelationAccumulator {
         Ok(())
     }
 }
+
+#[derive(Default)]
+pub struct CorrelationGroupsAccumulator {
+    // Number of elements for each group
+    // This is also used to track nulls: if a group has 0 valid values 
accumulated,
+    // final aggregation result will be null.
+    count: Vec<u64>,
+    // Sum of x values for each group
+    sum_x: Vec<f64>,
+    // Sum of y
+    sum_y: Vec<f64>,
+    // Sum of x*y
+    sum_xy: Vec<f64>,
+    // Sum of x^2
+    sum_xx: Vec<f64>,
+    // Sum of y^2
+    sum_yy: Vec<f64>,
+}
+
+impl CorrelationGroupsAccumulator {
+    pub fn new() -> Self {
+        Default::default()
+    }
+}
+
+/// Specialized version of `accumulate_multiple` for correlation's merge_batch
+///
+/// Note: Arrays in `state_arrays` should not have null values, because they 
are all
+/// intermediate states created within the accumulator, instead of inputs from
+/// outside.
+fn accumulate_correlation_states(
+    group_indices: &[usize],
+    state_arrays: (
+        &UInt64Array,  // count
+        &Float64Array, // sum_x
+        &Float64Array, // sum_y
+        &Float64Array, // sum_xy
+        &Float64Array, // sum_xx
+        &Float64Array, // sum_yy
+    ),
+    mut value_fn: impl FnMut(usize, u64, &[f64]),
+) {
+    let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays;
+
+    assert_eq!(counts.null_count(), 0);
+    assert_eq!(sum_x.null_count(), 0);
+    assert_eq!(sum_y.null_count(), 0);
+    assert_eq!(sum_xy.null_count(), 0);
+    assert_eq!(sum_xx.null_count(), 0);
+    assert_eq!(sum_yy.null_count(), 0);
+
+    let counts_values = counts.values().as_ref();
+    let sum_x_values = sum_x.values().as_ref();
+    let sum_y_values = sum_y.values().as_ref();
+    let sum_xy_values = sum_xy.values().as_ref();
+    let sum_xx_values = sum_xx.values().as_ref();
+    let sum_yy_values = sum_yy.values().as_ref();
+
+    for (idx, &group_idx) in group_indices.iter().enumerate() {
+        let row = [
+            sum_x_values[idx],
+            sum_y_values[idx],
+            sum_xy_values[idx],
+            sum_xx_values[idx],
+            sum_yy_values[idx],
+        ];
+        value_fn(group_idx, counts_values[idx], &row);
+    }
+}
+
+/// GroupsAccumulator implementation for `corr(x, y)` that computes the 
Pearson correlation coefficient
+/// between two numeric columns.
+///
+/// Online algorithm for correlation:
+///
+/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * 
sum_yy - sum_y^2))
+/// where:
+/// n = number of observations
+/// sum_x = sum of x values
+/// sum_y = sum of y values  
+/// sum_xy = sum of (x * y)
+/// sum_xx = sum of x^2 values
+/// sum_yy = sum of y^2 values
+///
+/// Reference: 
<https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#For_a_sample>
+impl GroupsAccumulator for CorrelationGroupsAccumulator {
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.count.resize(total_num_groups, 0);
+        self.sum_x.resize(total_num_groups, 0.0);
+        self.sum_y.resize(total_num_groups, 0.0);
+        self.sum_xy.resize(total_num_groups, 0.0);
+        self.sum_xx.resize(total_num_groups, 0.0);
+        self.sum_yy.resize(total_num_groups, 0.0);
+
+        let array_x = &cast(&values[0], &DataType::Float64)?;
+        let array_x = downcast_array::<Float64Array>(array_x);
+        let array_y = &cast(&values[1], &DataType::Float64)?;
+        let array_y = downcast_array::<Float64Array>(array_y);
+
+        accumulate_multiple(
+            group_indices,
+            &[&array_x, &array_y],
+            opt_filter,
+            |group_index, batch_index, columns| {
+                let x = columns[0].value(batch_index);
+                let y = columns[1].value(batch_index);
+                self.count[group_index] += 1;
+                self.sum_x[group_index] += x;
+                self.sum_y[group_index] += y;
+                self.sum_xy[group_index] += x * y;
+                self.sum_xx[group_index] += x * x;
+                self.sum_yy[group_index] += y * y;
+            },
+        );
+
+        Ok(())
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        // Resize vectors to accommodate total number of groups
+        self.count.resize(total_num_groups, 0);
+        self.sum_x.resize(total_num_groups, 0.0);
+        self.sum_y.resize(total_num_groups, 0.0);
+        self.sum_xy.resize(total_num_groups, 0.0);
+        self.sum_xx.resize(total_num_groups, 0.0);
+        self.sum_yy.resize(total_num_groups, 0.0);
+
+        // Extract arrays from input values
+        let partial_counts = values[0].as_primitive::<UInt64Type>();
+        let partial_sum_x = values[1].as_primitive::<Float64Type>();
+        let partial_sum_y = values[2].as_primitive::<Float64Type>();
+        let partial_sum_xy = values[3].as_primitive::<Float64Type>();
+        let partial_sum_xx = values[4].as_primitive::<Float64Type>();
+        let partial_sum_yy = values[5].as_primitive::<Float64Type>();
+
+        assert!(opt_filter.is_none(), "aggregate filter should be applied in 
partial stage, there should be no filter in final stage");
+
+        accumulate_correlation_states(
+            group_indices,
+            (
+                partial_counts,
+                partial_sum_x,
+                partial_sum_y,
+                partial_sum_xy,
+                partial_sum_xx,
+                partial_sum_yy,
+            ),
+            |group_index, count, values| {
+                self.count[group_index] += count;
+                self.sum_x[group_index] += values[0];
+                self.sum_y[group_index] += values[1];
+                self.sum_xy[group_index] += values[2];
+                self.sum_xx[group_index] += values[3];
+                self.sum_yy[group_index] += values[4];
+            },
+        );
+
+        Ok(())
+    }
+
+    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+        let n = match emit_to {
+            EmitTo::All => self.count.len(),
+            EmitTo::First(n) => n,
+        };
+
+        let mut values = Vec::with_capacity(n);
+        let mut nulls = BooleanBufferBuilder::new(n);
+
+        // Notes for `Null` handling:
+        // - If the `count` state of a group is 0, no valid records are 
accumulated
+        //   for this group, so the aggregation result is `Null`.
+        // - Correlation can't be calculated when a group only has 1 record, 
or when
+        //   the `denominator` state is 0. In these cases, the final 
aggregation
+        //   result should be `Null` (according to PostgreSQL's behavior).
+        //
+        // TODO: Old datafusion implementation returns 0.0 for these invalid 
cases.
+        // Update this to match PostgreSQL's behavior.
+        for i in 0..n {
+            if self.count[i] < 2 {
+                // TODO: Evaluate as `Null` (see notes above)
+                values.push(0.0);
+                nulls.append(false);
+                continue;
+            }
+
+            let count = self.count[i];
+            let sum_x = self.sum_x[i];
+            let sum_y = self.sum_y[i];
+            let sum_xy = self.sum_xy[i];
+            let sum_xx = self.sum_xx[i];
+            let sum_yy = self.sum_yy[i];
+
+            let mean_x = sum_x / count as f64;
+            let mean_y = sum_y / count as f64;
+
+            let numerator = sum_xy - sum_x * mean_y;
+            let denominator =
+                ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt();
+
+            if denominator == 0.0 {
+                // TODO: Evaluate as `Null` (see notes above)
+                values.push(0.0);
+                nulls.append(false);
+            } else {
+                values.push(numerator / denominator);
+                nulls.append(true);
+            }
+        }
+
+        Ok(Arc::new(Float64Array::new(
+            values.into(),
+            Some(nulls.finish().into()),
+        )))
+    }
+
+    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+        let n = match emit_to {
+            EmitTo::All => self.count.len(),
+            EmitTo::First(n) => n,
+        };
+
+        Ok(vec![
+            Arc::new(UInt64Array::from(self.count[0..n].to_vec())),
+            Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())),
+            Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())),
+            Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())),
+            Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())),
+            Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())),
+        ])
+    }
+
+    fn size(&self) -> usize {
+        size_of_val(&self.count)
+            + size_of_val(&self.sum_x)
+            + size_of_val(&self.sum_y)
+            + size_of_val(&self.sum_xy)
+            + size_of_val(&self.sum_xx)
+            + size_of_val(&self.sum_yy)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::array::{Float64Array, UInt64Array};
+
+    #[test]
+    fn test_accumulate_correlation_states() {
+        // Test data
+        let group_indices = vec![0, 1, 0, 1];
+        let counts = UInt64Array::from(vec![1, 2, 3, 4]);
+        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
+        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
+        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
+        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
+        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
+
+        let mut accumulated = vec![];
+        accumulate_correlation_states(
+            &group_indices,
+            (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
+            |group_idx, count, values| {
+                accumulated.push((group_idx, count, values.to_vec()));
+            },
+        );
+
+        let expected = vec![
+            (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]),
+            (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]),
+            (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]),
+            (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]),
+        ];
+        assert_eq!(accumulated, expected);
+
+        // Test that function panics with null values
+        let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]);
+        let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]);
+        let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]);
+        let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]);
+        let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]);
+        let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]);
+
+        let result = std::panic::catch_unwind(|| {
+            accumulate_correlation_states(
+                &group_indices,
+                (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy),
+                |_, _, _| {},
+            )
+        });
+        assert!(result.is_err());
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to