This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a0fccbf886 Move `Covariance` (Sample) `covar` / `covar_samp` to be a
User Defined Aggregate Function (#10372)
a0fccbf886 is described below
commit a0fccbf886346fde5dfbda136149ec98bbd6e952
Author: Jay Zhan <[email protected]>
AuthorDate: Mon May 6 17:46:06 2024 +0800
Move `Covariance` (Sample) `covar` / `covar_samp` to be a User Defined
Aggregate Function (#10372)
* introduce CovarianceSample
Signed-off-by: jayzhan211 <[email protected]>
* rewrite macro
Signed-off-by: jayzhan211 <[email protected]>
* rm old statstype
Signed-off-by: jayzhan211 <[email protected]>
* register
Signed-off-by: jayzhan211 <[email protected]>
* state field
Signed-off-by: jayzhan211 <[email protected]>
* rm builtin
Signed-off-by: jayzhan211 <[email protected]>
* addres comments
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/core/src/physical_planner.rs | 1 +
datafusion/expr/src/aggregate_function.rs | 11 +-
datafusion/expr/src/type_coercion/aggregates.rs | 2 +-
datafusion/functions-aggregate/src/covariance.rs | 318 +++++++++++++++++++++
datafusion/functions-aggregate/src/first_last.rs | 4 +-
datafusion/functions-aggregate/src/lib.rs | 7 +-
datafusion/functions-aggregate/src/macros.rs | 68 +++--
.../physical-expr-common/src/aggregate/mod.rs | 1 +
.../src/aggregate/stats.rs | 1 +
datafusion/physical-expr/src/aggregate/build_in.rs | 154 +---------
.../physical-expr/src/aggregate/covariance.rs | 174 -----------
datafusion/physical-expr/src/aggregate/stats.rs | 9 +-
datafusion/physical-expr/src/expressions/mod.rs | 2 +-
datafusion/proto/proto/datafusion.proto | 2 +-
datafusion/proto/src/generated/pbjson.rs | 3 -
datafusion/proto/src/generated/prost.rs | 4 +-
datafusion/proto/src/logical_plan/from_proto.rs | 1 -
datafusion/proto/src/logical_plan/to_proto.rs | 4 -
datafusion/proto/src/physical_plan/to_proto.rs | 14 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 2 +
datafusion/sqllogictest/test_files/functions.slt | 2 +-
21 files changed, 393 insertions(+), 391 deletions(-)
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 391ded84ea..dfcda553af 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1901,6 +1901,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
+
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let physical_sort_exprs = match order_by {
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index 3dc9c3a01c..af8a682eff 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -63,8 +63,6 @@ pub enum AggregateFunction {
Stddev,
/// Standard Deviation (Population)
StddevPop,
- /// Covariance (Sample)
- Covariance,
/// Covariance (Population)
CovariancePop,
/// Correlation
@@ -128,7 +126,6 @@ impl AggregateFunction {
VariancePop => "VAR_POP",
Stddev => "STDDEV",
StddevPop => "STDDEV_POP",
- Covariance => "COVAR",
CovariancePop => "COVAR_POP",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
@@ -184,9 +181,7 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
- "covar" => AggregateFunction::Covariance,
"covar_pop" => AggregateFunction::CovariancePop,
- "covar_samp" => AggregateFunction::Covariance,
"stddev" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
"stddev_samp" => AggregateFunction::Stddev,
@@ -260,9 +255,6 @@ impl AggregateFunction {
AggregateFunction::VariancePop => {
variance_return_type(&coerced_data_types[0])
}
- AggregateFunction::Covariance => {
- covariance_return_type(&coerced_data_types[0])
- }
AggregateFunction::CovariancePop => {
covariance_return_type(&coerced_data_types[0])
}
@@ -357,8 +349,7 @@ impl AggregateFunction {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2,
Volatility::Immutable),
- AggregateFunction::Covariance
- | AggregateFunction::CovariancePop
+ AggregateFunction::CovariancePop
| AggregateFunction::Correlation
| AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 5ffdc8f947..39726d7d0e 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -183,7 +183,7 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
- AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
+ AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
diff --git a/datafusion/functions-aggregate/src/covariance.rs
b/datafusion/functions-aggregate/src/covariance.rs
new file mode 100644
index 0000000000..130b193996
--- /dev/null
+++ b/datafusion/functions-aggregate/src/covariance.rs
@@ -0,0 +1,318 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`CovarianceSample`]: covariance sample aggregations.
+
+use std::fmt::Debug;
+
+use arrow::{
+ array::{ArrayRef, Float64Array, UInt64Array},
+ compute::kernels::cast,
+ datatypes::{DataType, Field},
+};
+
+use datafusion_common::{
+ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
+ ScalarValue,
+};
+use datafusion_expr::{
+ function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
+ utils::format_state_name, Accumulator, AggregateUDFImpl, Signature,
Volatility,
+};
+use datafusion_physical_expr_common::aggregate::stats::StatsType;
+
+make_udaf_expr_and_func!(
+ CovarianceSample,
+ covar_samp,
+ y x,
+ "Computes the sample covariance.",
+ covar_samp_udaf
+);
+
+pub struct CovarianceSample {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl Debug for CovarianceSample {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.debug_struct("CovarianceSample")
+ .field("name", &self.name())
+ .field("signature", &self.signature)
+ .finish()
+ }
+}
+
+impl Default for CovarianceSample {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl CovarianceSample {
+ pub fn new() -> Self {
+ Self {
+ aliases: vec![String::from("covar")],
+ signature: Signature::uniform(2, NUMERICS.to_vec(),
Volatility::Immutable),
+ }
+ }
+}
+
+impl AggregateUDFImpl for CovarianceSample {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "covar_samp"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ if !arg_types[0].is_numeric() {
+ return plan_err!("Covariance requires numeric input types");
+ }
+
+ Ok(DataType::Float64)
+ }
+
+ fn state_fields(
+ &self,
+ name: &str,
+ _value_type: DataType,
+ _ordering_fields: Vec<Field>,
+ ) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(format_state_name(name, "count"), DataType::UInt64,
true),
+ Field::new(format_state_name(name, "mean1"), DataType::Float64,
true),
+ Field::new(format_state_name(name, "mean2"), DataType::Float64,
true),
+ Field::new(
+ format_state_name(name, "algo_const"),
+ DataType::Float64,
+ true,
+ ),
+ ])
+ }
+
+ fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+/// An accumulator to compute covariance
+/// The algorithm used is an online implementation and numerically stable. It
is derived from the following paper
+/// for calculating variance:
+/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of
squares and products".
+/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
+///
+/// The algorithm has been analyzed here:
+/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing
Sample Means and Variances".
+/// Journal of the American Statistical Association. 69 (348): 859–866.
doi:10.2307/2286154. JSTOR 2286154.
+///
+/// Though it is not covered in the original paper but is based on the same
idea, as a result the algorithm is online,
+/// parallelizable and numerically stable.
+
+#[derive(Debug)]
+pub struct CovarianceAccumulator {
+ algo_const: f64,
+ mean1: f64,
+ mean2: f64,
+ count: u64,
+ stats_type: StatsType,
+}
+
+impl CovarianceAccumulator {
+ /// Creates a new `CovarianceAccumulator`
+ pub fn try_new(s_type: StatsType) -> Result<Self> {
+ Ok(Self {
+ algo_const: 0_f64,
+ mean1: 0_f64,
+ mean2: 0_f64,
+ count: 0_u64,
+ stats_type: s_type,
+ })
+ }
+
+ pub fn get_count(&self) -> u64 {
+ self.count
+ }
+
+ pub fn get_mean1(&self) -> f64 {
+ self.mean1
+ }
+
+ pub fn get_mean2(&self) -> f64 {
+ self.mean2
+ }
+
+ pub fn get_algo_const(&self) -> f64 {
+ self.algo_const
+ }
+}
+
+impl Accumulator for CovarianceAccumulator {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![
+ ScalarValue::from(self.count),
+ ScalarValue::from(self.mean1),
+ ScalarValue::from(self.mean2),
+ ScalarValue::from(self.algo_const),
+ ])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values1 = &cast(&values[0], &DataType::Float64)?;
+ let values2 = &cast(&values[1], &DataType::Float64)?;
+
+ let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
+ let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
+
+ for i in 0..values1.len() {
+ let value1 = if values1.is_valid(i) {
+ arr1.next()
+ } else {
+ None
+ };
+ let value2 = if values2.is_valid(i) {
+ arr2.next()
+ } else {
+ None
+ };
+
+ if value1.is_none() || value2.is_none() {
+ continue;
+ }
+
+ let value1 = unwrap_or_internal_err!(value1);
+ let value2 = unwrap_or_internal_err!(value2);
+ let new_count = self.count + 1;
+ let delta1 = value1 - self.mean1;
+ let new_mean1 = delta1 / new_count as f64 + self.mean1;
+ let delta2 = value2 - self.mean2;
+ let new_mean2 = delta2 / new_count as f64 + self.mean2;
+ let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
+
+ self.count += 1;
+ self.mean1 = new_mean1;
+ self.mean2 = new_mean2;
+ self.algo_const = new_c;
+ }
+
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values1 = &cast(&values[0], &DataType::Float64)?;
+ let values2 = &cast(&values[1], &DataType::Float64)?;
+ let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
+ let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
+
+ for i in 0..values1.len() {
+ let value1 = if values1.is_valid(i) {
+ arr1.next()
+ } else {
+ None
+ };
+ let value2 = if values2.is_valid(i) {
+ arr2.next()
+ } else {
+ None
+ };
+
+ if value1.is_none() || value2.is_none() {
+ continue;
+ }
+
+ let value1 = unwrap_or_internal_err!(value1);
+ let value2 = unwrap_or_internal_err!(value2);
+
+ let new_count = self.count - 1;
+ let delta1 = self.mean1 - value1;
+ let new_mean1 = delta1 / new_count as f64 + self.mean1;
+ let delta2 = self.mean2 - value2;
+ let new_mean2 = delta2 / new_count as f64 + self.mean2;
+ let new_c = self.algo_const - delta1 * (new_mean2 - value2);
+
+ self.count -= 1;
+ self.mean1 = new_mean1;
+ self.mean2 = new_mean2;
+ self.algo_const = new_c;
+ }
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = downcast_value!(states[0], UInt64Array);
+ let means1 = downcast_value!(states[1], Float64Array);
+ let means2 = downcast_value!(states[2], Float64Array);
+ let cs = downcast_value!(states[3], Float64Array);
+
+ for i in 0..counts.len() {
+ let c = counts.value(i);
+ if c == 0_u64 {
+ continue;
+ }
+ let new_count = self.count + c;
+ let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
+ + means1.value(i) * c as f64 / new_count as f64;
+ let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
+ + means2.value(i) * c as f64 / new_count as f64;
+ let delta1 = self.mean1 - means1.value(i);
+ let delta2 = self.mean2 - means2.value(i);
+ let new_c = self.algo_const
+ + cs.value(i)
+ + delta1 * delta2 * self.count as f64 * c as f64 / new_count
as f64;
+
+ self.count = new_count;
+ self.mean1 = new_mean1;
+ self.mean2 = new_mean2;
+ self.algo_const = new_c;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ let count = match self.stats_type {
+ StatsType::Population => self.count,
+ StatsType::Sample => {
+ if self.count > 0 {
+ self.count - 1
+ } else {
+ self.count
+ }
+ }
+ };
+
+ if count == 0 {
+ Ok(ScalarValue::Float64(None))
+ } else {
+ Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
+ }
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
diff --git a/datafusion/functions-aggregate/src/first_last.rs
b/datafusion/functions-aggregate/src/first_last.rs
index 8dc4cee87a..e3b685e903 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -39,12 +39,12 @@ use datafusion_physical_expr_common::expressions;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering,
PhysicalSortExpr};
use datafusion_physical_expr_common::utils::reverse_order_bys;
-use sqlparser::ast::NullTreatment;
+
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
-make_udaf_function!(
+make_udaf_expr_and_func!(
FirstValue,
first_value,
"Returns the first value in a group of values.",
diff --git a/datafusion/functions-aggregate/src/lib.rs
b/datafusion/functions-aggregate/src/lib.rs
index 8016b76889..d4e4d3a5f3 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -55,6 +55,7 @@
#[macro_use]
pub mod macros;
+pub mod covariance;
pub mod first_last;
use datafusion_common::Result;
@@ -65,12 +66,16 @@ use std::sync::Arc;
/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
+ pub use super::covariance::covar_samp;
pub use super::first_last::first_value;
}
/// Registers all enabled packages with a [`FunctionRegistry`]
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
- let functions: Vec<Arc<AggregateUDF>> =
vec![first_last::first_value_udaf()];
+ let functions: Vec<Arc<AggregateUDF>> = vec![
+ first_last::first_value_udaf(),
+ covariance::covar_samp_udaf(),
+ ];
functions.into_iter().try_for_each(|udf| {
let existing_udaf = registry.register_udaf(udf)?;
diff --git a/datafusion/functions-aggregate/src/macros.rs
b/datafusion/functions-aggregate/src/macros.rs
index 04f9fecb8b..27fc623a18 100644
--- a/datafusion/functions-aggregate/src/macros.rs
+++ b/datafusion/functions-aggregate/src/macros.rs
@@ -15,33 +15,59 @@
// specific language governing permissions and limitations
// under the License.
-macro_rules! make_udaf_function {
+macro_rules! make_udaf_expr_and_func {
+ ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr,
$AGGREGATE_UDF_FN:ident) => {
+ // "fluent expr_fn" style function
+ #[doc = $DOC]
+ pub fn $EXPR_FN(
+ $($arg: datafusion_expr::Expr,)*
+ distinct: bool,
+ filter: Option<Box<datafusion_expr::Expr>>,
+ order_by: Option<Vec<datafusion_expr::Expr>>,
+ null_treatment: Option<sqlparser::ast::NullTreatment>
+ ) -> datafusion_expr::Expr {
+
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+ $AGGREGATE_UDF_FN(),
+ vec![$($arg),*],
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ ))
+ }
+ create_func!($UDAF, $AGGREGATE_UDF_FN);
+ };
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
- paste::paste! {
- // "fluent expr_fn" style function
- #[doc = $DOC]
- pub fn $EXPR_FN(
- args: Vec<Expr>,
- distinct: bool,
- filter: Option<Box<Expr>>,
- order_by: Option<Vec<Expr>>,
- null_treatment: Option<NullTreatment>
- ) -> Expr {
-
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
- $AGGREGATE_UDF_FN(),
- args,
- distinct,
- filter,
- order_by,
- null_treatment,
- ))
- }
+ // "fluent expr_fn" style function
+ #[doc = $DOC]
+ pub fn $EXPR_FN(
+ args: Vec<datafusion_expr::Expr>,
+ distinct: bool,
+ filter: Option<Box<datafusion_expr::Expr>>,
+ order_by: Option<Vec<datafusion_expr::Expr>>,
+ null_treatment: Option<sqlparser::ast::NullTreatment>
+ ) -> datafusion_expr::Expr {
+
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+ $AGGREGATE_UDF_FN(),
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ ))
+ }
+ create_func!($UDAF, $AGGREGATE_UDF_FN);
+ };
+}
+macro_rules! create_func {
+ ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
+ paste::paste! {
/// Singleton instance of [$UDAF], ensures the UDAF is only
created once
/// named STATIC_$(UDAF). For example `STATIC_FirstValue`
#[allow(non_upper_case_globals)]
static [< STATIC_ $UDAF >]:
std::sync::OnceLock<std::sync::Arc<datafusion_expr::AggregateUDF>> =
- std::sync::OnceLock::new();
+ std::sync::OnceLock::new();
/// AggregateFunction that returns a [AggregateUDF] for [$UDAF]
///
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 448af63417..d2e3414fbf 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+pub mod stats;
pub mod utils;
use arrow::datatypes::{DataType, Field, Schema};
diff --git a/datafusion/physical-expr/src/aggregate/stats.rs
b/datafusion/physical-expr-common/src/aggregate/stats.rs
similarity index 95%
copy from datafusion/physical-expr/src/aggregate/stats.rs
copy to datafusion/physical-expr-common/src/aggregate/stats.rs
index 98baaccffe..6a11ebe36c 100644
--- a/datafusion/physical-expr/src/aggregate/stats.rs
+++ b/datafusion/physical-expr-common/src/aggregate/stats.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+/// TODO: Move this to functions-aggregate module
/// Enum used for differentiating population and sample for statistical
functions
#[derive(Debug, Clone, Copy)]
pub enum StatsType {
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 57ed35b0b7..36af875473 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -181,15 +181,6 @@ pub fn create_aggregate_expr(
(AggregateFunction::VariancePop, true) => {
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not
available");
}
- (AggregateFunction::Covariance, false) =>
Arc::new(expressions::Covariance::new(
- input_phy_exprs[0].clone(),
- input_phy_exprs[1].clone(),
- name,
- data_type,
- )),
- (AggregateFunction::Covariance, true) => {
- return not_impl_err!("COVAR(DISTINCT) aggregations are not
available");
- }
(AggregateFunction::CovariancePop, false) => {
Arc::new(expressions::CovariancePop::new(
input_phy_exprs[0].clone(),
@@ -428,8 +419,8 @@ mod tests {
use crate::expressions::{
try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont,
ArrayAgg, Avg,
- BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Correlation, Count, Covariance,
- DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance,
+ BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg,
DistinctCount,
+ Max, Min, Stddev, Sum, Variance,
};
use super::*;
@@ -950,147 +941,6 @@ mod tests {
Ok(())
}
- #[test]
- fn test_covar_expr() -> Result<()> {
- let funcs = vec![AggregateFunction::Covariance];
- let data_types = vec![
- DataType::UInt32,
- DataType::UInt64,
- DataType::Int32,
- DataType::Int64,
- DataType::Float32,
- DataType::Float64,
- ];
- for fun in funcs {
- for data_type in &data_types {
- let input_schema = Schema::new(vec![
- Field::new("c1", data_type.clone(), true),
- Field::new("c2", data_type.clone(), true),
- ]);
- let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
- Arc::new(
- expressions::Column::new_with_schema("c1",
&input_schema)
- .unwrap(),
- ),
- Arc::new(
- expressions::Column::new_with_schema("c2",
&input_schema)
- .unwrap(),
- ),
- ];
- let result_agg_phy_exprs = create_physical_agg_expr_for_test(
- &fun,
- false,
- &input_phy_exprs[0..2],
- &input_schema,
- "c1",
- )?;
- if fun == AggregateFunction::Covariance {
- assert!(result_agg_phy_exprs.as_any().is::<Covariance>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", DataType::Float64, true),
- result_agg_phy_exprs.field().unwrap()
- )
- }
- }
- }
- Ok(())
- }
-
- #[test]
- fn test_covar_pop_expr() -> Result<()> {
- let funcs = vec![AggregateFunction::CovariancePop];
- let data_types = vec![
- DataType::UInt32,
- DataType::UInt64,
- DataType::Int32,
- DataType::Int64,
- DataType::Float32,
- DataType::Float64,
- ];
- for fun in funcs {
- for data_type in &data_types {
- let input_schema = Schema::new(vec![
- Field::new("c1", data_type.clone(), true),
- Field::new("c2", data_type.clone(), true),
- ]);
- let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
- Arc::new(
- expressions::Column::new_with_schema("c1",
&input_schema)
- .unwrap(),
- ),
- Arc::new(
- expressions::Column::new_with_schema("c2",
&input_schema)
- .unwrap(),
- ),
- ];
- let result_agg_phy_exprs = create_physical_agg_expr_for_test(
- &fun,
- false,
- &input_phy_exprs[0..2],
- &input_schema,
- "c1",
- )?;
- if fun == AggregateFunction::Covariance {
- assert!(result_agg_phy_exprs.as_any().is::<Covariance>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", DataType::Float64, true),
- result_agg_phy_exprs.field().unwrap()
- )
- }
- }
- }
- Ok(())
- }
-
- #[test]
- fn test_corr_expr() -> Result<()> {
- let funcs = vec![AggregateFunction::Correlation];
- let data_types = vec![
- DataType::UInt32,
- DataType::UInt64,
- DataType::Int32,
- DataType::Int64,
- DataType::Float32,
- DataType::Float64,
- ];
- for fun in funcs {
- for data_type in &data_types {
- let input_schema = Schema::new(vec![
- Field::new("c1", data_type.clone(), true),
- Field::new("c2", data_type.clone(), true),
- ]);
- let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
- Arc::new(
- expressions::Column::new_with_schema("c1",
&input_schema)
- .unwrap(),
- ),
- Arc::new(
- expressions::Column::new_with_schema("c2",
&input_schema)
- .unwrap(),
- ),
- ];
- let result_agg_phy_exprs = create_physical_agg_expr_for_test(
- &fun,
- false,
- &input_phy_exprs[0..2],
- &input_schema,
- "c1",
- )?;
- if fun == AggregateFunction::Covariance {
- assert!(result_agg_phy_exprs.as_any().is::<Correlation>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", DataType::Float64, true),
- result_agg_phy_exprs.field().unwrap()
- )
- }
- }
- }
- Ok(())
- }
-
#[test]
fn test_median_expr() -> Result<()> {
let funcs = vec![AggregateFunction::ApproxMedian];
diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs
b/datafusion/physical-expr/src/aggregate/covariance.rs
index ba9bdbc8ae..272f1d8be2 100644
--- a/datafusion/physical-expr/src/aggregate/covariance.rs
+++ b/datafusion/physical-expr/src/aggregate/covariance.rs
@@ -36,14 +36,6 @@ use crate::aggregate::stats::StatsType;
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
-/// COVAR and COVAR_SAMP aggregate expression
-#[derive(Debug)]
-pub struct Covariance {
- name: String,
- expr1: Arc<dyn PhysicalExpr>,
- expr2: Arc<dyn PhysicalExpr>,
-}
-
/// COVAR_POP aggregate expression
#[derive(Debug)]
pub struct CovariancePop {
@@ -52,83 +44,6 @@ pub struct CovariancePop {
expr2: Arc<dyn PhysicalExpr>,
}
-impl Covariance {
- /// Create a new COVAR aggregate function
- pub fn new(
- expr1: Arc<dyn PhysicalExpr>,
- expr2: Arc<dyn PhysicalExpr>,
- name: impl Into<String>,
- data_type: DataType,
- ) -> Self {
- // the result of covariance just support FLOAT64 data type.
- assert!(matches!(data_type, DataType::Float64));
- Self {
- name: name.into(),
- expr1,
- expr2,
- }
- }
-}
-
-impl AggregateExpr for Covariance {
- /// Return a reference to Any that can be used for downcasting
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, DataType::Float64, true))
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
- }
-
- fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(vec![
- Field::new(
- format_state_name(&self.name, "count"),
- DataType::UInt64,
- true,
- ),
- Field::new(
- format_state_name(&self.name, "mean1"),
- DataType::Float64,
- true,
- ),
- Field::new(
- format_state_name(&self.name, "mean2"),
- DataType::Float64,
- true,
- ),
- Field::new(
- format_state_name(&self.name, "algo_const"),
- DataType::Float64,
- true,
- ),
- ])
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![self.expr1.clone(), self.expr2.clone()]
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-}
-
-impl PartialEq<dyn Any> for Covariance {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| {
- self.name == x.name && self.expr1.eq(&x.expr1) &&
self.expr2.eq(&x.expr2)
- })
- .unwrap_or(false)
- }
-}
-
impl CovariancePop {
/// Create a new COVAR_POP aggregate function
pub fn new(
@@ -429,36 +344,6 @@ mod tests {
)
}
- #[test]
- fn covariance_f64_2() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64,
3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64,
6_f64]));
-
- generic_test_op2!(
- a,
- b,
- DataType::Float64,
- DataType::Float64,
- Covariance,
- ScalarValue::from(1_f64)
- )
- }
-
- #[test]
- fn covariance_f64_4() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64,
3_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64,
6_f64]));
-
- generic_test_op2!(
- a,
- b,
- DataType::Float64,
- DataType::Float64,
- Covariance,
- ScalarValue::from(0.9033333333333335_f64)
- )
- }
-
#[test]
fn covariance_f64_5() -> Result<()> {
let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64,
3_f64]));
@@ -580,50 +465,6 @@ mod tests {
)
}
- #[test]
- fn covariance_i32_with_nulls_3() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
- Some(1),
- None,
- Some(2),
- None,
- Some(3),
- None,
- ]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![
- Some(4),
- Some(9),
- Some(5),
- Some(8),
- Some(6),
- None,
- ]));
-
- generic_test_op2!(
- a,
- b,
- DataType::Int32,
- DataType::Int32,
- Covariance,
- ScalarValue::from(1_f64)
- )
- }
-
- #[test]
- fn covariance_i32_all_nulls() -> Result<()> {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
-
- generic_test_op2!(
- a,
- b,
- DataType::Int32,
- DataType::Int32,
- Covariance,
- ScalarValue::Float64(None)
- )
- }
-
#[test]
fn covariance_pop_i32_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
@@ -639,21 +480,6 @@ mod tests {
)
}
- #[test]
- fn covariance_1_input() -> Result<()> {
- let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64]));
- let b: ArrayRef = Arc::new(Float64Array::from(vec![2_f64]));
-
- generic_test_op2!(
- a,
- b,
- DataType::Float64,
- DataType::Float64,
- Covariance,
- ScalarValue::Float64(None)
- )
- }
-
#[test]
fn covariance_pop_1_input() -> Result<()> {
let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64]));
diff --git a/datafusion/physical-expr/src/aggregate/stats.rs
b/datafusion/physical-expr/src/aggregate/stats.rs
index 98baaccffe..d9338f5a96 100644
--- a/datafusion/physical-expr/src/aggregate/stats.rs
+++ b/datafusion/physical-expr/src/aggregate/stats.rs
@@ -15,11 +15,4 @@
// specific language governing permissions and limitations
// under the License.
-/// Enum used for differentiating population and sample for statistical
functions
-#[derive(Debug, Clone, Copy)]
-pub enum StatsType {
- /// Population
- Population,
- /// Sample
- Sample,
-}
+pub use datafusion_physical_expr_common::aggregate::stats::StatsType;
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 55ebd9ed8c..0cd2ac2c9e 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -52,7 +52,7 @@ pub use crate::aggregate::build_in::create_aggregate_expr;
pub use crate::aggregate::correlation::Correlation;
pub use crate::aggregate::count::Count;
pub use crate::aggregate::count_distinct::DistinctCount;
-pub use crate::aggregate::covariance::{Covariance, CovariancePop};
+pub use crate::aggregate::covariance::CovariancePop;
pub use crate::aggregate::grouping::Grouping;
pub use crate::aggregate::median::Median;
pub use crate::aggregate::min_max::{Max, Min};
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 9e4ea8e712..c057ab8acd 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -548,7 +548,7 @@ enum AggregateFunction {
ARRAY_AGG = 6;
VARIANCE = 7;
VARIANCE_POP = 8;
- COVARIANCE = 9;
+ // COVARIANCE = 9;
COVARIANCE_POP = 10;
STDDEV = 11;
STDDEV_POP = 12;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index b5779d25c6..994703c5fc 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -430,7 +430,6 @@ impl serde::Serialize for AggregateFunction {
Self::ArrayAgg => "ARRAY_AGG",
Self::Variance => "VARIANCE",
Self::VariancePop => "VARIANCE_POP",
- Self::Covariance => "COVARIANCE",
Self::CovariancePop => "COVARIANCE_POP",
Self::Stddev => "STDDEV",
Self::StddevPop => "STDDEV_POP",
@@ -478,7 +477,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"ARRAY_AGG",
"VARIANCE",
"VARIANCE_POP",
- "COVARIANCE",
"COVARIANCE_POP",
"STDDEV",
"STDDEV_POP",
@@ -555,7 +553,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg),
"VARIANCE" => Ok(AggregateFunction::Variance),
"VARIANCE_POP" => Ok(AggregateFunction::VariancePop),
- "COVARIANCE" => Ok(AggregateFunction::Covariance),
"COVARIANCE_POP" => Ok(AggregateFunction::CovariancePop),
"STDDEV" => Ok(AggregateFunction::Stddev),
"STDDEV_POP" => Ok(AggregateFunction::StddevPop),
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index c822ac1301..fc23a9ea05 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2834,7 +2834,7 @@ pub enum AggregateFunction {
ArrayAgg = 6,
Variance = 7,
VariancePop = 8,
- Covariance = 9,
+ /// COVARIANCE = 9;
CovariancePop = 10,
Stddev = 11,
StddevPop = 12,
@@ -2881,7 +2881,6 @@ impl AggregateFunction {
AggregateFunction::ArrayAgg => "ARRAY_AGG",
AggregateFunction::Variance => "VARIANCE",
AggregateFunction::VariancePop => "VARIANCE_POP",
- AggregateFunction::Covariance => "COVARIANCE",
AggregateFunction::CovariancePop => "COVARIANCE_POP",
AggregateFunction::Stddev => "STDDEV",
AggregateFunction::StddevPop => "STDDEV_POP",
@@ -2925,7 +2924,6 @@ impl AggregateFunction {
"ARRAY_AGG" => Some(Self::ArrayAgg),
"VARIANCE" => Some(Self::Variance),
"VARIANCE_POP" => Some(Self::VariancePop),
- "COVARIANCE" => Some(Self::Covariance),
"COVARIANCE_POP" => Some(Self::CovariancePop),
"STDDEV" => Some(Self::Stddev),
"STDDEV_POP" => Some(Self::StddevPop),
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 83b232da9d..35d4c6409b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -428,7 +428,6 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg,
protobuf::AggregateFunction::Variance => Self::Variance,
protobuf::AggregateFunction::VariancePop => Self::VariancePop,
- protobuf::AggregateFunction::Covariance => Self::Covariance,
protobuf::AggregateFunction::CovariancePop => Self::CovariancePop,
protobuf::AggregateFunction::Stddev => Self::Stddev,
protobuf::AggregateFunction::StddevPop => Self::StddevPop,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index b2236847ac..dcec2a3b85 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -369,7 +369,6 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::ArrayAgg => Self::ArrayAgg,
AggregateFunction::Variance => Self::Variance,
AggregateFunction::VariancePop => Self::VariancePop,
- AggregateFunction::Covariance => Self::Covariance,
AggregateFunction::CovariancePop => Self::CovariancePop,
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
@@ -674,9 +673,6 @@ pub fn serialize_expr(
AggregateFunction::VariancePop => {
protobuf::AggregateFunction::VariancePop
}
- AggregateFunction::Covariance => {
- protobuf::AggregateFunction::Covariance
- }
AggregateFunction::CovariancePop => {
protobuf::AggregateFunction::CovariancePop
}
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index c7df6ebf58..a0a0ee7205 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -25,12 +25,12 @@ use datafusion::physical_expr::{PhysicalSortExpr,
ScalarFunctionExpr};
use datafusion::physical_plan::expressions::{
ApproxDistinct, ApproxMedian, ApproxPercentileCont,
ApproxPercentileContWithWeight,
ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr,
CaseExpr,
- CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist,
- DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue,
Grouping,
- InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median,
Min,
- NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile,
OrderSensitiveArrayAgg, Rank,
- RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum,
TryCastExpr,
- Variance, VariancePop, WindowShift,
+ CastExpr, Column, Correlation, Count, CovariancePop, CumeDist,
DistinctArrayAgg,
+ DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping,
InListExpr,
+ IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min,
NegativeExpr,
+ NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank,
RankType, Regr,
+ RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr,
Variance,
+ VariancePop, WindowShift,
};
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion::physical_plan::windows::{BuiltInWindowExpr,
PlainAggregateWindowExpr};
@@ -292,8 +292,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
protobuf::AggregateFunction::Variance
} else if aggr_expr.downcast_ref::<VariancePop>().is_some() {
protobuf::AggregateFunction::VariancePop
- } else if aggr_expr.downcast_ref::<Covariance>().is_some() {
- protobuf::AggregateFunction::Covariance
} else if aggr_expr.downcast_ref::<CovariancePop>().is_some() {
protobuf::AggregateFunction::CovariancePop
} else if aggr_expr.downcast_ref::<Stddev>().is_some() {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 65985f8680..3800b672b5 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -30,6 +30,7 @@ use datafusion::datasource::provider::TableProviderFactory;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+use datafusion::functions_aggregate::covariance::covar_samp;
use datafusion::functions_aggregate::expr_fn::first_value;
use datafusion::prelude::*;
use datafusion::test_util::{TestTableFactory, TestTableProvider};
@@ -614,6 +615,7 @@ async fn roundtrip_expr_api() -> Result<()> {
),
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2),
lit(4)),
first_value(vec![lit(1)], false, None, None, None),
+ covar_samp(lit(1.5), lit(2.2), false, None, None, None),
];
// ensure expressions created with the expr api can be round tripped
diff --git a/datafusion/sqllogictest/test_files/functions.slt
b/datafusion/sqllogictest/test_files/functions.slt
index bc8f6a2687..d03b33d0c8 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -495,7 +495,7 @@ statement error Did you mean 'STDDEV'?
SELECT STDEV(v1) from test;
# Aggregate function
-statement error Did you mean 'COVAR'?
+statement error DataFusion error: Error during planning: Invalid function
'covaria'.\nDid you mean 'covar'?
SELECT COVARIA(1,1);
# Window function
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]