This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a9561a0f06 Add regr_slope() aggregate function (#7135)
a9561a0f06 is described below
commit a9561a0f06c25f370dc39df08d057db85c4e0c7a
Author: Yongting You <[email protected]>
AuthorDate: Tue Aug 1 13:33:23 2023 -0700
Add regr_slope() aggregate function (#7135)
---
.../tests/sqllogictests/test_files/aggregate.slt | 158 ++++++++++
datafusion/expr/src/aggregate_function.rs | 13 +-
datafusion/expr/src/type_coercion/aggregates.rs | 35 +--
datafusion/physical-expr/src/aggregate/build_in.rs | 11 +
datafusion/physical-expr/src/aggregate/mod.rs | 1 +
.../physical-expr/src/aggregate/regr_slope.rs | 331 +++++++++++++++++++++
datafusion/physical-expr/src/expressions/mod.rs | 1 +
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/logical_plan/from_proto.rs | 1 +
datafusion/proto/src/logical_plan/to_proto.rs | 4 +
docs/source/user-guide/sql/aggregate_functions.md | 17 ++
13 files changed, 550 insertions(+), 29 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 2f6f44c56b..0e3f337071 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -2290,3 +2290,161 @@ true
false
true
NULL
+
+
+
+#
+# regr_slope() tests
+#
+
+# invalid input
+statement error
+select regr_slope();
+
+statement error
+select regr_slope(*);
+
+statement error
+select regr_slope(*) from aggregate_test_100;
+
+statement error
+select regr_slope(1);
+
+statement error
+select regr_slope(1,2,3);
+
+statement error
+select regr_slope(1, 'foo');
+
+statement error
+select regr_slope('foo', 1);
+
+statement error
+select regr_slope('foo', 'bar');
+
+
+
+# regr_slope() NULL result
+query R
+select regr_slope(1,1);
+----
+NULL
+
+query R
+select regr_slope(1, NULL);
+----
+NULL
+
+query R
+select regr_slope(NULL, 1);
+----
+NULL
+
+query R
+select regr_slope(NULL, NULL);
+----
+NULL
+
+query R
+select regr_slope(column2, column1) from (values (1,2), (1,4), (1,6));
+----
+NULL
+
+
+
+# regr_slope() basic tests
+query R
+select regr_slope(column2, column1) from (values (1,2), (2,4), (3,6));
+----
+2
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+
+
+# regr_slope() ignore NULLs
+query R
+select regr_slope(column2, column1) from (values (1,NULL), (2,4), (3,6));
+----
+2
+
+query R
+select regr_slope(column2, column1) from (values (1,NULL), (NULL,4), (3,6));
+----
+NULL
+
+query R
+select regr_slope(column2, column1) from (values (1,NULL), (NULL,4),
(NULL,NULL));
+----
+NULL
+
+query TR rowsort
+select column3, regr_slope(column2, column1)
+from (values (1,2,'a'), (2,4,'a'), (1,3,'b'), (3,9,'b'), (1,10,'c'),
(NULL,100,'c'))
+group by column3;
+----
+a 2
+b 3
+c NULL
+
+
+
+# regr_slope() testing merge_batch() from RegrSlopeAccumulator's internal
implementation
+statement ok
+set datafusion.execution.batch_size = 1;
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+statement ok
+set datafusion.execution.batch_size = 2;
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+statement ok
+set datafusion.execution.batch_size = 3;
+
+query R
+select regr_slope(c12, c11) from aggregate_test_100;
+----
+0.051534002628
+
+statement ok
+set datafusion.execution.batch_size = 8192;
+
+
+
+# regr_slope testing retract_batch() from RegrSlopeAccumulator's internal
implementation
+query R
+select regr_slope(column2, column1)
+over (order by column1 rows between 2 preceding and current row)
+from (values (1,2), (2,4), (3,6), (4,12), (5,15), (6, 18));
+----
+NULL
+2
+2
+4
+4.5
+3
+
+query R
+select regr_slope(column2, column1)
+over (order by column1 rows between 2 preceding and current row)
+from (values (1,2), (2,4), (3,6), (3, NULL), (4, NULL), (5,15), (6,18), (7,
21));
+----
+NULL
+2
+2
+2
+NULL
+NULL
+3
+3
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index add1262237..ac0ac3079e 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -61,6 +61,8 @@ pub enum AggregateFunction {
CovariancePop,
/// Correlation
Correlation,
+ /// Slope from linear regression
+ RegrSlope,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
@@ -102,6 +104,7 @@ impl AggregateFunction {
Covariance => "COVARIANCE",
CovariancePop => "COVARIANCE_POP",
Correlation => "CORRELATION",
+ RegrSlope => "REGR_SLOPE",
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
ApproxPercentileContWithWeight =>
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
ApproxMedian => "APPROX_MEDIAN",
@@ -152,6 +155,7 @@ impl FromStr for AggregateFunction {
"var" => AggregateFunction::Variance,
"var_pop" => AggregateFunction::VariancePop,
"var_samp" => AggregateFunction::Variance,
+ "regr_slope" => AggregateFunction::RegrSlope,
// approximate
"approx_distinct" => AggregateFunction::ApproxDistinct,
"approx_median" => AggregateFunction::ApproxMedian,
@@ -228,6 +232,7 @@ impl AggregateFunction {
}
AggregateFunction::Stddev =>
stddev_return_type(&coerced_data_types[0]),
AggregateFunction::StddevPop =>
stddev_return_type(&coerced_data_types[0]),
+ AggregateFunction::RegrSlope => Ok(DataType::Float64),
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg =>
Ok(DataType::List(Arc::new(Field::new(
"item",
@@ -311,10 +316,10 @@ impl AggregateFunction {
| AggregateFunction::LastValue => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
- AggregateFunction::Covariance | AggregateFunction::CovariancePop
=> {
- Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
- }
- AggregateFunction::Correlation => {
+ AggregateFunction::Covariance
+ | AggregateFunction::CovariancePop
+ | AggregateFunction::Correlation
+ | AggregateFunction::RegrSlope => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => {
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index dec2eb7f12..95ca6ab718 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -148,7 +148,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
- AggregateFunction::Variance => {
+ AggregateFunction::Variance | AggregateFunction::VariancePop => {
if !is_variance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
@@ -157,16 +157,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
- AggregateFunction::VariancePop => {
- if !is_variance_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Covariance => {
+ AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
@@ -175,16 +166,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
- AggregateFunction::CovariancePop => {
- if !is_covariance_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Stddev => {
+ AggregateFunction::Stddev | AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
@@ -193,8 +175,8 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
- AggregateFunction::StddevPop => {
- if !is_stddev_support_arg_type(&input_types[0]) {
+ AggregateFunction::Correlation => {
+ if !is_correlation_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
@@ -202,8 +184,11 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
- AggregateFunction::Correlation => {
- if !is_correlation_support_arg_type(&input_types[0]) {
+ AggregateFunction::RegrSlope => {
+ let valid_types = [NUMERICS.to_vec(),
vec![DataType::Null]].concat();
+ let input_types_valid = // number of input already checked before
+ valid_types.contains(&input_types[0]) &&
valid_types.contains(&input_types[1]);
+ if !input_types_valid {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 4dc7c824ed..45c98b0187 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -248,6 +248,17 @@ pub fn create_aggregate_expr(
"CORR(DISTINCT) aggregations are not available".to_string(),
));
}
+ (AggregateFunction::RegrSlope, false) =>
Arc::new(expressions::RegrSlope::new(
+ input_phy_exprs[0].clone(),
+ input_phy_exprs[1].clone(),
+ name,
+ rt_type,
+ )),
+ (AggregateFunction::RegrSlope, true) => {
+ return Err(DataFusionError::NotImplemented(
+ "REGR_SLOPE(DISTINCT) aggregations are not
available".to_string(),
+ ));
+ }
(AggregateFunction::ApproxPercentileCont, false) => {
if input_phy_exprs.len() == 2 {
Arc::new(expressions::ApproxPercentileCont::new(
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 5490b87576..0d0abca062 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -49,6 +49,7 @@ pub mod build_in;
pub(crate) mod groups_accumulator;
mod hyperloglog;
pub mod moving_min_max;
+pub(crate) mod regr_slope;
pub(crate) mod stats;
pub(crate) mod stddev;
pub(crate) mod sum;
diff --git a/datafusion/physical-expr/src/aggregate/regr_slope.rs
b/datafusion/physical-expr/src/aggregate/regr_slope.rs
new file mode 100644
index 0000000000..fce9627b04
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/regr_slope.rs
@@ -0,0 +1,331 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query
execution
+
+use std::any::Any;
+use std::sync::Arc;
+
+use crate::{AggregateExpr, PhysicalExpr};
+use arrow::array::Float64Array;
+use arrow::{
+ array::{ArrayRef, UInt64Array},
+ compute::cast,
+ datatypes::DataType,
+ datatypes::Field,
+};
+use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue};
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::Accumulator;
+
+use crate::aggregate::utils::down_cast_any_ref;
+use crate::expressions::format_state_name;
+
+/// regr_slope aggregate expression
+/// Returns the slope of the linear regression line for non-null pairs in
aggregate columns
+/// Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y =
k*X + b) using minimal
+/// RSS fitting.
+#[derive(Debug)]
+pub struct RegrSlope {
+ name: String,
+ expr_y: Arc<dyn PhysicalExpr>,
+ expr_x: Arc<dyn PhysicalExpr>,
+}
+
+impl RegrSlope {
+ pub fn new(
+ expr_y: Arc<dyn PhysicalExpr>,
+ expr_x: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ return_type: DataType,
+ ) -> Self {
+ // the result of regr_slope only support FLOAT64 data type.
+ assert!(matches!(return_type, DataType::Float64));
+ Self {
+ name: name.into(),
+ expr_y,
+ expr_x,
+ }
+ }
+}
+
+impl AggregateExpr for RegrSlope {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, DataType::Float64, true))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(RegrSlopeAccumulator::try_new()?))
+ }
+
+ fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(RegrSlopeAccumulator::try_new()?))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(
+ format_state_name(&self.name, "count"),
+ DataType::UInt64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "mean_x"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "mean_y"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "m2_x"),
+ DataType::Float64,
+ true,
+ ),
+ Field::new(
+ format_state_name(&self.name, "algo_const"),
+ DataType::Float64,
+ true,
+ ),
+ ])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr_y.clone(), self.expr_x.clone()]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+impl PartialEq<dyn Any> for RegrSlope {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| {
+ self.name == x.name
+ && self.expr_y.eq(&x.expr_y)
+ && self.expr_x.eq(&x.expr_x)
+ })
+ .unwrap_or(false)
+ }
+}
+
+// regr_slope(y, x) is calculated using cov_pop(x, y)/var_pop(x)
+// Reference of online algorithms for calculationg variance:
+//
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
+#[derive(Debug)]
+pub struct RegrSlopeAccumulator {
+ count: u64,
+ mean_x: f64,
+ mean_y: f64,
+ m2_x: f64,
+ algo_const: f64,
+}
+
+impl RegrSlopeAccumulator {
+ /// Creates a new `RegrSlopeAccumulator`
+ pub fn try_new() -> Result<Self> {
+ Ok(Self {
+ count: 0_u64,
+ mean_x: 0_f64,
+ mean_y: 0_f64,
+ m2_x: 0_f64,
+ algo_const: 0_f64,
+ })
+ }
+}
+
+impl Accumulator for RegrSlopeAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![
+ ScalarValue::from(self.count),
+ ScalarValue::from(self.mean_x),
+ ScalarValue::from(self.mean_y),
+ ScalarValue::from(self.m2_x),
+ ScalarValue::from(self.algo_const),
+ ])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ // regr_slope(Y, X) calculates k in y = k*x + b
+ let values_y = &cast(&values[0], &DataType::Float64)?;
+ let values_x = &cast(&values[1], &DataType::Float64)?;
+
+ let mut arr_y = downcast_value!(values_y,
Float64Array).iter().flatten();
+ let mut arr_x = downcast_value!(values_x,
Float64Array).iter().flatten();
+
+ for i in 0..values_y.len() {
+ // skip either x or y is NULL
+ let value_y = if values_y.is_valid(i) {
+ arr_y.next()
+ } else {
+ None
+ };
+ let value_x = if values_x.is_valid(i) {
+ arr_x.next()
+ } else {
+ None
+ };
+ if value_y.is_none() || value_x.is_none() {
+ continue;
+ }
+
+ // Update states for regr_slope(y,x) [using
cov_pop(x,y)/var_pop(x)]
+ let value_y = unwrap_or_internal_err!(value_y);
+ let value_x = unwrap_or_internal_err!(value_x);
+
+ self.count += 1;
+ let delta_x = value_x - self.mean_x;
+ let delta_y = value_y - self.mean_y;
+ self.mean_x += delta_x / self.count as f64;
+ let delta_x_2 = value_x - self.mean_x;
+ self.m2_x += delta_x * delta_x_2;
+ self.mean_y += delta_y / self.count as f64;
+ self.algo_const += delta_x * (value_y - self.mean_y);
+ }
+
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values_y = &cast(&values[0], &DataType::Float64)?;
+ let values_x = &cast(&values[1], &DataType::Float64)?;
+
+ let mut arr_y = downcast_value!(values_y,
Float64Array).iter().flatten();
+ let mut arr_x = downcast_value!(values_x,
Float64Array).iter().flatten();
+
+ for i in 0..values_y.len() {
+ // skip either x or y is NULL
+ let value_y = if values_y.is_valid(i) {
+ arr_y.next()
+ } else {
+ None
+ };
+ let value_x = if values_x.is_valid(i) {
+ arr_x.next()
+ } else {
+ None
+ };
+ if value_y.is_none() || value_x.is_none() {
+ continue;
+ }
+
+ // Update states for regr_slope(y,x) [using
cov_pop(x,y)/var_pop(x)]
+ let value_y = unwrap_or_internal_err!(value_y);
+ let value_x = unwrap_or_internal_err!(value_x);
+
+ if self.count > 1 {
+ self.count -= 1;
+ let delta_x = value_x - self.mean_x;
+ let delta_y = value_y - self.mean_y;
+ self.mean_x -= delta_x / self.count as f64;
+ let delta_x_2 = value_x - self.mean_x;
+ self.m2_x -= delta_x * delta_x_2;
+ self.mean_y -= delta_y / self.count as f64;
+ self.algo_const -= delta_x * (value_y - self.mean_y);
+ } else {
+ self.count = 0;
+ self.mean_x = 0.0;
+ self.m2_x = 0.0;
+ self.mean_y = 0.0;
+ self.algo_const = 0.0;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let count_arr = downcast_value!(states[0], UInt64Array);
+ let mean_x_arr = downcast_value!(states[1], Float64Array);
+ let mean_y_arr = downcast_value!(states[2], Float64Array);
+ let m2_x_arr = downcast_value!(states[3], Float64Array);
+ let algo_const_arr = downcast_value!(states[4], Float64Array);
+
+ for i in 0..count_arr.len() {
+ let count_b = count_arr.value(i);
+ if count_b == 0_u64 {
+ continue;
+ }
+ let (count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a) = (
+ self.count,
+ self.mean_x,
+ self.mean_y,
+ self.m2_x,
+ self.algo_const,
+ );
+ let (count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b) = (
+ count_b,
+ mean_x_arr.value(i),
+ mean_y_arr.value(i),
+ m2_x_arr.value(i),
+ algo_const_arr.value(i),
+ );
+
+ // Assuming two different batches of input have calculated the
states:
+ // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a,
algo_const_a}
+ // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b,
algo_const_b}
+ // The merged states from A and B are {count_ab, mean_x_ab,
mean_y_ab, m2_x_ab,
+ // algo_const_ab}
+ //
+ // Reference for the algorithm to merge states:
+ //
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+ let count_ab = count_a + count_b;
+ let (count_a, count_b) = (count_a as f64, count_b as f64);
+ let d_x = mean_x_b - mean_x_a;
+ let d_y = mean_y_b - mean_y_a;
+ let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64;
+ let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64;
+ let m2_x_ab =
+ m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as
f64;
+ let algo_const_ab = algo_const_a
+ + algo_const_b
+ + d_x * d_y * count_a * count_b / count_ab as f64;
+
+ self.count = count_ab;
+ self.mean_x = mean_x_ab;
+ self.mean_y = mean_y_ab;
+ self.m2_x = m2_x_ab;
+ self.algo_const = algo_const_ab;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ let cov_pop_x_y = self.algo_const / self.count as f64;
+ let var_pop_x = self.m2_x / self.count as f64;
+
+ // Only 0/1 point or slope is infinite
+ if self.count <= 1 || var_pop_x == 0.0 {
+ Ok(ScalarValue::Float64(None))
+ } else {
+ Ok(ScalarValue::Float64(Some(cov_pop_x_y / var_pop_x)))
+ }
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index c660cfadcc..c56c63db7b 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -60,6 +60,7 @@ pub use crate::aggregate::grouping::Grouping;
pub use crate::aggregate::median::Median;
pub use crate::aggregate::min_max::{Max, Min};
pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator};
+pub use crate::aggregate::regr_slope::RegrSlope;
pub use crate::aggregate::stats::StatsType;
pub use crate::aggregate::stddev::{Stddev, StddevPop};
pub use crate::aggregate::sum::Sum;
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index e9ae76b25d..9694a5beb7 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -613,6 +613,7 @@ enum AggregateFunction {
// we append "_AGG" to obey name scoping rules.
FIRST_VALUE_AGG = 24;
LAST_VALUE_AGG = 25;
+ REGR_SLOPE = 26;
}
message AggregateExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index a5d85cc6cf..40f58b312a 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -465,6 +465,7 @@ impl serde::Serialize for AggregateFunction {
Self::BoolOr => "BOOL_OR",
Self::FirstValueAgg => "FIRST_VALUE_AGG",
Self::LastValueAgg => "LAST_VALUE_AGG",
+ Self::RegrSlope => "REGR_SLOPE",
};
serializer.serialize_str(variant)
}
@@ -502,6 +503,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"BOOL_OR",
"FIRST_VALUE_AGG",
"LAST_VALUE_AGG",
+ "REGR_SLOPE",
];
struct GeneratedVisitor;
@@ -570,6 +572,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"BOOL_OR" => Ok(AggregateFunction::BoolOr),
"FIRST_VALUE_AGG" => Ok(AggregateFunction::FirstValueAgg),
"LAST_VALUE_AGG" => Ok(AggregateFunction::LastValueAgg),
+ "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index c6f3a23ed6..7e4a5f8afd 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2578,6 +2578,7 @@ pub enum AggregateFunction {
/// we append "_AGG" to obey name scoping rules.
FirstValueAgg = 24,
LastValueAgg = 25,
+ RegrSlope = 26,
}
impl AggregateFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2614,6 +2615,7 @@ impl AggregateFunction {
AggregateFunction::BoolOr => "BOOL_OR",
AggregateFunction::FirstValueAgg => "FIRST_VALUE_AGG",
AggregateFunction::LastValueAgg => "LAST_VALUE_AGG",
+ AggregateFunction::RegrSlope => "REGR_SLOPE",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2647,6 +2649,7 @@ impl AggregateFunction {
"BOOL_OR" => Some(Self::BoolOr),
"FIRST_VALUE_AGG" => Some(Self::FirstValueAgg),
"LAST_VALUE_AGG" => Some(Self::LastValueAgg),
+ "REGR_SLOPE" => Some(Self::RegrSlope),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 1464f32bb3..4caff5fba0 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -549,6 +549,7 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::Stddev => Self::Stddev,
protobuf::AggregateFunction::StddevPop => Self::StddevPop,
protobuf::AggregateFunction::Correlation => Self::Correlation,
+ protobuf::AggregateFunction::RegrSlope => Self::RegrSlope,
protobuf::AggregateFunction::ApproxPercentileCont => {
Self::ApproxPercentileCont
}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index df5701a282..3f4fdfeb74 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -384,6 +384,7 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::Correlation => Self::Correlation,
+ AggregateFunction::RegrSlope => Self::RegrSlope,
AggregateFunction::ApproxPercentileCont =>
Self::ApproxPercentileCont,
AggregateFunction::ApproxPercentileContWithWeight => {
Self::ApproxPercentileContWithWeight
@@ -675,6 +676,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
AggregateFunction::Correlation => {
protobuf::AggregateFunction::Correlation
}
+ AggregateFunction::RegrSlope => {
+ protobuf::AggregateFunction::RegrSlope
+ }
AggregateFunction::ApproxMedian => {
protobuf::AggregateFunction::ApproxMedian
}
diff --git a/docs/source/user-guide/sql/aggregate_functions.md
b/docs/source/user-guide/sql/aggregate_functions.md
index 132ba47e24..71168b0622 100644
--- a/docs/source/user-guide/sql/aggregate_functions.md
+++ b/docs/source/user-guide/sql/aggregate_functions.md
@@ -245,6 +245,7 @@ last_value(expression [ORDER BY expression])
- [var](#var)
- [var_pop](#var_pop)
- [var_samp](#var_samp)
+- [regr_slope](#regr_slope)
### `corr`
@@ -384,6 +385,22 @@ var_samp(expression)
- **expression**: Expression to operate on.
Can be a constant, column, or function, and any combination of arithmetic
operators.
+### `regr_slope`
+
+Returns the slope of the linear regression line for non-null pairs in
aggregate columns.
+Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X
+ b) using minimal RSS fitting.
+
+```
+regr_slope(expression1, expression2)
+```
+
+#### Arguments
+
+- **expression1**: Expression to operate on.
+ Can be a constant, column, or function, and any combination of arithmetic
operators.
+- **expression2**: Expression to operate on.
+ Can be a constant, column, or function, and any combination of arithmetic
operators.
+
## Approximate
- [approx_distinct](#approx_distinct)