This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 06d147a Add batch operations to stddev (#1547)
06d147a is described below
commit 06d147aace7a923752cef2fcf98f4732de31a6a6
Author: Lin Ma <[email protected]>
AuthorDate: Tue Jan 11 13:37:48 2022 -0800
Add batch operations to stddev (#1547)
* Initial implementation of variance
* get simple f64 type tests working
* add math functions to ScalarValue, some tests
* add to expressions and tests
* add more tests
* add test for ScalarValue add
* add tests for scalar arithmetic
* add test, finish variance
* fix warnings
* add more sql tests
* add stddev and tests
* add the hooks and expression
* add more tests
* fix lint and clipy
* address comments and fix test errors
* address comments
* add population and sample for variance and stddev
* address more comments
* fmt
* add test for less than 2 values
* fix inconsistency in the merge logic
* fix lint and clipy
* use batch operations
* remove unused code
* lint and clipy
* fix s typo
* clipy fix
* fix lint
---
datafusion/src/physical_plan/expressions/stddev.rs | 15 +-
.../src/physical_plan/expressions/variance.rs | 198 +++++++++++++++------
2 files changed, 156 insertions(+), 57 deletions(-)
diff --git a/datafusion/src/physical_plan/expressions/stddev.rs
b/datafusion/src/physical_plan/expressions/stddev.rs
index d6e28f1..2c85b90 100644
--- a/datafusion/src/physical_plan/expressions/stddev.rs
+++ b/datafusion/src/physical_plan/expressions/stddev.rs
@@ -25,8 +25,7 @@ use crate::physical_plan::{
expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr,
PhysicalExpr,
};
use crate::scalar::ScalarValue;
-use arrow::datatypes::DataType;
-use arrow::datatypes::Field;
+use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
use super::{format_state_name, StatsType};
@@ -216,8 +215,8 @@ impl Accumulator for StddevAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.variance.get_count()),
- self.variance.get_mean(),
- self.variance.get_m2(),
+ ScalarValue::from(self.variance.get_mean()),
+ ScalarValue::from(self.variance.get_m2()),
])
}
@@ -229,6 +228,14 @@ impl Accumulator for StddevAccumulator {
self.variance.merge(states)
}
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ self.variance.update_batch(values)
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ self.variance.merge_batch(states)
+ }
+
fn evaluate(&self) -> Result<ScalarValue> {
let variance = self.variance.evaluate()?;
match variance {
diff --git a/datafusion/src/physical_plan/expressions/variance.rs
b/datafusion/src/physical_plan/expressions/variance.rs
index 3f592b0..7516440 100644
--- a/datafusion/src/physical_plan/expressions/variance.rs
+++ b/datafusion/src/physical_plan/expressions/variance.rs
@@ -23,8 +23,13 @@ use std::sync::Arc;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
-use arrow::datatypes::DataType;
-use arrow::datatypes::Field;
+use arrow::array::Float64Array;
+use arrow::{
+ array::{ArrayRef, UInt64Array},
+ compute::cast,
+ datatypes::DataType,
+ datatypes::Field,
+};
use super::{format_state_name, StatsType};
@@ -209,8 +214,8 @@ impl AggregateExpr for VariancePop {
#[derive(Debug)]
pub struct VarianceAccumulator {
- m2: ScalarValue,
- mean: ScalarValue,
+ m2: f64,
+ mean: f64,
count: u64,
stats_type: StatsType,
}
@@ -219,9 +224,9 @@ impl VarianceAccumulator {
/// Creates a new `VarianceAccumulator`
pub fn try_new(s_type: StatsType) -> Result<Self> {
Ok(Self {
- m2: ScalarValue::from(0 as f64),
- mean: ScalarValue::from(0 as f64),
- count: 0,
+ m2: 0_f64,
+ mean: 0_f64,
+ count: 0_u64,
stats_type: s_type,
})
}
@@ -230,12 +235,12 @@ impl VarianceAccumulator {
self.count
}
- pub fn get_mean(&self) -> ScalarValue {
- self.mean.clone()
+ pub fn get_mean(&self) -> f64 {
+ self.mean
}
- pub fn get_m2(&self) -> ScalarValue {
- self.m2.clone()
+ pub fn get_m2(&self) -> f64 {
+ self.m2
}
}
@@ -243,80 +248,174 @@ impl Accumulator for VarianceAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
- self.mean.clone(),
- self.m2.clone(),
+ ScalarValue::from(self.mean),
+ ScalarValue::from(self.m2),
])
}
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = &cast(&values[0], &DataType::Float64)?;
+ let arr = values.as_any().downcast_ref::<Float64Array>().unwrap();
+
+ for i in 0..arr.len() {
+ let value = arr.value(i);
+
+ if value == 0_f64 && values.is_null(i) {
+ continue;
+ }
+ let new_count = self.count + 1;
+ let delta1 = value - self.mean;
+ let new_mean = delta1 / new_count as f64 + self.mean;
+ let delta2 = value - new_mean;
+ let new_m2 = self.m2 + delta1 * delta2;
+
+ self.count += 1;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
+ let means = states[1].as_any().downcast_ref::<Float64Array>().unwrap();
+ let m2s = states[2].as_any().downcast_ref::<Float64Array>().unwrap();
+
+ for i in 0..counts.len() {
+ let c = counts.value(i);
+ if c == 0_u64 {
+ continue;
+ }
+ let new_count = self.count + c;
+ let new_mean = self.mean * self.count as f64 / new_count as f64
+ + means.value(i) * c as f64 / new_count as f64;
+ let delta = self.mean - means.value(i);
+ let new_m2 = self.m2
+ + m2s.value(i)
+ + delta * delta * self.count as f64 * c as f64 / new_count as
f64;
+
+ self.count = new_count;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+ Ok(())
+ }
+
fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
let values = &values[0];
let is_empty = values.is_null();
+ let mean = ScalarValue::from(self.mean);
+ let m2 = ScalarValue::from(self.m2);
if !is_empty {
let new_count = self.count + 1;
- let delta1 = ScalarValue::add(values,
&self.mean.arithmetic_negate())?;
+ let delta1 = ScalarValue::add(values, &mean.arithmetic_negate())?;
let new_mean = ScalarValue::add(
&ScalarValue::div(&delta1, &ScalarValue::from(new_count as
f64))?,
- &self.mean,
+ &mean,
)?;
let delta2 = ScalarValue::add(values,
&new_mean.arithmetic_negate())?;
let tmp = ScalarValue::mul(&delta1, &delta2)?;
- let new_m2 = ScalarValue::add(&self.m2, &tmp)?;
+ let new_m2 = ScalarValue::add(&m2, &tmp)?;
self.count += 1;
- self.mean = new_mean;
- self.m2 = new_m2;
+
+ if let ScalarValue::Float64(Some(c)) = new_mean {
+ self.mean = c;
+ } else {
+ unreachable!()
+ };
+ if let ScalarValue::Float64(Some(m)) = new_m2 {
+ self.m2 = m;
+ } else {
+ unreachable!()
+ };
}
Ok(())
}
fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
- let count = &states[0];
- let mean = &states[1];
- let m2 = &states[2];
+ let count;
+ let mean;
+ let m2;
let mut new_count: u64 = self.count;
- // counts are summed
- if let ScalarValue::UInt64(Some(c)) = count {
- if *c == 0_u64 {
- return Ok(());
- }
+ if let ScalarValue::UInt64(Some(c)) = states[0] {
+ count = c;
+ } else {
+ unreachable!()
+ };
- if self.count == 0 {
- self.count = *c;
- self.mean = mean.clone();
- self.m2 = m2.clone();
- return Ok(());
- }
- new_count += c
+ if count == 0_u64 {
+ return Ok(());
+ }
+
+ if let ScalarValue::Float64(Some(m)) = states[1] {
+ mean = m;
} else {
unreachable!()
};
+ if let ScalarValue::Float64(Some(n)) = states[2] {
+ m2 = n;
+ } else {
+ unreachable!()
+ };
+
+ if self.count == 0 {
+ self.count = count;
+ self.mean = mean;
+ self.m2 = m2;
+ return Ok(());
+ }
- let new_mean = ScalarValue::div(
- &ScalarValue::add(&self.mean, mean)?,
- &ScalarValue::from(2_f64),
+ new_count += count;
+
+ let mean1 = ScalarValue::from(self.mean);
+ let mean2 = ScalarValue::from(mean);
+
+ let new_mean = ScalarValue::add(
+ &ScalarValue::div(
+ &ScalarValue::mul(&mean1, &ScalarValue::from(self.count))?,
+ &ScalarValue::from(new_count as f64),
+ )?,
+ &ScalarValue::div(
+ &ScalarValue::mul(&mean2, &ScalarValue::from(count))?,
+ &ScalarValue::from(new_count as f64),
+ )?,
)?;
- let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?;
+
+ let delta = ScalarValue::add(&mean2.arithmetic_negate(), &mean1)?;
let delta_sqrt = ScalarValue::mul(&delta, &delta)?;
let new_m2 = ScalarValue::add(
&ScalarValue::add(
&ScalarValue::mul(
&delta_sqrt,
&ScalarValue::div(
- &ScalarValue::mul(&ScalarValue::from(self.count),
count)?,
+ &ScalarValue::mul(
+ &ScalarValue::from(self.count),
+ &ScalarValue::from(count),
+ )?,
&ScalarValue::from(new_count as f64),
)?,
)?,
- &self.m2,
+ &ScalarValue::from(self.m2),
)?,
- m2,
+ &ScalarValue::from(m2),
)?;
self.count = new_count;
- self.mean = new_mean;
- self.m2 = new_m2;
+ if let ScalarValue::Float64(Some(c)) = new_mean {
+ self.mean = c;
+ } else {
+ unreachable!()
+ };
+ if let ScalarValue::Float64(Some(m)) = new_m2 {
+ self.m2 = m;
+ } else {
+ unreachable!()
+ };
Ok(())
}
@@ -339,17 +438,10 @@ impl Accumulator for VarianceAccumulator {
));
}
- match self.m2 {
- ScalarValue::Float64(e) => {
- if self.count == 0 {
- Ok(ScalarValue::Float64(None))
- } else {
- Ok(ScalarValue::Float64(e.map(|f| f / count as f64)))
- }
- }
- _ => Err(DataFusionError::Internal(
- "M2 should be f64 for variance".to_string(),
- )),
+ if self.count == 0 {
+ Ok(ScalarValue::Float64(None))
+ } else {
+ Ok(ScalarValue::Float64(Some(self.m2 / count as f64)))
}
}
}