This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 245def059 Implement exact median, add `AggregateState` (#3009)
245def059 is described below
commit 245def05940b1fea69d1b75df8a928efb39fc3af
Author: Andy Grove <[email protected]>
AuthorDate: Fri Aug 5 13:56:52 2022 -0600
Implement exact median, add `AggregateState` (#3009)
* Implement exact median
* revert some changes
* toml format
* add median to protobuf
* remove some unwraps
* remove some unwraps
* remove some unwraps
* fix
* clippy
* reduce code duplication
* reduce code duplication
* more tests
* move tests to simplify github diff
* Update datafusion/expr/src/accumulator.rs
Co-authored-by: Andrew Lamb <[email protected]>
* refactor to make it more obvious that empty arrays are being created
* partially address feedback
* Update datafusion/physical-expr/src/aggregate/count_distinct.rs
Co-authored-by: Andrew Lamb <[email protected]>
* add more tests
* more docs
* clippy
* avoid a clone
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion-examples/examples/simple_udaf.rs | 7 +-
datafusion/common/Cargo.toml | 1 +
.../core/src/physical_plan/aggregates/hash.rs | 6 +-
datafusion/core/tests/sql/aggregates.rs | 189 +++++++++++++++-
datafusion/core/tests/sql/mod.rs | 6 +-
datafusion/expr/src/accumulator.rs | 45 +++-
datafusion/expr/src/aggregate_function.rs | 9 +-
datafusion/expr/src/lib.rs | 2 +-
.../physical-expr/src/aggregate/approx_distinct.rs | 6 +-
.../src/aggregate/approx_percentile_cont.rs | 11 +-
.../approx_percentile_cont_with_weight.rs | 4 +-
.../physical-expr/src/aggregate/array_agg.rs | 6 +-
.../src/aggregate/array_agg_distinct.rs | 8 +-
datafusion/physical-expr/src/aggregate/average.rs | 9 +-
datafusion/physical-expr/src/aggregate/build_in.rs | 10 +
.../physical-expr/src/aggregate/correlation.rs | 24 +-
datafusion/physical-expr/src/aggregate/count.rs | 8 +-
.../physical-expr/src/aggregate/count_distinct.rs | 13 +-
.../physical-expr/src/aggregate/covariance.rs | 20 +-
datafusion/physical-expr/src/aggregate/median.rs | 244 +++++++++++++++++++++
datafusion/physical-expr/src/aggregate/min_max.rs | 10 +-
datafusion/physical-expr/src/aggregate/mod.rs | 2 +
datafusion/physical-expr/src/aggregate/stddev.rs | 18 +-
datafusion/physical-expr/src/aggregate/sum.rs | 6 +-
.../physical-expr/src/aggregate/sum_distinct.rs | 11 +-
datafusion/physical-expr/src/aggregate/utils.rs | 48 ++++
datafusion/physical-expr/src/aggregate/variance.rs | 18 +-
datafusion/physical-expr/src/expressions/mod.rs | 1 +
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/from_proto.rs | 1 +
datafusion/proto/src/lib.rs | 6 +-
datafusion/proto/src/to_proto.rs | 2 +
32 files changed, 645 insertions(+), 107 deletions(-)
diff --git a/datafusion-examples/examples/simple_udaf.rs
b/datafusion-examples/examples/simple_udaf.rs
index 5e0f41bc8..378d2548e 100644
--- a/datafusion-examples/examples/simple_udaf.rs
+++ b/datafusion-examples/examples/simple_udaf.rs
@@ -23,6 +23,7 @@ use datafusion::arrow::{
};
use datafusion::from_slice::FromSlice;
+use datafusion::logical_expr::AggregateState;
use datafusion::{error::Result, logical_plan::create_udaf,
physical_plan::Accumulator};
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
use std::sync::Arc;
@@ -107,10 +108,10 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion
uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
- ScalarValue::from(self.prod),
- ScalarValue::from(self.n),
+ AggregateState::Scalar(ScalarValue::from(self.prod)),
+ AggregateState::Scalar(ScalarValue::from(self.n)),
])
}
diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml
index e7873de4d..33d6af087 100644
--- a/datafusion/common/Cargo.toml
+++ b/datafusion/common/Cargo.toml
@@ -45,4 +45,5 @@ object_store = { version = "0.3", optional = true }
ordered-float = "3.0"
parquet = { version = "19.0.0", features = ["arrow"], optional = true }
pyo3 = { version = "0.16", optional = true }
+serde_json = "1.0"
sqlparser = "0.19"
diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs
b/datafusion/core/src/physical_plan/aggregates/hash.rs
index c21109495..54806d37f 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -428,8 +428,10 @@ fn create_batch_from_map(
AggregateMode::Partial => {
let res = ScalarValue::iter_to_array(
accumulators.group_states.iter().map(|group_state| {
- let x =
group_state.accumulator_set[x].state().unwrap();
- x[y].clone()
+ group_state.accumulator_set[x]
+ .state()
+ .and_then(|x| x[y].as_scalar().map(|v|
v.clone()))
+ .expect("unexpected accumulator state in hash
aggregate")
}),
)?;
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index 02d4b3a4d..eb0e07f84 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> {
}
#[tokio::test]
-async fn csv_query_median_1() -> Result<()> {
+async fn csv_query_approx_median_1() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c2) FROM aggregate_test_100";
@@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> {
}
#[tokio::test]
-async fn csv_query_median_2() -> Result<()> {
+async fn csv_query_approx_median_2() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c6) FROM aggregate_test_100";
@@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> {
}
#[tokio::test]
-async fn csv_query_median_3() -> Result<()> {
+async fn csv_query_approx_median_3() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT approx_median(c12) FROM aggregate_test_100";
@@ -253,6 +253,189 @@ async fn csv_query_median_3() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_median_1() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT median(c2) FROM aggregate_test_100";
+ let actual = execute(&ctx, sql).await;
+ let expected = vec![vec!["3"]];
+ assert_float_eq(&expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_median_2() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT median(c6) FROM aggregate_test_100";
+ let actual = execute(&ctx, sql).await;
+ let expected = vec![vec!["1125553990140691277"]];
+ assert_float_eq(&expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_median_3() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT median(c12) FROM aggregate_test_100";
+ let actual = execute(&ctx, sql).await;
+ let expected = vec![vec!["0.5513900544385053"]];
+ assert_float_eq(&expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn median_i8() -> Result<()> {
+ median_test(
+ "median",
+ DataType::Int8,
+ Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])),
+ "-14",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_i16() -> Result<()> {
+ median_test(
+ "median",
+ DataType::Int16,
+ Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])),
+ "-16334",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_i32() -> Result<()> {
+ median_test(
+ "median",
+ DataType::Int32,
+ Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])),
+ "-1073741774",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_i64() -> Result<()> {
+ median_test(
+ "median",
+ DataType::Int64,
+ Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])),
+ "-4611686018427388000",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_u8() -> Result<()> {
+ median_test(
+ "median",
+ DataType::UInt8,
+ Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])),
+ "50",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_u16() -> Result<()> {
+ median_test(
+ "median",
+ DataType::UInt16,
+ Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])),
+ "50",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_u32() -> Result<()> {
+ median_test(
+ "median",
+ DataType::UInt32,
+ Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])),
+ "50",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_u64() -> Result<()> {
+ median_test(
+ "median",
+ DataType::UInt64,
+ Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])),
+ "50",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_f32() -> Result<()> {
+ median_test(
+ "median",
+ DataType::Float32,
+ Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
+ "3.3",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_f64() -> Result<()> {
+ median_test(
+ "median",
+ DataType::Float64,
+ Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
+ "3.3",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn median_f64_nan() -> Result<()> {
+ median_test(
+ "median",
+ DataType::Float64,
+ Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
+ "NaN", // probably not the desired behavior? - see
https://github.com/apache/arrow-datafusion/issues/3039
+ )
+ .await
+}
+
+#[tokio::test]
+async fn approx_median_f64_nan() -> Result<()> {
+ median_test(
+ "approx_median",
+ DataType::Float64,
+ Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
+ "NaN", // probably not the desired behavior? - see
https://github.com/apache/arrow-datafusion/issues/3039
+ )
+ .await
+}
+
+async fn median_test(
+ func: &str,
+ data_type: DataType,
+ values: ArrayRef,
+ expected: &str,
+) -> Result<()> {
+ let ctx = SessionContext::new();
+ let schema = Arc::new(Schema::new(vec![Field::new("a", data_type,
false)]));
+ let batch = RecordBatch::try_new(schema.clone(), vec![values])?;
+ let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
+ ctx.register_table("t", table)?;
+ let sql = format!("SELECT {}(a) FROM t", func);
+ let actual = execute(&ctx, &sql).await;
+ let expected = vec![vec![expected.to_owned()]];
+ assert_float_eq(&expected, &actual);
+ Ok(())
+}
+
#[tokio::test]
async fn csv_query_external_table_count() {
let ctx = SessionContext::new();
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 847f19cf8..5481161d0 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -128,7 +128,11 @@ where
l.as_ref().parse::<f64>().unwrap(),
r.as_str().parse::<f64>().unwrap(),
);
- assert!((l - r).abs() <= 2.0 * f64::EPSILON);
+ if l.is_nan() || r.is_nan() {
+ assert!(l.is_nan() && r.is_nan());
+ } else if (l - r).abs() > 2.0 * f64::EPSILON {
+ panic!("{} != {}", l, r)
+ }
});
}
diff --git a/datafusion/expr/src/accumulator.rs
b/datafusion/expr/src/accumulator.rs
index d59764957..6c146bc30 100644
--- a/datafusion/expr/src/accumulator.rs
+++ b/datafusion/expr/src/accumulator.rs
@@ -18,7 +18,7 @@
//! Accumulator module contains the trait definition for aggregation
function's accumulators.
use arrow::array::ArrayRef;
-use datafusion_common::{Result, ScalarValue};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::fmt::Debug;
/// An accumulator represents a stateful object that lives throughout the
evaluation of multiple rows and
@@ -26,14 +26,14 @@ use std::fmt::Debug;
///
/// An accumulator knows how to:
/// * update its state from inputs via `update_batch`
-/// * convert its internal state to a vector of scalar values
+/// * convert its internal state to a vector of aggregate values
/// * update its state from multiple accumulators' states via `merge_batch`
/// * compute the final value from its internal state via `evaluate`
pub trait Accumulator: Send + Sync + Debug {
/// Returns the state of the accumulator at the end of the accumulation.
- // in the case of an average on which we track `sum` and `n`, this
function should return a vector
- // of two values, sum and n.
- fn state(&self) -> Result<Vec<ScalarValue>>;
+ /// in the case of an average on which we track `sum` and `n`, this
function should return a vector
+ /// of two values, sum and n.
+ fn state(&self) -> Result<Vec<AggregateState>>;
/// updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
@@ -44,3 +44,38 @@ pub trait Accumulator: Send + Sync + Debug {
/// returns its value based on its current state.
fn evaluate(&self) -> Result<ScalarValue>;
}
+
+/// Representation of internal accumulator state. Accumulators can potentially
have a mix of
+/// scalar and array values. It may be desirable to add custom aggregator
states here as well
+/// in the future (perhaps `Custom(Box<dyn Any>)`?).
+#[derive(Debug)]
+pub enum AggregateState {
+ /// Simple scalar value. Note that `ScalarValue::List` can be used to pass
multiple
+ /// values around
+ Scalar(ScalarValue),
+ /// Arrays can be used instead of `ScalarValue::List` and could
potentially have better
+ /// performance with large data sets, although this has not been verified.
It also allows
+ /// for use of arrow kernels with less overhead.
+ Array(ArrayRef),
+}
+
+impl AggregateState {
+ /// Access the aggregate state as a scalar value. An error will occur if
the
+ /// state is not a scalar value.
+ pub fn as_scalar(&self) -> Result<&ScalarValue> {
+ match &self {
+ Self::Scalar(v) => Ok(v),
+ _ => Err(DataFusionError::Internal(
+ "AggregateState is not a scalar aggregate".to_string(),
+ )),
+ }
+ }
+
+ /// Access the aggregate state as an array value.
+ pub fn to_array(&self) -> ArrayRef {
+ match &self {
+ Self::Scalar(v) => v.to_array(),
+ Self::Array(array) => array.clone(),
+ }
+ }
+}
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index 30bf0521d..09d759e56 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -62,6 +62,8 @@ pub enum AggregateFunction {
Max,
/// avg
Avg,
+ /// median
+ Median,
/// Approximate aggregate function
ApproxDistinct,
/// array_agg
@@ -107,6 +109,7 @@ impl FromStr for AggregateFunction {
"avg" => AggregateFunction::Avg,
"mean" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
+ "median" => AggregateFunction::Median,
"approx_distinct" => AggregateFunction::ApproxDistinct,
"array_agg" => AggregateFunction::ArrayAgg,
"var" => AggregateFunction::Variance,
@@ -175,7 +178,9 @@ pub fn return_type(
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
- AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
+ AggregateFunction::ApproxMedian | AggregateFunction::Median => {
+ Ok(coerced_data_types[0].clone())
+ }
AggregateFunction::Grouping => Ok(DataType::Int32),
}
}
@@ -330,6 +335,7 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
+ AggregateFunction::Median => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
}
@@ -358,6 +364,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
+ | AggregateFunction::Median
| AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index f71243610..90007a8bd 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -53,7 +53,7 @@ pub mod utils;
pub mod window_frame;
pub mod window_function;
-pub use accumulator::Accumulator;
+pub use accumulator::{Accumulator, AggregateState};
pub use aggregate_function::AggregateFunction;
pub use built_in_function::BuiltinScalarFunction;
pub use columnar_value::{ColumnarValue, NullColumnarValue};
diff --git a/datafusion/physical-expr/src/aggregate/approx_distinct.rs
b/datafusion/physical-expr/src/aggregate/approx_distinct.rs
index c67d1c9d3..5b391ed84 100644
--- a/datafusion/physical-expr/src/aggregate/approx_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_distinct.rs
@@ -30,7 +30,7 @@ use arrow::datatypes::{
};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use std::any::type_name;
use std::any::Any;
use std::convert::TryFrom;
@@ -232,8 +232,8 @@ macro_rules! default_accumulator_impl {
Ok(())
}
- fn state(&self) -> Result<Vec<ScalarValue>> {
- let value = ScalarValue::from(&self.hll);
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ let value = AggregateState::Scalar(ScalarValue::from(&self.hll));
Ok(vec![value])
}
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 2315ad1d5..41c6c72db 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -29,7 +29,7 @@ use arrow::{
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use ordered_float::OrderedFloat;
use std::{any::Any, iter, sync::Arc};
@@ -287,8 +287,13 @@ impl ApproxPercentileAccumulator {
}
impl Accumulator for ApproxPercentileAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(self.digest.to_scalar_state())
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(self
+ .digest
+ .to_scalar_state()
+ .into_iter()
+ .map(AggregateState::Scalar)
+ .collect())
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
index f9874b0a5..40a44c3a5 100644
---
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
+++
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
@@ -26,7 +26,7 @@ use arrow::{
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use std::{any::Any, sync::Arc};
@@ -114,7 +114,7 @@ impl ApproxPercentileWithWeightAccumulator {
}
impl Accumulator for ApproxPercentileWithWeightAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
self.approx_percentile_cont_accumulator.state()
}
diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs
b/datafusion/physical-expr/src/aggregate/array_agg.rs
index eaed89390..e7fd0937c 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg.rs
@@ -23,7 +23,7 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use std::any::Any;
use std::sync::Arc;
@@ -143,8 +143,8 @@ impl Accumulator for ArrayAggAccumulator {
})
}
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![self.evaluate()?])
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![AggregateState::Scalar(self.evaluate()?)])
}
fn evaluate(&self) -> Result<ScalarValue> {
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 44e24e93c..f9899379d 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -29,7 +29,7 @@ use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
#[derive(Debug)]
@@ -119,11 +119,11 @@ impl DistinctArrayAggAccumulator {
}
impl Accumulator for DistinctArrayAggAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![ScalarValue::List(
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![AggregateState::Scalar(ScalarValue::List(
Some(self.values.clone().into_iter().collect()),
Box::new(Field::new("item", self.datatype.clone(), true)),
- )])
+ ))])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/average.rs
b/datafusion/physical-expr/src/aggregate/average.rs
index 1b1d99525..a55e0e352 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -33,7 +33,7 @@ use arrow::{
};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use datafusion_row::accessor::RowAccessor;
/// AVG aggregate expression
@@ -150,8 +150,11 @@ impl AvgAccumulator {
}
impl Accumulator for AvgAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![
+ AggregateState::Scalar(ScalarValue::from(self.count)),
+ AggregateState::Scalar(self.sum.clone()),
+ ])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 23d2a84d1..8d76e35e4 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -251,6 +251,16 @@ pub fn create_aggregate_expr(
))
}
(AggregateFunction::ApproxMedian, true) => {
+ return Err(DataFusionError::NotImplemented(
+ "APPROX_MEDIAN(DISTINCT) aggregations are not
available".to_string(),
+ ));
+ }
+ (AggregateFunction::Median, false) =>
Arc::new(expressions::Median::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ return_type,
+ )),
+ (AggregateFunction::Median, true) => {
return Err(DataFusionError::NotImplemented(
"MEDIAN(DISTINCT) aggregations are not available".to_string(),
));
diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs
b/datafusion/physical-expr/src/aggregate/correlation.rs
index 94a820849..3bbea5d9b 100644
--- a/datafusion/physical-expr/src/aggregate/correlation.rs
+++ b/datafusion/physical-expr/src/aggregate/correlation.rs
@@ -25,7 +25,7 @@ use crate::{AggregateExpr, PhysicalExpr};
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use std::any::Any;
use std::sync::Arc;
@@ -133,14 +133,14 @@ impl CorrelationAccumulator {
}
impl Accumulator for CorrelationAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
- ScalarValue::from(self.covar.get_count()),
- ScalarValue::from(self.covar.get_mean1()),
- ScalarValue::from(self.stddev1.get_m2()),
- ScalarValue::from(self.covar.get_mean2()),
- ScalarValue::from(self.stddev2.get_m2()),
- ScalarValue::from(self.covar.get_algo_const()),
+ AggregateState::Scalar(ScalarValue::from(self.covar.get_count())),
+ AggregateState::Scalar(ScalarValue::from(self.covar.get_mean1())),
+ AggregateState::Scalar(ScalarValue::from(self.stddev1.get_m2())),
+ AggregateState::Scalar(ScalarValue::from(self.covar.get_mean2())),
+ AggregateState::Scalar(ScalarValue::from(self.stddev2.get_m2())),
+
AggregateState::Scalar(ScalarValue::from(self.covar.get_algo_const())),
])
}
@@ -191,6 +191,7 @@ impl Accumulator for CorrelationAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op2;
@@ -469,12 +470,7 @@ mod tests {
.collect::<Result<Vec<_>>>()?;
accum1.update_batch(&values1)?;
accum2.update_batch(&values2)?;
- let state2 = accum2
- .state()?
- .iter()
- .map(|v| vec![v.clone()])
- .map(|x| ScalarValue::iter_to_array(x).unwrap())
- .collect::<Vec<_>>();
+ let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
accum1.merge_batch(&state2)?;
accum1.evaluate()
}
diff --git a/datafusion/physical-expr/src/aggregate/count.rs
b/datafusion/physical-expr/src/aggregate/count.rs
index 2b02d03b5..982c1dc09 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -29,7 +29,7 @@ use arrow::datatypes::DataType;
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use datafusion_row::accessor::RowAccessor;
use crate::expressions::format_state_name;
@@ -134,8 +134,10 @@ impl Accumulator for CountAccumulator {
Ok(())
}
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![ScalarValue::Int64(Some(self.count))])
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
+ self.count,
+ )))])
}
fn evaluate(&self) -> Result<ScalarValue> {
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs
b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index 744d9b90d..6060ddb4d 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -28,7 +28,7 @@ use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct DistinctScalarValues(Vec<ScalarValue>);
@@ -177,7 +177,7 @@ impl Accumulator for DistinctCountAccumulator {
self.merge(&v)
})
}
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
let mut cols_out = self
.state_data_types
.iter()
@@ -206,7 +206,7 @@ impl Accumulator for DistinctCountAccumulator {
)
});
- Ok(cols_out)
+ Ok(cols_out.into_iter().map(AggregateState::Scalar).collect())
}
fn evaluate(&self) -> Result<ScalarValue> {
@@ -223,6 +223,7 @@ impl Accumulator for DistinctCountAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::aggregate::utils::get_accum_scalar_values;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
Int32Array,
Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array,
UInt64Array,
@@ -341,7 +342,7 @@ mod tests {
let mut accum = agg.create_accumulator()?;
accum.update_batch(arrays)?;
- Ok((accum.state()?, accum.evaluate()?))
+ Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
}
fn run_update(
@@ -372,7 +373,7 @@ mod tests {
accum.update_batch(&arrays)?;
- Ok((accum.state()?, accum.evaluate()?))
+ Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
}
fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>,
ScalarValue)> {
@@ -390,7 +391,7 @@ mod tests {
let mut accum = agg.create_accumulator()?;
accum.merge_batch(arrays)?;
- Ok((accum.state()?, accum.evaluate()?))
+ Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
}
// Used trait to create associated constant for f32 and f64
diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs
b/datafusion/physical-expr/src/aggregate/covariance.rs
index 1df002b48..9cd319127 100644
--- a/datafusion/physical-expr/src/aggregate/covariance.rs
+++ b/datafusion/physical-expr/src/aggregate/covariance.rs
@@ -30,7 +30,7 @@ use arrow::{
};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use crate::aggregate::stats::StatsType;
use crate::expressions::format_state_name;
@@ -237,12 +237,12 @@ impl CovarianceAccumulator {
}
impl Accumulator for CovarianceAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
- ScalarValue::from(self.count),
- ScalarValue::from(self.mean1),
- ScalarValue::from(self.mean2),
- ScalarValue::from(self.algo_const),
+ AggregateState::Scalar(ScalarValue::from(self.count)),
+ AggregateState::Scalar(ScalarValue::from(self.mean1)),
+ AggregateState::Scalar(ScalarValue::from(self.mean2)),
+ AggregateState::Scalar(ScalarValue::from(self.algo_const)),
])
}
@@ -352,6 +352,7 @@ impl Accumulator for CovarianceAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op2;
@@ -644,12 +645,7 @@ mod tests {
.collect::<Result<Vec<_>>>()?;
accum1.update_batch(&values1)?;
accum2.update_batch(&values2)?;
- let state2 = accum2
- .state()?
- .iter()
- .map(|v| vec![v.clone()])
- .map(|x| ScalarValue::iter_to_array(x).unwrap())
- .collect::<Vec<_>>();
+ let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
accum1.merge_batch(&state2)?;
accum1.evaluate()
}
diff --git a/datafusion/physical-expr/src/aggregate/median.rs
b/datafusion/physical-expr/src/aggregate/median.rs
new file mode 100644
index 000000000..6b68f2ec3
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/median.rs
@@ -0,0 +1,244 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! # Median
+
+use crate::expressions::format_state_name;
+use crate::{AggregateExpr, PhysicalExpr};
+use arrow::array::{Array, ArrayRef, PrimitiveArray, PrimitiveBuilder};
+use arrow::compute::sort;
+use arrow::datatypes::{
+ ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type,
Int32Type,
+ Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::{Accumulator, AggregateState};
+use std::any::Any;
+use std::sync::Arc;
+
+/// MEDIAN aggregate expression. This uses a lot of memory because all values
need to be
+/// stored in memory before a result can be computed. If an approximation is
sufficient
+/// then APPROX_MEDIAN provides a much more efficient solution.
+#[derive(Debug)]
+pub struct Median {
+ name: String,
+ expr: Arc<dyn PhysicalExpr>,
+ data_type: DataType,
+}
+
+impl Median {
+ /// Create a new MEDIAN aggregate function
+ pub fn new(
+ expr: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ data_type: DataType,
+ ) -> Self {
+ Self {
+ name: name.into(),
+ expr,
+ data_type,
+ }
+ }
+}
+
+impl AggregateExpr for Median {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, self.data_type.clone(), true))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(MedianAccumulator {
+ data_type: self.data_type.clone(),
+ all_values: vec![],
+ }))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![Field::new(
+ &format_state_name(&self.name, "median"),
+ self.data_type.clone(),
+ true,
+ )])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+#[derive(Debug)]
+struct MedianAccumulator {
+ data_type: DataType,
+ all_values: Vec<ArrayRef>,
+}
+
+macro_rules! median {
+ ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{
+ let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?;
+ if combined.is_empty() {
+ return Ok(ScalarValue::Null);
+ }
+ let sorted = sort(&combined, None)?;
+ let array = sorted
+ .as_any()
+ .downcast_ref::<PrimitiveArray<$TY>>()
+ .ok_or(DataFusionError::Internal(
+ "median! macro failed to cast array to expected
type".to_string(),
+ ))?;
+ let len = sorted.len();
+ let mid = len / 2;
+ if len % 2 == 0 {
+ Ok(ScalarValue::$SCALAR_TY(Some(
+ (array.value(mid - 1) + array.value(mid)) / $TWO,
+ )))
+ } else {
+ Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid))))
+ }
+ }};
+}
+
+impl Accumulator for MedianAccumulator {
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ let mut vec: Vec<AggregateState> = self
+ .all_values
+ .iter()
+ .map(|v| AggregateState::Array(v.clone()))
+ .collect();
+ if vec.is_empty() {
+ match self.data_type {
+ DataType::UInt8 => vec.push(empty_array::<UInt8Type>()),
+ DataType::UInt16 => vec.push(empty_array::<UInt16Type>()),
+ DataType::UInt32 => vec.push(empty_array::<UInt32Type>()),
+ DataType::UInt64 => vec.push(empty_array::<UInt64Type>()),
+ DataType::Int8 => vec.push(empty_array::<Int8Type>()),
+ DataType::Int16 => vec.push(empty_array::<Int16Type>()),
+ DataType::Int32 => vec.push(empty_array::<Int32Type>()),
+ DataType::Int64 => vec.push(empty_array::<Int64Type>()),
+ DataType::Float32 => vec.push(empty_array::<Float32Type>()),
+ DataType::Float64 => vec.push(empty_array::<Float64Type>()),
+ _ => {
+ return Err(DataFusionError::Execution(
+ "unsupported data type for median".to_string(),
+ ))
+ }
+ }
+ }
+ Ok(vec)
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let x = values[0].clone();
+ self.all_values.extend_from_slice(&[x]);
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ for array in states {
+ self.all_values.extend_from_slice(&[array.clone()]);
+ }
+ Ok(())
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ match self.all_values[0].data_type() {
+ DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8,
2),
+ DataType::Int16 => median!(self, arrow::datatypes::Int16Type,
Int16, 2),
+ DataType::Int32 => median!(self, arrow::datatypes::Int32Type,
Int32, 2),
+ DataType::Int64 => median!(self, arrow::datatypes::Int64Type,
Int64, 2),
+ DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type,
UInt8, 2),
+ DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type,
UInt16, 2),
+ DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type,
UInt32, 2),
+ DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type,
UInt64, 2),
+ DataType::Float32 => {
+ median!(self, arrow::datatypes::Float32Type, Float32, 2_f32)
+ }
+ DataType::Float64 => {
+ median!(self, arrow::datatypes::Float64Type, Float64, 2_f64)
+ }
+ _ => Err(DataFusionError::Execution(
+ "unsupported data type for median".to_string(),
+ )),
+ }
+ }
+}
+
+/// Create an empty array
+fn empty_array<T: ArrowPrimitiveType>() -> AggregateState {
+ AggregateState::Array(Arc::new(PrimitiveBuilder::<T>::new(0).finish()))
+}
+
+/// Combine all non-null values from provided arrays into a single array
+fn combine_arrays<T: ArrowPrimitiveType>(arrays: &[ArrayRef]) ->
Result<ArrayRef> {
+ let len = arrays.iter().map(|a| a.len() - a.null_count()).sum();
+ let mut builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(len);
+ for array in arrays {
+ let array = array
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(
+ "combine_arrays failed to cast array to expected
type".to_string(),
+ )
+ })?;
+ for i in 0..array.len() {
+ if !array.is_null(i) {
+ builder.append_value(array.value(i));
+ }
+ }
+ }
+ Ok(Arc::new(builder.finish()))
+}
+
+#[cfg(test)]
+mod test {
+ use crate::aggregate::median::combine_arrays;
+ use arrow::array::{Int32Array, UInt32Array};
+ use arrow::datatypes::{Int32Type, UInt32Type};
+ use datafusion_common::Result;
+ use std::sync::Arc;
+
+ #[test]
+ fn combine_i32_array() -> Result<()> {
+ let a = Arc::new(Int32Array::from(vec![1, 2, 3]));
+ let b = combine_arrays::<Int32Type>(&[a.clone(), a])?;
+ assert_eq!(
+ "PrimitiveArray<Int32>\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]",
+ format!("{:?}", b)
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn combine_u32_array() -> Result<()> {
+ let a = Arc::new(UInt32Array::from(vec![1, 2, 3]));
+ let b = combine_arrays::<UInt32Type>(&[a.clone(), a])?;
+ assert_eq!(
+ "PrimitiveArray<UInt32>\n[\n 1,\n 2,\n 3,\n 1,\n 2,\n 3,\n]",
+ format!("{:?}", b)
+ );
+ Ok(())
+ }
+}
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs
b/datafusion/physical-expr/src/aggregate/min_max.rs
index bd56973b1..077f4d725 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -36,7 +36,7 @@ use arrow::{
};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use crate::aggregate::row_accumulator::RowAccumulator;
use crate::expressions::format_state_name;
@@ -538,8 +538,8 @@ impl Accumulator for MaxAccumulator {
self.update_batch(states)
}
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![self.max.clone()])
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![AggregateState::Scalar(self.max.clone())])
}
fn evaluate(&self) -> Result<ScalarValue> {
@@ -691,8 +691,8 @@ impl MinAccumulator {
}
impl Accumulator for MinAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![self.min.clone()])
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![AggregateState::Scalar(self.min.clone())])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 1cbd4aeea..a8d59d714 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -37,6 +37,7 @@ pub(crate) mod count;
pub(crate) mod count_distinct;
pub(crate) mod covariance;
pub(crate) mod grouping;
+pub(crate) mod median;
#[macro_use]
pub(crate) mod min_max;
pub mod build_in;
@@ -47,6 +48,7 @@ pub(crate) mod stddev;
pub(crate) mod sum;
pub(crate) mod sum_distinct;
mod tdigest;
+pub mod utils;
pub(crate) mod variance;
/// An aggregate expression that:
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs
b/datafusion/physical-expr/src/aggregate/stddev.rs
index 13085fee2..77f080293 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -27,7 +27,7 @@ use crate::{AggregateExpr, PhysicalExpr};
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
#[derive(Debug)]
@@ -180,11 +180,11 @@ impl StddevAccumulator {
}
impl Accumulator for StddevAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
- ScalarValue::from(self.variance.get_count()),
- ScalarValue::from(self.variance.get_mean()),
- ScalarValue::from(self.variance.get_m2()),
+
AggregateState::Scalar(ScalarValue::from(self.variance.get_count())),
+
AggregateState::Scalar(ScalarValue::from(self.variance.get_mean())),
+ AggregateState::Scalar(ScalarValue::from(self.variance.get_m2())),
])
}
@@ -216,6 +216,7 @@ impl Accumulator for StddevAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op;
@@ -441,12 +442,7 @@ mod tests {
.collect::<Result<Vec<_>>>()?;
accum1.update_batch(&values1)?;
accum2.update_batch(&values2)?;
- let state2 = accum2
- .state()?
- .iter()
- .map(|v| vec![v.clone()])
- .map(|x| ScalarValue::iter_to_array(x).unwrap())
- .collect::<Vec<_>>();
+ let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
accum1.merge_batch(&state2)?;
accum1.evaluate()
}
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs
b/datafusion/physical-expr/src/aggregate/sum.rs
index 866e90f1e..b0a7de6c6 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -32,7 +32,7 @@ use arrow::{
datatypes::Field,
};
use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
use crate::aggregate::row_accumulator::RowAccumulator;
use crate::expressions::format_state_name;
@@ -435,8 +435,8 @@ pub(crate) fn add_to_row(
}
impl Accumulator for SumAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![self.sum.clone()])
+ fn state(&self) -> Result<Vec<AggregateState>> {
+ Ok(vec![AggregateState::Scalar(self.sum.clone())])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
index a64b4b497..d939a033e 100644
--- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -29,7 +29,7 @@ use std::collections::HashSet;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
/// Expression for a SUM(DISTINCT) aggregation.
#[derive(Debug)]
@@ -128,7 +128,7 @@ impl DistinctSumAccumulator {
}
impl Accumulator for DistinctSumAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
// 1. Stores aggregate state in `ScalarValue::List`
// 2. Constructs `ScalarValue::List` state from distinct numeric
stored in hash set
let state_out = {
@@ -136,10 +136,10 @@ impl Accumulator for DistinctSumAccumulator {
self.hash_values
.iter()
.for_each(|distinct_value|
distinct_values.push(distinct_value.clone()));
- vec![ScalarValue::List(
+ vec![AggregateState::Scalar(ScalarValue::List(
Some(distinct_values),
Box::new(Field::new("item", self.data_type.clone(), true)),
- )]
+ ))]
};
Ok(state_out)
}
@@ -181,6 +181,7 @@ impl Accumulator for DistinctSumAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::aggregate::utils::get_accum_scalar_values;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use arrow::record_batch::RecordBatch;
@@ -196,7 +197,7 @@ mod tests {
let mut accum = agg.create_accumulator()?;
accum.update_batch(arrays)?;
- Ok((accum.state()?, accum.evaluate()?))
+ Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
}
macro_rules! generic_test_sum_distinct {
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs
b/datafusion/physical-expr/src/aggregate/utils.rs
new file mode 100644
index 000000000..1cac5b98a
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Utilities used in aggregates
+
+use arrow::array::ArrayRef;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::Accumulator;
+
+/// Extract scalar values from an accumulator. This can return an error if the
accumulator
+/// has any non-scalar values.
+pub fn get_accum_scalar_values(accum: &dyn Accumulator) ->
Result<Vec<ScalarValue>> {
+ accum
+ .state()?
+ .iter()
+ .map(|agg| agg.as_scalar().map(|v| v.clone()))
+ .collect::<Result<Vec<_>>>()
+}
+
+/// Convert scalar values from an accumulator into arrays. This can return an
error if the
+/// accumulator has any non-scalar values.
+pub fn get_accum_scalar_values_as_arrays(
+ accum: &dyn Accumulator,
+) -> Result<Vec<ArrayRef>> {
+ accum
+ .state()?
+ .iter()
+ .map(|v| {
+ v.as_scalar()
+ .map(|s| vec![s.clone()])
+ .and_then(ScalarValue::iter_to_array)
+ })
+ .collect::<Result<Vec<_>>>()
+}
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs
b/datafusion/physical-expr/src/aggregate/variance.rs
index 364936213..4ff4318e3 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -32,7 +32,7 @@ use arrow::{
};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
/// VAR and VAR_SAMP aggregate expression
#[derive(Debug)]
@@ -210,11 +210,11 @@ impl VarianceAccumulator {
}
impl Accumulator for VarianceAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![
- ScalarValue::from(self.count),
- ScalarValue::from(self.mean),
- ScalarValue::from(self.m2),
+ AggregateState::Scalar(ScalarValue::from(self.count)),
+ AggregateState::Scalar(ScalarValue::from(self.mean)),
+ AggregateState::Scalar(ScalarValue::from(self.m2)),
])
}
@@ -296,6 +296,7 @@ impl Accumulator for VarianceAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use crate::generic_test_op;
@@ -522,12 +523,7 @@ mod tests {
.collect::<Result<Vec<_>>>()?;
accum1.update_batch(&values1)?;
accum2.update_batch(&values2)?;
- let state2 = accum2
- .state()?
- .iter()
- .map(|v| vec![v.clone()])
- .map(|x| ScalarValue::iter_to_array(x).unwrap())
- .collect::<Vec<_>>();
+ let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
accum1.merge_batch(&state2)?;
accum1.evaluate()
}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 7a78f4603..6d8852e77 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -52,6 +52,7 @@ pub use crate::aggregate::count::Count;
pub use crate::aggregate::count_distinct::DistinctCount;
pub use crate::aggregate::covariance::{Covariance, CovariancePop};
pub use crate::aggregate::grouping::Grouping;
+pub use crate::aggregate::median::Median;
pub use crate::aggregate::min_max::{Max, Min};
pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator};
pub use crate::aggregate::stats::StatsType;
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index ec816a419..c9c1237a7 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -466,6 +466,7 @@ enum AggregateFunction {
APPROX_MEDIAN=15;
APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
GROUPING = 17;
+ MEDIAN=18;
}
message AggregateExprNode {
diff --git a/datafusion/proto/src/from_proto.rs
b/datafusion/proto/src/from_proto.rs
index 40ea1bd02..1f3c3955a 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -504,6 +504,7 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
}
protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian,
protobuf::AggregateFunction::Grouping => Self::Grouping,
+ protobuf::AggregateFunction::Median => Self::Median,
}
}
}
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index 1d7847df2..88230766d 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -68,8 +68,8 @@ mod roundtrip_tests {
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
use datafusion_expr::{
- col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt,
Expr,
- LogicalPlan, Volatility,
+ col, lit, Accumulator, AggregateFunction, AggregateState,
+ BuiltinScalarFunction::Sqrt, Expr, LogicalPlan, Volatility,
};
use prost::Message;
use std::any::Any;
@@ -986,7 +986,7 @@ mod roundtrip_tests {
struct Dummy {}
impl Accumulator for Dummy {
- fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
+ fn state(&self) -> datafusion::error::Result<Vec<AggregateState>> {
Ok(vec![])
}
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 323e2186d..b8ca81008 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -354,6 +354,7 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
}
AggregateFunction::ApproxMedian => Self::ApproxMedian,
AggregateFunction::Grouping => Self::Grouping,
+ AggregateFunction::Median => Self::Median,
}
}
}
@@ -540,6 +541,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
protobuf::AggregateFunction::ApproxMedian
}
AggregateFunction::Grouping =>
protobuf::AggregateFunction::Grouping,
+ AggregateFunction::Median =>
protobuf::AggregateFunction::Median,
};
let aggregate_expr = protobuf::AggregateExprNode {