This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 70744d59d5 Convert variance sample to udaf (#10713)
70744d59d5 is described below
commit 70744d59d5c1caca6313b77a1baa4025c2558fdd
Author: Yue Yin <[email protected]>
AuthorDate: Wed Jun 5 01:18:42 2024 -0400
Convert variance sample to udaf (#10713)
* Without migrating tests
* Should fail VAR(DISTINCT) but doesn't
* Pass all other tests.
* Return error for var(distinct)
* Migrate tests
* Fix tests
* Lint
* Fix tests
* Fix use
---
datafusion/expr/src/aggregate_function.rs | 7 -
datafusion/expr/src/type_coercion/aggregates.rs | 2 +-
datafusion/functions-aggregate/src/lib.rs | 3 +
datafusion/functions-aggregate/src/variance.rs | 263 +++++++++++++++++++++
datafusion/physical-expr/src/aggregate/build_in.rs | 85 +------
datafusion/physical-expr/src/aggregate/variance.rs | 75 ------
datafusion/physical-expr/src/expressions/mod.rs | 2 +-
datafusion/proto/proto/datafusion.proto | 2 +-
datafusion/proto/src/generated/pbjson.rs | 3 -
datafusion/proto/src/generated/prost.rs | 4 +-
datafusion/proto/src/logical_plan/from_proto.rs | 1 -
datafusion/proto/src/logical_plan/to_proto.rs | 2 -
datafusion/proto/src/physical_plan/to_proto.rs | 4 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 6 +-
datafusion/sqllogictest/Cargo.toml | 10 +-
datafusion/sqllogictest/test_files/aggregate.slt | 48 ++++
.../sqllogictest/test_files/sort_merge_join.slt | 1 -
17 files changed, 338 insertions(+), 180 deletions(-)
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index fb5a8db550..8f683cabe6 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -49,8 +49,6 @@ pub enum AggregateFunction {
ArrayAgg,
/// N'th value in a group according to some ordering
NthValue,
- /// Variance (Sample)
- Variance,
/// Variance (Population)
VariancePop,
/// Standard Deviation (Sample)
@@ -111,7 +109,6 @@ impl AggregateFunction {
ApproxDistinct => "APPROX_DISTINCT",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
- Variance => "VAR",
VariancePop => "VAR_POP",
Stddev => "STDDEV",
StddevPop => "STDDEV_POP",
@@ -169,9 +166,7 @@ impl FromStr for AggregateFunction {
"stddev" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
"stddev_samp" => AggregateFunction::Stddev,
- "var" => AggregateFunction::Variance,
"var_pop" => AggregateFunction::VariancePop,
- "var_samp" => AggregateFunction::Variance,
"regr_slope" => AggregateFunction::RegrSlope,
"regr_intercept" => AggregateFunction::RegrIntercept,
"regr_count" => AggregateFunction::RegrCount,
@@ -235,7 +230,6 @@ impl AggregateFunction {
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Ok(DataType::Boolean)
}
- AggregateFunction::Variance =>
variance_return_type(&coerced_data_types[0]),
AggregateFunction::VariancePop => {
variance_return_type(&coerced_data_types[0])
}
@@ -315,7 +309,6 @@ impl AggregateFunction {
}
AggregateFunction::Avg
| AggregateFunction::Sum
- | AggregateFunction::Variance
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 6bd204c53c..b7004e200d 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -173,7 +173,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
- AggregateFunction::Variance | AggregateFunction::VariancePop => {
+ AggregateFunction::VariancePop => {
if !is_variance_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
diff --git a/datafusion/functions-aggregate/src/lib.rs
b/datafusion/functions-aggregate/src/lib.rs
index cb8ef65420..ff02d25ad0 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -59,6 +59,7 @@ pub mod covariance;
pub mod first_last;
pub mod median;
pub mod sum;
+pub mod variance;
use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
@@ -74,6 +75,7 @@ pub mod expr_fn {
pub use super::first_last::last_value;
pub use super::median::median;
pub use super::sum::sum;
+ pub use super::variance::var_sample;
}
/// Returns all default aggregate functions
@@ -85,6 +87,7 @@ pub fn all_default_aggregate_functions() ->
Vec<Arc<AggregateUDF>> {
sum::sum_udaf(),
covariance::covar_pop_udaf(),
median::median_udaf(),
+ variance::var_samp_udaf(),
]
}
diff --git a/datafusion/functions-aggregate/src/variance.rs
b/datafusion/functions-aggregate/src/variance.rs
new file mode 100644
index 0000000000..b5d467d0e7
--- /dev/null
+++ b/datafusion/functions-aggregate/src/variance.rs
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`VarianceSample`]: covariance sample aggregations.
+
+use std::fmt::Debug;
+
+use arrow::{
+ array::{ArrayRef, Float64Array, UInt64Array},
+ compute::kernels::cast,
+ datatypes::{DataType, Field},
+};
+
+use datafusion_common::{
+ downcast_value, not_impl_err, plan_err, DataFusionError, Result,
ScalarValue,
+};
+use datafusion_expr::{
+ function::{AccumulatorArgs, StateFieldsArgs},
+ utils::format_state_name,
+ Accumulator, AggregateUDFImpl, Signature, Volatility,
+};
+use datafusion_physical_expr_common::aggregate::stats::StatsType;
+
+make_udaf_expr_and_func!(
+ VarianceSample,
+ var_sample,
+ expression,
+ "Computes the sample variance.",
+ var_samp_udaf
+);
+
+pub struct VarianceSample {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl Debug for VarianceSample {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.debug_struct("VarianceSample")
+ .field("name", &self.name())
+ .field("signature", &self.signature)
+ .finish()
+ }
+}
+
+impl Default for VarianceSample {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl VarianceSample {
+ pub fn new() -> Self {
+ Self {
+ aliases: vec![String::from("var_sample"),
String::from("var_samp")],
+ signature: Signature::numeric(1, Volatility::Immutable),
+ }
+ }
+}
+
+impl AggregateUDFImpl for VarianceSample {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "var"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ if !arg_types[0].is_numeric() {
+ return plan_err!("Variance requires numeric input types");
+ }
+
+ Ok(DataType::Float64)
+ }
+
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ let name = args.name;
+ Ok(vec![
+ Field::new(format_state_name(name, "count"), DataType::UInt64,
true),
+ Field::new(format_state_name(name, "mean"), DataType::Float64,
true),
+ Field::new(format_state_name(name, "m2"), DataType::Float64, true),
+ ])
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ if acc_args.is_distinct {
+ return not_impl_err!("VAR(DISTINCT) aggregations are not
available");
+ }
+
+ Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+/// An accumulator to compute variance
+/// The algrithm used is an online implementation and numerically stable. It
is based on this paper:
+/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of
squares and products".
+/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
+///
+/// The algorithm has been analyzed here:
+/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing
Sample Means and Variances".
+/// Journal of the American Statistical Association. 69 (348): 859–866.
doi:10.2307/2286154. JSTOR 2286154.
+
+#[derive(Debug)]
+pub struct VarianceAccumulator {
+ m2: f64,
+ mean: f64,
+ count: u64,
+ stats_type: StatsType,
+}
+
+impl VarianceAccumulator {
+ /// Creates a new `VarianceAccumulator`
+ pub fn try_new(s_type: StatsType) -> Result<Self> {
+ Ok(Self {
+ m2: 0_f64,
+ mean: 0_f64,
+ count: 0_u64,
+ stats_type: s_type,
+ })
+ }
+
+ pub fn get_count(&self) -> u64 {
+ self.count
+ }
+
+ pub fn get_mean(&self) -> f64 {
+ self.mean
+ }
+
+ pub fn get_m2(&self) -> f64 {
+ self.m2
+ }
+}
+
+impl Accumulator for VarianceAccumulator {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![
+ ScalarValue::from(self.count),
+ ScalarValue::from(self.mean),
+ ScalarValue::from(self.m2),
+ ])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = &cast(&values[0], &DataType::Float64)?;
+ let arr = downcast_value!(values, Float64Array).iter().flatten();
+
+ for value in arr {
+ let new_count = self.count + 1;
+ let delta1 = value - self.mean;
+ let new_mean = delta1 / new_count as f64 + self.mean;
+ let delta2 = value - new_mean;
+ let new_m2 = self.m2 + delta1 * delta2;
+
+ self.count += 1;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = &cast(&values[0], &DataType::Float64)?;
+ let arr = downcast_value!(values, Float64Array).iter().flatten();
+
+ for value in arr {
+ let new_count = self.count - 1;
+ let delta1 = self.mean - value;
+ let new_mean = delta1 / new_count as f64 + self.mean;
+ let delta2 = new_mean - value;
+ let new_m2 = self.m2 - delta1 * delta2;
+
+ self.count -= 1;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = downcast_value!(states[0], UInt64Array);
+ let means = downcast_value!(states[1], Float64Array);
+ let m2s = downcast_value!(states[2], Float64Array);
+
+ for i in 0..counts.len() {
+ let c = counts.value(i);
+ if c == 0_u64 {
+ continue;
+ }
+ let new_count = self.count + c;
+ let new_mean = self.mean * self.count as f64 / new_count as f64
+ + means.value(i) * c as f64 / new_count as f64;
+ let delta = self.mean - means.value(i);
+ let new_m2 = self.m2
+ + m2s.value(i)
+ + delta * delta * self.count as f64 * c as f64 / new_count as
f64;
+
+ self.count = new_count;
+ self.mean = new_mean;
+ self.m2 = new_m2;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ let count = match self.stats_type {
+ StatsType::Population => self.count,
+ StatsType::Sample => {
+ if self.count > 0 {
+ self.count - 1
+ } else {
+ self.count
+ }
+ }
+ };
+
+ Ok(ScalarValue::Float64(match self.count {
+ 0 => None,
+ 1 => {
+ if let StatsType::Population = self.stats_type {
+ Some(0.0)
+ } else {
+ None
+ }
+ }
+ _ => Some(self.m2 / count as f64),
+ }))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+
+ fn supports_retract_batch(&self) -> bool {
+ true
+ }
+}
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 813a394d69..07409dd1f4 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -160,14 +160,6 @@ pub fn create_aggregate_expr(
(AggregateFunction::Avg, true) => {
return not_impl_err!("AVG(DISTINCT) aggregations are not
available");
}
- (AggregateFunction::Variance, false) =>
Arc::new(expressions::Variance::new(
- input_phy_exprs[0].clone(),
- name,
- data_type,
- )),
- (AggregateFunction::Variance, true) => {
- return not_impl_err!("VAR(DISTINCT) aggregations are not
available");
- }
(AggregateFunction::VariancePop, false) => Arc::new(
expressions::VariancePop::new(input_phy_exprs[0].clone(), name,
data_type),
),
@@ -367,12 +359,13 @@ pub fn create_aggregate_expr(
#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
+ use expressions::{StddevPop, VariancePop};
use super::*;
use crate::expressions::{
try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont,
ArrayAgg, Avg,
BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg,
DistinctCount,
- Max, Min, Stddev, Variance,
+ Max, Min, Stddev,
};
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
@@ -719,44 +712,6 @@ mod tests {
Ok(())
}
- #[test]
- fn test_variance_expr() -> Result<()> {
- let funcs = vec![AggregateFunction::Variance];
- let data_types = vec![
- DataType::UInt32,
- DataType::UInt64,
- DataType::Int32,
- DataType::Int64,
- DataType::Float32,
- DataType::Float64,
- ];
- for fun in funcs {
- for data_type in &data_types {
- let input_schema =
- Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
- let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
- expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
- )];
- let result_agg_phy_exprs = create_physical_agg_expr_for_test(
- &fun,
- false,
- &input_phy_exprs[0..1],
- &input_schema,
- "c1",
- )?;
- if fun == AggregateFunction::Variance {
- assert!(result_agg_phy_exprs.as_any().is::<Variance>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", DataType::Float64, true),
- result_agg_phy_exprs.field().unwrap()
- )
- }
- }
- }
- Ok(())
- }
-
#[test]
fn test_var_pop_expr() -> Result<()> {
let funcs = vec![AggregateFunction::VariancePop];
@@ -782,8 +737,8 @@ mod tests {
&input_schema,
"c1",
)?;
- if fun == AggregateFunction::Variance {
- assert!(result_agg_phy_exprs.as_any().is::<Variance>());
+ if fun == AggregateFunction::VariancePop {
+ assert!(result_agg_phy_exprs.as_any().is::<VariancePop>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
@@ -820,7 +775,7 @@ mod tests {
&input_schema,
"c1",
)?;
- if fun == AggregateFunction::Variance {
+ if fun == AggregateFunction::Stddev {
assert!(result_agg_phy_exprs.as_any().is::<Stddev>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
@@ -858,8 +813,8 @@ mod tests {
&input_schema,
"c1",
)?;
- if fun == AggregateFunction::Variance {
- assert!(result_agg_phy_exprs.as_any().is::<Stddev>());
+ if fun == AggregateFunction::StddevPop {
+ assert!(result_agg_phy_exprs.as_any().is::<StddevPop>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
@@ -987,32 +942,6 @@ mod tests {
assert!(observed.is_err());
}
- #[test]
- fn test_variance_return_type() -> Result<()> {
- let observed =
AggregateFunction::Variance.return_type(&[DataType::Float32])?;
- assert_eq!(DataType::Float64, observed);
-
- let observed =
AggregateFunction::Variance.return_type(&[DataType::Float64])?;
- assert_eq!(DataType::Float64, observed);
-
- let observed =
AggregateFunction::Variance.return_type(&[DataType::Int32])?;
- assert_eq!(DataType::Float64, observed);
-
- let observed =
AggregateFunction::Variance.return_type(&[DataType::UInt32])?;
- assert_eq!(DataType::Float64, observed);
-
- let observed =
AggregateFunction::Variance.return_type(&[DataType::Int64])?;
- assert_eq!(DataType::Float64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_variance_no_utf8() {
- let observed =
AggregateFunction::Variance.return_type(&[DataType::Utf8]);
- assert!(observed.is_err());
- }
-
#[test]
fn test_stddev_return_type() -> Result<()> {
let observed =
AggregateFunction::Stddev.return_type(&[DataType::Float32])?;
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs
b/datafusion/physical-expr/src/aggregate/variance.rs
index 7ae917409a..3db3c0e3ae 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -35,13 +35,6 @@ use datafusion_common::downcast_value;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::Accumulator;
-/// VAR and VAR_SAMP aggregate expression
-#[derive(Debug)]
-pub struct Variance {
- name: String,
- expr: Arc<dyn PhysicalExpr>,
-}
-
/// VAR_POP aggregate expression
#[derive(Debug)]
pub struct VariancePop {
@@ -49,74 +42,6 @@ pub struct VariancePop {
expr: Arc<dyn PhysicalExpr>,
}
-impl Variance {
- /// Create a new VARIANCE aggregate function
- pub fn new(
- expr: Arc<dyn PhysicalExpr>,
- name: impl Into<String>,
- data_type: DataType,
- ) -> Self {
- // the result of variance just support FLOAT64 data type.
- assert!(matches!(data_type, DataType::Float64));
- Self {
- name: name.into(),
- expr,
- }
- }
-}
-
-impl AggregateExpr for Variance {
- /// Return a reference to Any that can be used for downcasting
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, DataType::Float64, true))
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
- }
-
- fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
- }
-
- fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(vec![
- Field::new(
- format_state_name(&self.name, "count"),
- DataType::UInt64,
- true,
- ),
- Field::new(
- format_state_name(&self.name, "mean"),
- DataType::Float64,
- true,
- ),
- Field::new(format_state_name(&self.name, "m2"), DataType::Float64,
true),
- ])
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![self.expr.clone()]
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-}
-
-impl PartialEq<dyn Any> for Variance {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| self.name == x.name && self.expr.eq(&x.expr))
- .unwrap_or(false)
- }
-}
-
impl VariancePop {
/// Create a new VAR_POP aggregate function
pub fn new(
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 1e9644f75a..324699af5b 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -60,7 +60,7 @@ pub use crate::aggregate::stddev::{Stddev, StddevPop};
pub use crate::aggregate::string_agg::StringAgg;
pub use crate::aggregate::sum::Sum;
pub use crate::aggregate::sum_distinct::DistinctSum;
-pub use crate::aggregate::variance::{Variance, VariancePop};
+pub use crate::aggregate::variance::VariancePop;
pub use crate::window::cume_dist::{cume_dist, CumeDist};
pub use crate::window::lead_lag::{lag, lead, WindowShift};
pub use crate::window::nth_value::NthValue;
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index fa95194696..f8d229f48d 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -479,7 +479,7 @@ enum AggregateFunction {
COUNT = 4;
APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
- VARIANCE = 7;
+ // VARIANCE = 7;
VARIANCE_POP = 8;
// COVARIANCE = 9;
// COVARIANCE_POP = 10;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index b0e77eb69e..6de030679c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -539,7 +539,6 @@ impl serde::Serialize for AggregateFunction {
Self::Count => "COUNT",
Self::ApproxDistinct => "APPROX_DISTINCT",
Self::ArrayAgg => "ARRAY_AGG",
- Self::Variance => "VARIANCE",
Self::VariancePop => "VARIANCE_POP",
Self::Stddev => "STDDEV",
Self::StddevPop => "STDDEV_POP",
@@ -582,7 +581,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"COUNT",
"APPROX_DISTINCT",
"ARRAY_AGG",
- "VARIANCE",
"VARIANCE_POP",
"STDDEV",
"STDDEV_POP",
@@ -654,7 +652,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"COUNT" => Ok(AggregateFunction::Count),
"APPROX_DISTINCT" => Ok(AggregateFunction::ApproxDistinct),
"ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg),
- "VARIANCE" => Ok(AggregateFunction::Variance),
"VARIANCE_POP" => Ok(AggregateFunction::VariancePop),
"STDDEV" => Ok(AggregateFunction::Stddev),
"STDDEV_POP" => Ok(AggregateFunction::StddevPop),
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 6d8a0c3057..e397f35459 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1923,7 +1923,7 @@ pub enum AggregateFunction {
Count = 4,
ApproxDistinct = 5,
ArrayAgg = 6,
- Variance = 7,
+ /// VARIANCE = 7;
VariancePop = 8,
/// COVARIANCE = 9;
/// COVARIANCE_POP = 10;
@@ -1966,7 +1966,6 @@ impl AggregateFunction {
AggregateFunction::Count => "COUNT",
AggregateFunction::ApproxDistinct => "APPROX_DISTINCT",
AggregateFunction::ArrayAgg => "ARRAY_AGG",
- AggregateFunction::Variance => "VARIANCE",
AggregateFunction::VariancePop => "VARIANCE_POP",
AggregateFunction::Stddev => "STDDEV",
AggregateFunction::StddevPop => "STDDEV_POP",
@@ -2005,7 +2004,6 @@ impl AggregateFunction {
"COUNT" => Some(Self::Count),
"APPROX_DISTINCT" => Some(Self::ApproxDistinct),
"ARRAY_AGG" => Some(Self::ArrayAgg),
- "VARIANCE" => Some(Self::Variance),
"VARIANCE_POP" => Some(Self::VariancePop),
"STDDEV" => Some(Self::Stddev),
"STDDEV_POP" => Some(Self::StddevPop),
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index e2a2f875ea..f8a78bdbdc 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -149,7 +149,6 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::Count => Self::Count,
protobuf::AggregateFunction::ApproxDistinct =>
Self::ApproxDistinct,
protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg,
- protobuf::AggregateFunction::Variance => Self::Variance,
protobuf::AggregateFunction::VariancePop => Self::VariancePop,
protobuf::AggregateFunction::Stddev => Self::Stddev,
protobuf::AggregateFunction::StddevPop => Self::StddevPop,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index d2783305f6..15d0d6dd49 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -120,7 +120,6 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::Count => Self::Count,
AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
AggregateFunction::ArrayAgg => Self::ArrayAgg,
- AggregateFunction::Variance => Self::Variance,
AggregateFunction::VariancePop => Self::VariancePop,
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
@@ -418,7 +417,6 @@ pub fn serialize_expr(
AggregateFunction::BoolOr =>
protobuf::AggregateFunction::BoolOr,
AggregateFunction::Avg => protobuf::AggregateFunction::Avg,
AggregateFunction::Count =>
protobuf::AggregateFunction::Count,
- AggregateFunction::Variance =>
protobuf::AggregateFunction::Variance,
AggregateFunction::VariancePop => {
protobuf::AggregateFunction::VariancePop
}
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index 0714636141..834f59abb1 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -29,7 +29,7 @@ use datafusion::physical_plan::expressions::{
DistinctCount, DistinctSum, Grouping, InListExpr, IsNotNullExpr,
IsNullExpr, Literal,
Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile,
OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, Stddev,
StddevPop,
- StringAgg, Sum, TryCastExpr, Variance, VariancePop, WindowShift,
+ StringAgg, Sum, TryCastExpr, VariancePop, WindowShift,
};
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion::physical_plan::windows::{BuiltInWindowExpr,
PlainAggregateWindowExpr};
@@ -281,8 +281,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
protobuf::AggregateFunction::Max
} else if aggr_expr.downcast_ref::<Avg>().is_some() {
protobuf::AggregateFunction::Avg
- } else if aggr_expr.downcast_ref::<Variance>().is_some() {
- protobuf::AggregateFunction::Variance
} else if aggr_expr.downcast_ref::<VariancePop>().is_some() {
protobuf::AggregateFunction::VariancePop
} else if aggr_expr.downcast_ref::<Stddev>().is_some() {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 14d7227480..deae97fecc 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -31,8 +31,9 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::execution::FunctionRegistry;
-use datafusion::functions_aggregate::expr_fn::{covar_pop, covar_samp,
first_value};
-use datafusion::functions_aggregate::median::median;
+use datafusion::functions_aggregate::expr_fn::{
+ covar_pop, covar_samp, first_value, median, var_sample,
+};
use datafusion::prelude::*;
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::config::{FormatOptions, TableOptions};
@@ -651,6 +652,7 @@ async fn roundtrip_expr_api() -> Result<()> {
covar_pop(lit(1.5), lit(2.2)),
sum(lit(1)),
median(lit(2)),
+ var_sample(lit(2.2)),
];
// ensure expressions created with the expr api can be round tripped
diff --git a/datafusion/sqllogictest/Cargo.toml
b/datafusion/sqllogictest/Cargo.toml
index c652c8041f..3b1f0dfd6d 100644
--- a/datafusion/sqllogictest/Cargo.toml
+++ b/datafusion/sqllogictest/Cargo.toml
@@ -40,7 +40,7 @@ bigdecimal = { workspace = true }
bytes = { workspace = true, optional = true }
chrono = { workspace = true, optional = true }
clap = { version = "4.4.8", features = ["derive", "env"] }
-datafusion = { workspace = true, default-features = true }
+datafusion = { workspace = true, default-features = true, features = ["avro"] }
datafusion-common = { workspace = true, default-features = true }
datafusion-common-runtime = { workspace = true, default-features = true }
futures = { workspace = true }
@@ -60,7 +60,13 @@ tokio-postgres = { version = "0.7.7", optional = true }
[features]
avro = ["datafusion/avro"]
-postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types",
"postgres-protocol"]
+postgres = [
+ "bytes",
+ "chrono",
+ "tokio-postgres",
+ "postgres-types",
+ "postgres-protocol",
+]
[dev-dependencies]
env_logger = { workspace = true }
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 98e64b025b..56ec034257 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -2338,6 +2338,18 @@ select covar_pop(c1, c2), arrow_typeof(covar_pop(c1,
c2)) from t;
statement ok
drop table t;
+# variance_f64_1
+statement ok
+create table t (c double) as values (1), (2), (3), (4), (5);
+
+query RT
+select var(c), arrow_typeof(var(c)) from t;
+----
+2.5 Float64
+
+statement ok
+drop table t;
+
# aggregate stddev f64_1
statement ok
create table t (c1 double) as values (1), (2);
@@ -2494,6 +2506,18 @@ select var(c1), arrow_typeof(var(c1)) from t;
statement ok
drop table t;
+# variance_f64_2
+statement ok
+create table t (c double) as values (1.1), (2), (3);
+
+query RT
+select var(c), arrow_typeof(var(c)) from t;
+----
+0.903333333333 Float64
+
+statement ok
+drop table t;
+
# aggregate variance f64_4
statement ok
create table t (c1 double) as values (1.1), (2), (3);
@@ -2506,6 +2530,30 @@ select var(c1), arrow_typeof(var(c1)) from t;
statement ok
drop table t;
+# variance_1_input
+statement ok
+create table t (a double not null) as values (1);
+
+query RT
+select var(a), arrow_typeof(var(a)) from t;
+----
+NULL Float64
+
+statement ok
+drop table t;
+
+# variance_i32_all_nulls
+statement ok
+create table t (a int) as values (null), (null);
+
+query RT
+select var(a), arrow_typeof(var(a)) from t;
+----
+NULL Float64
+
+statement ok
+drop table t;
+
# aggregate variance i32
statement ok
create table t (c1 int) as values (1), (2), (3), (4), (5);
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index ce738c7a6f..1fd8b0a346 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -518,4 +518,3 @@ set datafusion.optimizer.prefer_hash_join = true;
statement ok
set datafusion.execution.batch_size = 8192;
-
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]