This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 06d147a  Add batch operations to stddev (#1547)
06d147a is described below

commit 06d147aace7a923752cef2fcf98f4732de31a6a6
Author: Lin Ma <[email protected]>
AuthorDate: Tue Jan 11 13:37:48 2022 -0800

    Add batch operations to stddev (#1547)
    
    * Initial implementation of variance
    
    * get simple f64 type tests working
    
    * add math functions to ScalarValue, some tests
    
    * add to expressions and tests
    
    * add more tests
    
    * add test for ScalarValue add
    
    * add tests for scalar arithmetic
    
    * add test, finish variance
    
    * fix warnings
    
    * add more sql tests
    
    * add stddev and tests
    
    * add the hooks and expression
    
    * add more tests
    
    * fix lint and clipy
    
    * address comments and fix test errors
    
    * address comments
    
    * add population and sample for variance and stddev
    
    * address more comments
    
    * fmt
    
    * add test for less than 2 values
    
    * fix inconsistency in the merge logic
    
    * fix lint and clipy
    
    * use batch operations
    
    * remove unused code
    
    * lint and clipy
    
    * fix s typo
    
    * clipy fix
    
    * fix lint
---
 datafusion/src/physical_plan/expressions/stddev.rs |  15 +-
 .../src/physical_plan/expressions/variance.rs      | 198 +++++++++++++++------
 2 files changed, 156 insertions(+), 57 deletions(-)

diff --git a/datafusion/src/physical_plan/expressions/stddev.rs 
b/datafusion/src/physical_plan/expressions/stddev.rs
index d6e28f1..2c85b90 100644
--- a/datafusion/src/physical_plan/expressions/stddev.rs
+++ b/datafusion/src/physical_plan/expressions/stddev.rs
@@ -25,8 +25,7 @@ use crate::physical_plan::{
     expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr, 
PhysicalExpr,
 };
 use crate::scalar::ScalarValue;
-use arrow::datatypes::DataType;
-use arrow::datatypes::Field;
+use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
 
 use super::{format_state_name, StatsType};
 
@@ -216,8 +215,8 @@ impl Accumulator for StddevAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
         Ok(vec![
             ScalarValue::from(self.variance.get_count()),
-            self.variance.get_mean(),
-            self.variance.get_m2(),
+            ScalarValue::from(self.variance.get_mean()),
+            ScalarValue::from(self.variance.get_m2()),
         ])
     }
 
@@ -229,6 +228,14 @@ impl Accumulator for StddevAccumulator {
         self.variance.merge(states)
     }
 
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.variance.update_batch(values)
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        self.variance.merge_batch(states)
+    }
+
     fn evaluate(&self) -> Result<ScalarValue> {
         let variance = self.variance.evaluate()?;
         match variance {
diff --git a/datafusion/src/physical_plan/expressions/variance.rs 
b/datafusion/src/physical_plan/expressions/variance.rs
index 3f592b0..7516440 100644
--- a/datafusion/src/physical_plan/expressions/variance.rs
+++ b/datafusion/src/physical_plan/expressions/variance.rs
@@ -23,8 +23,13 @@ use std::sync::Arc;
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
 use crate::scalar::ScalarValue;
-use arrow::datatypes::DataType;
-use arrow::datatypes::Field;
+use arrow::array::Float64Array;
+use arrow::{
+    array::{ArrayRef, UInt64Array},
+    compute::cast,
+    datatypes::DataType,
+    datatypes::Field,
+};
 
 use super::{format_state_name, StatsType};
 
@@ -209,8 +214,8 @@ impl AggregateExpr for VariancePop {
 
 #[derive(Debug)]
 pub struct VarianceAccumulator {
-    m2: ScalarValue,
-    mean: ScalarValue,
+    m2: f64,
+    mean: f64,
     count: u64,
     stats_type: StatsType,
 }
@@ -219,9 +224,9 @@ impl VarianceAccumulator {
     /// Creates a new `VarianceAccumulator`
     pub fn try_new(s_type: StatsType) -> Result<Self> {
         Ok(Self {
-            m2: ScalarValue::from(0 as f64),
-            mean: ScalarValue::from(0 as f64),
-            count: 0,
+            m2: 0_f64,
+            mean: 0_f64,
+            count: 0_u64,
             stats_type: s_type,
         })
     }
@@ -230,12 +235,12 @@ impl VarianceAccumulator {
         self.count
     }
 
-    pub fn get_mean(&self) -> ScalarValue {
-        self.mean.clone()
+    pub fn get_mean(&self) -> f64 {
+        self.mean
     }
 
-    pub fn get_m2(&self) -> ScalarValue {
-        self.m2.clone()
+    pub fn get_m2(&self) -> f64 {
+        self.m2
     }
 }
 
@@ -243,80 +248,174 @@ impl Accumulator for VarianceAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
         Ok(vec![
             ScalarValue::from(self.count),
-            self.mean.clone(),
-            self.m2.clone(),
+            ScalarValue::from(self.mean),
+            ScalarValue::from(self.m2),
         ])
     }
 
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values = &cast(&values[0], &DataType::Float64)?;
+        let arr = values.as_any().downcast_ref::<Float64Array>().unwrap();
+
+        for i in 0..arr.len() {
+            let value = arr.value(i);
+
+            if value == 0_f64 && values.is_null(i) {
+                continue;
+            }
+            let new_count = self.count + 1;
+            let delta1 = value - self.mean;
+            let new_mean = delta1 / new_count as f64 + self.mean;
+            let delta2 = value - new_mean;
+            let new_m2 = self.m2 + delta1 * delta2;
+
+            self.count += 1;
+            self.mean = new_mean;
+            self.m2 = new_m2;
+        }
+
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
+        let means = states[1].as_any().downcast_ref::<Float64Array>().unwrap();
+        let m2s = states[2].as_any().downcast_ref::<Float64Array>().unwrap();
+
+        for i in 0..counts.len() {
+            let c = counts.value(i);
+            if c == 0_u64 {
+                continue;
+            }
+            let new_count = self.count + c;
+            let new_mean = self.mean * self.count as f64 / new_count as f64
+                + means.value(i) * c as f64 / new_count as f64;
+            let delta = self.mean - means.value(i);
+            let new_m2 = self.m2
+                + m2s.value(i)
+                + delta * delta * self.count as f64 * c as f64 / new_count as 
f64;
+
+            self.count = new_count;
+            self.mean = new_mean;
+            self.m2 = new_m2;
+        }
+        Ok(())
+    }
+
     fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
         let values = &values[0];
         let is_empty = values.is_null();
+        let mean = ScalarValue::from(self.mean);
+        let m2 = ScalarValue::from(self.m2);
 
         if !is_empty {
             let new_count = self.count + 1;
-            let delta1 = ScalarValue::add(values, 
&self.mean.arithmetic_negate())?;
+            let delta1 = ScalarValue::add(values, &mean.arithmetic_negate())?;
             let new_mean = ScalarValue::add(
                 &ScalarValue::div(&delta1, &ScalarValue::from(new_count as 
f64))?,
-                &self.mean,
+                &mean,
             )?;
             let delta2 = ScalarValue::add(values, 
&new_mean.arithmetic_negate())?;
             let tmp = ScalarValue::mul(&delta1, &delta2)?;
 
-            let new_m2 = ScalarValue::add(&self.m2, &tmp)?;
+            let new_m2 = ScalarValue::add(&m2, &tmp)?;
             self.count += 1;
-            self.mean = new_mean;
-            self.m2 = new_m2;
+
+            if let ScalarValue::Float64(Some(c)) = new_mean {
+                self.mean = c;
+            } else {
+                unreachable!()
+            };
+            if let ScalarValue::Float64(Some(m)) = new_m2 {
+                self.m2 = m;
+            } else {
+                unreachable!()
+            };
         }
 
         Ok(())
     }
 
     fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
-        let count = &states[0];
-        let mean = &states[1];
-        let m2 = &states[2];
+        let count;
+        let mean;
+        let m2;
         let mut new_count: u64 = self.count;
 
-        // counts are summed
-        if let ScalarValue::UInt64(Some(c)) = count {
-            if *c == 0_u64 {
-                return Ok(());
-            }
+        if let ScalarValue::UInt64(Some(c)) = states[0] {
+            count = c;
+        } else {
+            unreachable!()
+        };
 
-            if self.count == 0 {
-                self.count = *c;
-                self.mean = mean.clone();
-                self.m2 = m2.clone();
-                return Ok(());
-            }
-            new_count += c
+        if count == 0_u64 {
+            return Ok(());
+        }
+
+        if let ScalarValue::Float64(Some(m)) = states[1] {
+            mean = m;
         } else {
             unreachable!()
         };
+        if let ScalarValue::Float64(Some(n)) = states[2] {
+            m2 = n;
+        } else {
+            unreachable!()
+        };
+
+        if self.count == 0 {
+            self.count = count;
+            self.mean = mean;
+            self.m2 = m2;
+            return Ok(());
+        }
 
-        let new_mean = ScalarValue::div(
-            &ScalarValue::add(&self.mean, mean)?,
-            &ScalarValue::from(2_f64),
+        new_count += count;
+
+        let mean1 = ScalarValue::from(self.mean);
+        let mean2 = ScalarValue::from(mean);
+
+        let new_mean = ScalarValue::add(
+            &ScalarValue::div(
+                &ScalarValue::mul(&mean1, &ScalarValue::from(self.count))?,
+                &ScalarValue::from(new_count as f64),
+            )?,
+            &ScalarValue::div(
+                &ScalarValue::mul(&mean2, &ScalarValue::from(count))?,
+                &ScalarValue::from(new_count as f64),
+            )?,
         )?;
-        let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?;
+
+        let delta = ScalarValue::add(&mean2.arithmetic_negate(), &mean1)?;
         let delta_sqrt = ScalarValue::mul(&delta, &delta)?;
         let new_m2 = ScalarValue::add(
             &ScalarValue::add(
                 &ScalarValue::mul(
                     &delta_sqrt,
                     &ScalarValue::div(
-                        &ScalarValue::mul(&ScalarValue::from(self.count), 
count)?,
+                        &ScalarValue::mul(
+                            &ScalarValue::from(self.count),
+                            &ScalarValue::from(count),
+                        )?,
                         &ScalarValue::from(new_count as f64),
                     )?,
                 )?,
-                &self.m2,
+                &ScalarValue::from(self.m2),
             )?,
-            m2,
+            &ScalarValue::from(m2),
         )?;
 
         self.count = new_count;
-        self.mean = new_mean;
-        self.m2 = new_m2;
+        if let ScalarValue::Float64(Some(c)) = new_mean {
+            self.mean = c;
+        } else {
+            unreachable!()
+        };
+        if let ScalarValue::Float64(Some(m)) = new_m2 {
+            self.m2 = m;
+        } else {
+            unreachable!()
+        };
 
         Ok(())
     }
@@ -339,17 +438,10 @@ impl Accumulator for VarianceAccumulator {
             ));
         }
 
-        match self.m2 {
-            ScalarValue::Float64(e) => {
-                if self.count == 0 {
-                    Ok(ScalarValue::Float64(None))
-                } else {
-                    Ok(ScalarValue::Float64(e.map(|f| f / count as f64)))
-                }
-            }
-            _ => Err(DataFusionError::Internal(
-                "M2 should be f64 for variance".to_string(),
-            )),
+        if self.count == 0 {
+            Ok(ScalarValue::Float64(None))
+        } else {
+            Ok(ScalarValue::Float64(Some(self.m2 / count as f64)))
         }
     }
 }

Reply via email to