This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 245def059 Implement exact median, add `AggregateState` (#3009)
245def059 is described below

commit 245def05940b1fea69d1b75df8a928efb39fc3af
Author: Andy Grove <[email protected]>
AuthorDate: Fri Aug 5 13:56:52 2022 -0600

    Implement exact median, add `AggregateState` (#3009)
    
    * Implement exact median
    
    * revert some changes
    
    * toml format
    
    * add median to protobuf
    
    * remove some unwraps
    
    * remove some unwraps
    
    * remove some unwraps
    
    * fix
    
    * clippy
    
    * reduce code duplication
    
    * reduce code duplication
    
    * more tests
    
    * move tests to simplify github diff
    
    * Update datafusion/expr/src/accumulator.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * refactor to make it more obvious that empty arrays are being created
    
    * partially address feedback
    
    * Update datafusion/physical-expr/src/aggregate/count_distinct.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * add more tests
    
    * more docs
    
    * clippy
    
    * avoid a clone
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion-examples/examples/simple_udaf.rs        |   7 +-
 datafusion/common/Cargo.toml                       |   1 +
 .../core/src/physical_plan/aggregates/hash.rs      |   6 +-
 datafusion/core/tests/sql/aggregates.rs            | 189 +++++++++++++++-
 datafusion/core/tests/sql/mod.rs                   |   6 +-
 datafusion/expr/src/accumulator.rs                 |  45 +++-
 datafusion/expr/src/aggregate_function.rs          |   9 +-
 datafusion/expr/src/lib.rs                         |   2 +-
 .../physical-expr/src/aggregate/approx_distinct.rs |   6 +-
 .../src/aggregate/approx_percentile_cont.rs        |  11 +-
 .../approx_percentile_cont_with_weight.rs          |   4 +-
 .../physical-expr/src/aggregate/array_agg.rs       |   6 +-
 .../src/aggregate/array_agg_distinct.rs            |   8 +-
 datafusion/physical-expr/src/aggregate/average.rs  |   9 +-
 datafusion/physical-expr/src/aggregate/build_in.rs |  10 +
 .../physical-expr/src/aggregate/correlation.rs     |  24 +-
 datafusion/physical-expr/src/aggregate/count.rs    |   8 +-
 .../physical-expr/src/aggregate/count_distinct.rs  |  13 +-
 .../physical-expr/src/aggregate/covariance.rs      |  20 +-
 datafusion/physical-expr/src/aggregate/median.rs   | 244 +++++++++++++++++++++
 datafusion/physical-expr/src/aggregate/min_max.rs  |  10 +-
 datafusion/physical-expr/src/aggregate/mod.rs      |   2 +
 datafusion/physical-expr/src/aggregate/stddev.rs   |  18 +-
 datafusion/physical-expr/src/aggregate/sum.rs      |   6 +-
 .../physical-expr/src/aggregate/sum_distinct.rs    |  11 +-
 datafusion/physical-expr/src/aggregate/utils.rs    |  48 ++++
 datafusion/physical-expr/src/aggregate/variance.rs |  18 +-
 datafusion/physical-expr/src/expressions/mod.rs    |   1 +
 datafusion/proto/proto/datafusion.proto            |   1 +
 datafusion/proto/src/from_proto.rs                 |   1 +
 datafusion/proto/src/lib.rs                        |   6 +-
 datafusion/proto/src/to_proto.rs                   |   2 +
 32 files changed, 645 insertions(+), 107 deletions(-)

diff --git a/datafusion-examples/examples/simple_udaf.rs 
b/datafusion-examples/examples/simple_udaf.rs
index 5e0f41bc8..378d2548e 100644
--- a/datafusion-examples/examples/simple_udaf.rs
+++ b/datafusion-examples/examples/simple_udaf.rs
@@ -23,6 +23,7 @@ use datafusion::arrow::{
 };
 
 use datafusion::from_slice::FromSlice;
+use datafusion::logical_expr::AggregateState;
 use datafusion::{error::Result, logical_plan::create_udaf, 
physical_plan::Accumulator};
 use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
 use std::sync::Arc;
@@ -107,10 +108,10 @@ impl Accumulator for GeometricMean {
     // This function serializes our state to `ScalarValue`, which DataFusion 
uses
     // to pass this state between execution stages.
     // Note that this can be arbitrary data.
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         Ok(vec![
-            ScalarValue::from(self.prod),
-            ScalarValue::from(self.n),
+            AggregateState::Scalar(ScalarValue::from(self.prod)),
+            AggregateState::Scalar(ScalarValue::from(self.n)),
         ])
     }
 
diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml
index e7873de4d..33d6af087 100644
--- a/datafusion/common/Cargo.toml
+++ b/datafusion/common/Cargo.toml
@@ -45,4 +45,5 @@ object_store = { version = "0.3", optional = true }
 ordered-float = "3.0"
 parquet = { version = "19.0.0", features = ["arrow"], optional = true }
 pyo3 = { version = "0.16", optional = true }
+serde_json = "1.0"
 sqlparser = "0.19"
diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs 
b/datafusion/core/src/physical_plan/aggregates/hash.rs
index c21109495..54806d37f 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -428,8 +428,10 @@ fn create_batch_from_map(
                 AggregateMode::Partial => {
                     let res = ScalarValue::iter_to_array(
                         accumulators.group_states.iter().map(|group_state| {
-                            let x = 
group_state.accumulator_set[x].state().unwrap();
-                            x[y].clone()
+                            group_state.accumulator_set[x]
+                                .state()
+                                .and_then(|x| x[y].as_scalar().map(|v| 
v.clone()))
+                                .expect("unexpected accumulator state in hash 
aggregate")
                         }),
                     )?;
 
diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index 02d4b3a4d..eb0e07f84 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> {
 }
 
 #[tokio::test]
-async fn csv_query_median_1() -> Result<()> {
+async fn csv_query_approx_median_1() -> Result<()> {
     let ctx = SessionContext::new();
     register_aggregate_csv(&ctx).await?;
     let sql = "SELECT approx_median(c2) FROM aggregate_test_100";
@@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> {
 }
 
 #[tokio::test]
-async fn csv_query_median_2() -> Result<()> {
+async fn csv_query_approx_median_2() -> Result<()> {
     let ctx = SessionContext::new();
     register_aggregate_csv(&ctx).await?;
     let sql = "SELECT approx_median(c6) FROM aggregate_test_100";
@@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> {
 }
 
 #[tokio::test]
-async fn csv_query_median_3() -> Result<()> {
+async fn csv_query_approx_median_3() -> Result<()> {
     let ctx = SessionContext::new();
     register_aggregate_csv(&ctx).await?;
     let sql = "SELECT approx_median(c12) FROM aggregate_test_100";
@@ -253,6 +253,189 @@ async fn csv_query_median_3() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_median_1() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT median(c2) FROM aggregate_test_100";
+    let actual = execute(&ctx, sql).await;
+    let expected = vec![vec!["3"]];
+    assert_float_eq(&expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_median_2() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT median(c6) FROM aggregate_test_100";
+    let actual = execute(&ctx, sql).await;
+    let expected = vec![vec!["1125553990140691277"]];
+    assert_float_eq(&expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_median_3() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT median(c12) FROM aggregate_test_100";
+    let actual = execute(&ctx, sql).await;
+    let expected = vec![vec!["0.5513900544385053"]];
+    assert_float_eq(&expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn median_i8() -> Result<()> {
+    median_test(
+        "median",
+        DataType::Int8,
+        Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])),
+        "-14",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_i16() -> Result<()> {
+    median_test(
+        "median",
+        DataType::Int16,
+        Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])),
+        "-16334",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_i32() -> Result<()> {
+    median_test(
+        "median",
+        DataType::Int32,
+        Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])),
+        "-1073741774",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_i64() -> Result<()> {
+    median_test(
+        "median",
+        DataType::Int64,
+        Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])),
+        "-4611686018427388000",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_u8() -> Result<()> {
+    median_test(
+        "median",
+        DataType::UInt8,
+        Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])),
+        "50",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_u16() -> Result<()> {
+    median_test(
+        "median",
+        DataType::UInt16,
+        Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])),
+        "50",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_u32() -> Result<()> {
+    median_test(
+        "median",
+        DataType::UInt32,
+        Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])),
+        "50",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_u64() -> Result<()> {
+    median_test(
+        "median",
+        DataType::UInt64,
+        Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])),
+        "50",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_f32() -> Result<()> {
+    median_test(
+        "median",
+        DataType::Float32,
+        Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
+        "3.3",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_f64() -> Result<()> {
+    median_test(
+        "median",
+        DataType::Float64,
+        Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
+        "3.3",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn median_f64_nan() -> Result<()> {
+    median_test(
+        "median",
+        DataType::Float64,
+        Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
+        "NaN", // probably not the desired behavior? - see 
https://github.com/apache/arrow-datafusion/issues/3039
+    )
+    .await
+}
+
+#[tokio::test]
+async fn approx_median_f64_nan() -> Result<()> {
+    median_test(
+        "approx_median",
+        DataType::Float64,
+        Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
+        "NaN", // probably not the desired behavior? - see 
https://github.com/apache/arrow-datafusion/issues/3039
+    )
+    .await
+}
+
+async fn median_test(
+    func: &str,
+    data_type: DataType,
+    values: ArrayRef,
+    expected: &str,
+) -> Result<()> {
+    let ctx = SessionContext::new();
+    let schema = Arc::new(Schema::new(vec![Field::new("a", data_type, 
false)]));
+    let batch = RecordBatch::try_new(schema.clone(), vec![values])?;
+    let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
+    ctx.register_table("t", table)?;
+    let sql = format!("SELECT {}(a) FROM t", func);
+    let actual = execute(&ctx, &sql).await;
+    let expected = vec![vec![expected.to_owned()]];
+    assert_float_eq(&expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn csv_query_external_table_count() {
     let ctx = SessionContext::new();
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 847f19cf8..5481161d0 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -128,7 +128,11 @@ where
                 l.as_ref().parse::<f64>().unwrap(),
                 r.as_str().parse::<f64>().unwrap(),
             );
-            assert!((l - r).abs() <= 2.0 * f64::EPSILON);
+            if l.is_nan() || r.is_nan() {
+                assert!(l.is_nan() && r.is_nan());
+            } else if (l - r).abs() > 2.0 * f64::EPSILON {
+                panic!("{} != {}", l, r)
+            }
         });
 }
 
diff --git a/datafusion/expr/src/accumulator.rs 
b/datafusion/expr/src/accumulator.rs
index d59764957..6c146bc30 100644
--- a/datafusion/expr/src/accumulator.rs
+++ b/datafusion/expr/src/accumulator.rs
@@ -18,7 +18,7 @@
 //! Accumulator module contains the trait definition for aggregation 
function's accumulators.
 
 use arrow::array::ArrayRef;
-use datafusion_common::{Result, ScalarValue};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
 use std::fmt::Debug;
 
 /// An accumulator represents a stateful object that lives throughout the 
evaluation of multiple rows and
@@ -26,14 +26,14 @@ use std::fmt::Debug;
 ///
 /// An accumulator knows how to:
 /// * update its state from inputs via `update_batch`
-/// * convert its internal state to a vector of scalar values
+/// * convert its internal state to a vector of aggregate values
 /// * update its state from multiple accumulators' states via `merge_batch`
 /// * compute the final value from its internal state via `evaluate`
 pub trait Accumulator: Send + Sync + Debug {
     /// Returns the state of the accumulator at the end of the accumulation.
-    // in the case of an average on which we track `sum` and `n`, this 
function should return a vector
-    // of two values, sum and n.
-    fn state(&self) -> Result<Vec<ScalarValue>>;
+    /// in the case of an average on which we track `sum` and `n`, this 
function should return a vector
+    /// of two values, sum and n.
+    fn state(&self) -> Result<Vec<AggregateState>>;
 
     /// updates the accumulator's state from a vector of arrays.
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
@@ -44,3 +44,38 @@ pub trait Accumulator: Send + Sync + Debug {
     /// returns its value based on its current state.
     fn evaluate(&self) -> Result<ScalarValue>;
 }
+
+/// Representation of internal accumulator state. Accumulators can potentially 
have a mix of
+/// scalar and array values. It may be desirable to add custom aggregator 
states here as well
+/// in the future (perhaps `Custom(Box<dyn Any>)`?).
+#[derive(Debug)]
+pub enum AggregateState {
+    /// Simple scalar value. Note that `ScalarValue::List` can be used to pass 
multiple
+    /// values around
+    Scalar(ScalarValue),
+    /// Arrays can be used instead of `ScalarValue::List` and could 
potentially have better
+    /// performance with large data sets, although this has not been verified. 
It also allows
+    /// for use of arrow kernels with less overhead.
+    Array(ArrayRef),
+}
+
+impl AggregateState {
+    /// Access the aggregate state as a scalar value. An error will occur if 
the
+    /// state is not a scalar value.
+    pub fn as_scalar(&self) -> Result<&ScalarValue> {
+        match &self {
+            Self::Scalar(v) => Ok(v),
+            _ => Err(DataFusionError::Internal(
+                "AggregateState is not a scalar aggregate".to_string(),
+            )),
+        }
+    }
+
+    /// Access the aggregate state as an array value.
+    pub fn to_array(&self) -> ArrayRef {
+        match &self {
+            Self::Scalar(v) => v.to_array(),
+            Self::Array(array) => array.clone(),
+        }
+    }
+}
diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index 30bf0521d..09d759e56 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -62,6 +62,8 @@ pub enum AggregateFunction {
     Max,
     /// avg
     Avg,
+    /// median
+    Median,
     /// Approximate aggregate function
     ApproxDistinct,
     /// array_agg
@@ -107,6 +109,7 @@ impl FromStr for AggregateFunction {
             "avg" => AggregateFunction::Avg,
             "mean" => AggregateFunction::Avg,
             "sum" => AggregateFunction::Sum,
+            "median" => AggregateFunction::Median,
             "approx_distinct" => AggregateFunction::ApproxDistinct,
             "array_agg" => AggregateFunction::ArrayAgg,
             "var" => AggregateFunction::Variance,
@@ -175,7 +178,9 @@ pub fn return_type(
         AggregateFunction::ApproxPercentileContWithWeight => {
             Ok(coerced_data_types[0].clone())
         }
-        AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
+        AggregateFunction::ApproxMedian | AggregateFunction::Median => {
+            Ok(coerced_data_types[0].clone())
+        }
         AggregateFunction::Grouping => Ok(DataType::Int32),
     }
 }
@@ -330,6 +335,7 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
+        AggregateFunction::Median => Ok(input_types.to_vec()),
         AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
     }
 }
@@ -358,6 +364,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
         | AggregateFunction::VariancePop
         | AggregateFunction::Stddev
         | AggregateFunction::StddevPop
+        | AggregateFunction::Median
         | AggregateFunction::ApproxMedian => {
             Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
         }
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index f71243610..90007a8bd 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -53,7 +53,7 @@ pub mod utils;
 pub mod window_frame;
 pub mod window_function;
 
-pub use accumulator::Accumulator;
+pub use accumulator::{Accumulator, AggregateState};
 pub use aggregate_function::AggregateFunction;
 pub use built_in_function::BuiltinScalarFunction;
 pub use columnar_value::{ColumnarValue, NullColumnarValue};
diff --git a/datafusion/physical-expr/src/aggregate/approx_distinct.rs 
b/datafusion/physical-expr/src/aggregate/approx_distinct.rs
index c67d1c9d3..5b391ed84 100644
--- a/datafusion/physical-expr/src/aggregate/approx_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_distinct.rs
@@ -30,7 +30,7 @@ use arrow::datatypes::{
 };
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 use std::any::type_name;
 use std::any::Any;
 use std::convert::TryFrom;
@@ -232,8 +232,8 @@ macro_rules! default_accumulator_impl {
             Ok(())
         }
 
-        fn state(&self) -> Result<Vec<ScalarValue>> {
-            let value = ScalarValue::from(&self.hll);
+        fn state(&self) -> Result<Vec<AggregateState>> {
+            let value = AggregateState::Scalar(ScalarValue::from(&self.hll));
             Ok(vec![value])
         }
 
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 2315ad1d5..41c6c72db 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -29,7 +29,7 @@ use arrow::{
 use datafusion_common::DataFusionError;
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 use ordered_float::OrderedFloat;
 use std::{any::Any, iter, sync::Arc};
 
@@ -287,8 +287,13 @@ impl ApproxPercentileAccumulator {
 }
 
 impl Accumulator for ApproxPercentileAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(self.digest.to_scalar_state())
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(self
+            .digest
+            .to_scalar_state()
+            .into_iter()
+            .map(AggregateState::Scalar)
+            .collect())
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git 
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
index f9874b0a5..40a44c3a5 100644
--- 
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
+++ 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
@@ -26,7 +26,7 @@ use arrow::{
 
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 use std::{any::Any, sync::Arc};
 
@@ -114,7 +114,7 @@ impl ApproxPercentileWithWeightAccumulator {
 }
 
 impl Accumulator for ApproxPercentileWithWeightAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         self.approx_percentile_cont_accumulator.state()
     }
 
diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs 
b/datafusion/physical-expr/src/aggregate/array_agg.rs
index eaed89390..e7fd0937c 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg.rs
@@ -23,7 +23,7 @@ use arrow::array::ArrayRef;
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 use std::any::Any;
 use std::sync::Arc;
 
@@ -143,8 +143,8 @@ impl Accumulator for ArrayAggAccumulator {
         })
     }
 
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.evaluate()?])
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(vec![AggregateState::Scalar(self.evaluate()?)])
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs 
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 44e24e93c..f9899379d 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -29,7 +29,7 @@ use crate::expressions::format_state_name;
 use crate::{AggregateExpr, PhysicalExpr};
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 /// Expression for a ARRAY_AGG(DISTINCT) aggregation.
 #[derive(Debug)]
@@ -119,11 +119,11 @@ impl DistinctArrayAggAccumulator {
 }
 
 impl Accumulator for DistinctArrayAggAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![ScalarValue::List(
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(vec![AggregateState::Scalar(ScalarValue::List(
             Some(self.values.clone().into_iter().collect()),
             Box::new(Field::new("item", self.datatype.clone(), true)),
-        )])
+        ))])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 1b1d99525..a55e0e352 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -33,7 +33,7 @@ use arrow::{
 };
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 use datafusion_row::accessor::RowAccessor;
 
 /// AVG aggregate expression
@@ -150,8 +150,11 @@ impl AvgAccumulator {
 }
 
 impl Accumulator for AvgAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(vec![
+            AggregateState::Scalar(ScalarValue::from(self.count)),
+            AggregateState::Scalar(self.sum.clone()),
+        ])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 23d2a84d1..8d76e35e4 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -251,6 +251,16 @@ pub fn create_aggregate_expr(
             ))
         }
         (AggregateFunction::ApproxMedian, true) => {
+            return Err(DataFusionError::NotImplemented(
+                "APPROX_MEDIAN(DISTINCT) aggregations are not 
available".to_string(),
+            ));
+        }
+        (AggregateFunction::Median, false) => 
Arc::new(expressions::Median::new(
+            coerced_phy_exprs[0].clone(),
+            name,
+            return_type,
+        )),
+        (AggregateFunction::Median, true) => {
             return Err(DataFusionError::NotImplemented(
                 "MEDIAN(DISTINCT) aggregations are not available".to_string(),
             ));
diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs 
b/datafusion/physical-expr/src/aggregate/correlation.rs
index 94a820849..3bbea5d9b 100644
--- a/datafusion/physical-expr/src/aggregate/correlation.rs
+++ b/datafusion/physical-expr/src/aggregate/correlation.rs
@@ -25,7 +25,7 @@ use crate::{AggregateExpr, PhysicalExpr};
 use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 use std::any::Any;
 use std::sync::Arc;
 
@@ -133,14 +133,14 @@ impl CorrelationAccumulator {
 }
 
 impl Accumulator for CorrelationAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         Ok(vec![
-            ScalarValue::from(self.covar.get_count()),
-            ScalarValue::from(self.covar.get_mean1()),
-            ScalarValue::from(self.stddev1.get_m2()),
-            ScalarValue::from(self.covar.get_mean2()),
-            ScalarValue::from(self.stddev2.get_m2()),
-            ScalarValue::from(self.covar.get_algo_const()),
+            AggregateState::Scalar(ScalarValue::from(self.covar.get_count())),
+            AggregateState::Scalar(ScalarValue::from(self.covar.get_mean1())),
+            AggregateState::Scalar(ScalarValue::from(self.stddev1.get_m2())),
+            AggregateState::Scalar(ScalarValue::from(self.covar.get_mean2())),
+            AggregateState::Scalar(ScalarValue::from(self.stddev2.get_m2())),
+            
AggregateState::Scalar(ScalarValue::from(self.covar.get_algo_const())),
         ])
     }
 
@@ -191,6 +191,7 @@ impl Accumulator for CorrelationAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
     use crate::generic_test_op2;
@@ -469,12 +470,7 @@ mod tests {
             .collect::<Result<Vec<_>>>()?;
         accum1.update_batch(&values1)?;
         accum2.update_batch(&values2)?;
-        let state2 = accum2
-            .state()?
-            .iter()
-            .map(|v| vec![v.clone()])
-            .map(|x| ScalarValue::iter_to_array(x).unwrap())
-            .collect::<Vec<_>>();
+        let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
         accum1.merge_batch(&state2)?;
         accum1.evaluate()
     }
diff --git a/datafusion/physical-expr/src/aggregate/count.rs 
b/datafusion/physical-expr/src/aggregate/count.rs
index 2b02d03b5..982c1dc09 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -29,7 +29,7 @@ use arrow::datatypes::DataType;
 use arrow::{array::ArrayRef, datatypes::Field};
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 use datafusion_row::accessor::RowAccessor;
 
 use crate::expressions::format_state_name;
@@ -134,8 +134,10 @@ impl Accumulator for CountAccumulator {
         Ok(())
     }
 
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![ScalarValue::Int64(Some(self.count))])
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
+            self.count,
+        )))])
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index 744d9b90d..6060ddb4d 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -28,7 +28,7 @@ use crate::expressions::format_state_name;
 use crate::{AggregateExpr, PhysicalExpr};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 #[derive(Debug, PartialEq, Eq, Hash, Clone)]
 struct DistinctScalarValues(Vec<ScalarValue>);
@@ -177,7 +177,7 @@ impl Accumulator for DistinctCountAccumulator {
             self.merge(&v)
         })
     }
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         let mut cols_out = self
             .state_data_types
             .iter()
@@ -206,7 +206,7 @@ impl Accumulator for DistinctCountAccumulator {
             )
         });
 
-        Ok(cols_out)
+        Ok(cols_out.into_iter().map(AggregateState::Scalar).collect())
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
@@ -223,6 +223,7 @@ impl Accumulator for DistinctCountAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::aggregate::utils::get_accum_scalar_values;
     use arrow::array::{
         ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, 
Int32Array,
         Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, 
UInt64Array,
@@ -341,7 +342,7 @@ mod tests {
         let mut accum = agg.create_accumulator()?;
         accum.update_batch(arrays)?;
 
-        Ok((accum.state()?, accum.evaluate()?))
+        Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
     }
 
     fn run_update(
@@ -372,7 +373,7 @@ mod tests {
 
         accum.update_batch(&arrays)?;
 
-        Ok((accum.state()?, accum.evaluate()?))
+        Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
     }
 
     fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, 
ScalarValue)> {
@@ -390,7 +391,7 @@ mod tests {
         let mut accum = agg.create_accumulator()?;
         accum.merge_batch(arrays)?;
 
-        Ok((accum.state()?, accum.evaluate()?))
+        Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
     }
 
     // Used trait to create associated constant for f32 and f64
diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs 
b/datafusion/physical-expr/src/aggregate/covariance.rs
index 1df002b48..9cd319127 100644
--- a/datafusion/physical-expr/src/aggregate/covariance.rs
+++ b/datafusion/physical-expr/src/aggregate/covariance.rs
@@ -30,7 +30,7 @@ use arrow::{
 };
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 use crate::aggregate::stats::StatsType;
 use crate::expressions::format_state_name;
@@ -237,12 +237,12 @@ impl CovarianceAccumulator {
 }
 
 impl Accumulator for CovarianceAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         Ok(vec![
-            ScalarValue::from(self.count),
-            ScalarValue::from(self.mean1),
-            ScalarValue::from(self.mean2),
-            ScalarValue::from(self.algo_const),
+            AggregateState::Scalar(ScalarValue::from(self.count)),
+            AggregateState::Scalar(ScalarValue::from(self.mean1)),
+            AggregateState::Scalar(ScalarValue::from(self.mean2)),
+            AggregateState::Scalar(ScalarValue::from(self.algo_const)),
         ])
     }
 
@@ -352,6 +352,7 @@ impl Accumulator for CovarianceAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
     use crate::generic_test_op2;
@@ -644,12 +645,7 @@ mod tests {
             .collect::<Result<Vec<_>>>()?;
         accum1.update_batch(&values1)?;
         accum2.update_batch(&values2)?;
-        let state2 = accum2
-            .state()?
-            .iter()
-            .map(|v| vec![v.clone()])
-            .map(|x| ScalarValue::iter_to_array(x).unwrap())
-            .collect::<Vec<_>>();
+        let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
         accum1.merge_batch(&state2)?;
         accum1.evaluate()
     }
diff --git a/datafusion/physical-expr/src/aggregate/median.rs 
b/datafusion/physical-expr/src/aggregate/median.rs
new file mode 100644
index 000000000..6b68f2ec3
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/median.rs
@@ -0,0 +1,244 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! # Median
+
+use crate::expressions::format_state_name;
+use crate::{AggregateExpr, PhysicalExpr};
+use arrow::array::{Array, ArrayRef, PrimitiveArray, PrimitiveBuilder};
+use arrow::compute::sort;
+use arrow::datatypes::{
+    ArrowPrimitiveType, DataType, Field, Float32Type, Float64Type, Int16Type, 
Int32Type,
+    Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::{Accumulator, AggregateState};
+use std::any::Any;
+use std::sync::Arc;
+
+/// MEDIAN aggregate expression. This uses a lot of memory because all values 
need to be
+/// stored in memory before a result can be computed. If an approximation is 
sufficient
+/// then APPROX_MEDIAN provides a much more efficient solution.
+#[derive(Debug)]
+pub struct Median {
+    name: String,
+    expr: Arc<dyn PhysicalExpr>,
+    data_type: DataType,
+}
+
+impl Median {
+    /// Create a new MEDIAN aggregate function
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        data_type: DataType,
+    ) -> Self {
+        Self {
+            name: name.into(),
+            expr,
+            data_type,
+        }
+    }
+}
+
+impl AggregateExpr for Median {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, self.data_type.clone(), true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(MedianAccumulator {
+            data_type: self.data_type.clone(),
+            all_values: vec![],
+        }))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![Field::new(
+            &format_state_name(&self.name, "median"),
+            self.data_type.clone(),
+            true,
+        )])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+#[derive(Debug)]
+struct MedianAccumulator {
+    data_type: DataType,
+    all_values: Vec<ArrayRef>,
+}
+
+macro_rules! median {
+    ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{
+        let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?;
+        if combined.is_empty() {
+            return Ok(ScalarValue::Null);
+        }
+        let sorted = sort(&combined, None)?;
+        let array = sorted
+            .as_any()
+            .downcast_ref::<PrimitiveArray<$TY>>()
+            .ok_or(DataFusionError::Internal(
+                "median! macro failed to cast array to expected 
type".to_string(),
+            ))?;
+        let len = sorted.len();
+        let mid = len / 2;
+        if len % 2 == 0 {
+            Ok(ScalarValue::$SCALAR_TY(Some(
+                (array.value(mid - 1) + array.value(mid)) / $TWO,
+            )))
+        } else {
+            Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid))))
+        }
+    }};
+}
+
+impl Accumulator for MedianAccumulator {
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        let mut vec: Vec<AggregateState> = self
+            .all_values
+            .iter()
+            .map(|v| AggregateState::Array(v.clone()))
+            .collect();
+        if vec.is_empty() {
+            match self.data_type {
+                DataType::UInt8 => vec.push(empty_array::<UInt8Type>()),
+                DataType::UInt16 => vec.push(empty_array::<UInt16Type>()),
+                DataType::UInt32 => vec.push(empty_array::<UInt32Type>()),
+                DataType::UInt64 => vec.push(empty_array::<UInt64Type>()),
+                DataType::Int8 => vec.push(empty_array::<Int8Type>()),
+                DataType::Int16 => vec.push(empty_array::<Int16Type>()),
+                DataType::Int32 => vec.push(empty_array::<Int32Type>()),
+                DataType::Int64 => vec.push(empty_array::<Int64Type>()),
+                DataType::Float32 => vec.push(empty_array::<Float32Type>()),
+                DataType::Float64 => vec.push(empty_array::<Float64Type>()),
+                _ => {
+                    return Err(DataFusionError::Execution(
+                        "unsupported data type for median".to_string(),
+                    ))
+                }
+            }
+        }
+        Ok(vec)
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let x = values[0].clone();
+        self.all_values.extend_from_slice(&[x]);
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        for array in states {
+            self.all_values.extend_from_slice(&[array.clone()]);
+        }
+        Ok(())
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        match self.all_values[0].data_type() {
+            DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8, 
2),
+            DataType::Int16 => median!(self, arrow::datatypes::Int16Type, 
Int16, 2),
+            DataType::Int32 => median!(self, arrow::datatypes::Int32Type, 
Int32, 2),
+            DataType::Int64 => median!(self, arrow::datatypes::Int64Type, 
Int64, 2),
+            DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type, 
UInt8, 2),
+            DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type, 
UInt16, 2),
+            DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type, 
UInt32, 2),
+            DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type, 
UInt64, 2),
+            DataType::Float32 => {
+                median!(self, arrow::datatypes::Float32Type, Float32, 2_f32)
+            }
+            DataType::Float64 => {
+                median!(self, arrow::datatypes::Float64Type, Float64, 2_f64)
+            }
+            _ => Err(DataFusionError::Execution(
+                "unsupported data type for median".to_string(),
+            )),
+        }
+    }
+}
+
+/// Create an empty array
+fn empty_array<T: ArrowPrimitiveType>() -> AggregateState {
+    AggregateState::Array(Arc::new(PrimitiveBuilder::<T>::new(0).finish()))
+}
+
+/// Combine all non-null values from provided arrays into a single array
+fn combine_arrays<T: ArrowPrimitiveType>(arrays: &[ArrayRef]) -> 
Result<ArrayRef> {
+    let len = arrays.iter().map(|a| a.len() - a.null_count()).sum();
+    let mut builder: PrimitiveBuilder<T> = PrimitiveBuilder::new(len);
+    for array in arrays {
+        let array = array
+            .as_any()
+            .downcast_ref::<PrimitiveArray<T>>()
+            .ok_or_else(|| {
+                DataFusionError::Internal(
+                    "combine_arrays failed to cast array to expected 
type".to_string(),
+                )
+            })?;
+        for i in 0..array.len() {
+            if !array.is_null(i) {
+                builder.append_value(array.value(i));
+            }
+        }
+    }
+    Ok(Arc::new(builder.finish()))
+}
+
+#[cfg(test)]
+mod test {
+    use crate::aggregate::median::combine_arrays;
+    use arrow::array::{Int32Array, UInt32Array};
+    use arrow::datatypes::{Int32Type, UInt32Type};
+    use datafusion_common::Result;
+    use std::sync::Arc;
+
+    #[test]
+    fn combine_i32_array() -> Result<()> {
+        let a = Arc::new(Int32Array::from(vec![1, 2, 3]));
+        let b = combine_arrays::<Int32Type>(&[a.clone(), a])?;
+        assert_eq!(
+            "PrimitiveArray<Int32>\n[\n  1,\n  2,\n  3,\n  1,\n  2,\n  3,\n]",
+            format!("{:?}", b)
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn combine_u32_array() -> Result<()> {
+        let a = Arc::new(UInt32Array::from(vec![1, 2, 3]));
+        let b = combine_arrays::<UInt32Type>(&[a.clone(), a])?;
+        assert_eq!(
+            "PrimitiveArray<UInt32>\n[\n  1,\n  2,\n  3,\n  1,\n  2,\n  3,\n]",
+            format!("{:?}", b)
+        );
+        Ok(())
+    }
+}
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs 
b/datafusion/physical-expr/src/aggregate/min_max.rs
index bd56973b1..077f4d725 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -36,7 +36,7 @@ use arrow::{
 };
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 use crate::aggregate::row_accumulator::RowAccumulator;
 use crate::expressions::format_state_name;
@@ -538,8 +538,8 @@ impl Accumulator for MaxAccumulator {
         self.update_batch(states)
     }
 
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.max.clone()])
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(vec![AggregateState::Scalar(self.max.clone())])
     }
 
     fn evaluate(&self) -> Result<ScalarValue> {
@@ -691,8 +691,8 @@ impl MinAccumulator {
 }
 
 impl Accumulator for MinAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.min.clone()])
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(vec![AggregateState::Scalar(self.min.clone())])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs 
b/datafusion/physical-expr/src/aggregate/mod.rs
index 1cbd4aeea..a8d59d714 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -37,6 +37,7 @@ pub(crate) mod count;
 pub(crate) mod count_distinct;
 pub(crate) mod covariance;
 pub(crate) mod grouping;
+pub(crate) mod median;
 #[macro_use]
 pub(crate) mod min_max;
 pub mod build_in;
@@ -47,6 +48,7 @@ pub(crate) mod stddev;
 pub(crate) mod sum;
 pub(crate) mod sum_distinct;
 mod tdigest;
+pub mod utils;
 pub(crate) mod variance;
 
 /// An aggregate expression that:
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs 
b/datafusion/physical-expr/src/aggregate/stddev.rs
index 13085fee2..77f080293 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -27,7 +27,7 @@ use crate::{AggregateExpr, PhysicalExpr};
 use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
 #[derive(Debug)]
@@ -180,11 +180,11 @@ impl StddevAccumulator {
 }
 
 impl Accumulator for StddevAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         Ok(vec![
-            ScalarValue::from(self.variance.get_count()),
-            ScalarValue::from(self.variance.get_mean()),
-            ScalarValue::from(self.variance.get_m2()),
+            
AggregateState::Scalar(ScalarValue::from(self.variance.get_count())),
+            
AggregateState::Scalar(ScalarValue::from(self.variance.get_mean())),
+            AggregateState::Scalar(ScalarValue::from(self.variance.get_m2())),
         ])
     }
 
@@ -216,6 +216,7 @@ impl Accumulator for StddevAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
     use crate::generic_test_op;
@@ -441,12 +442,7 @@ mod tests {
             .collect::<Result<Vec<_>>>()?;
         accum1.update_batch(&values1)?;
         accum2.update_batch(&values2)?;
-        let state2 = accum2
-            .state()?
-            .iter()
-            .map(|v| vec![v.clone()])
-            .map(|x| ScalarValue::iter_to_array(x).unwrap())
-            .collect::<Vec<_>>();
+        let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
         accum1.merge_batch(&state2)?;
         accum1.evaluate()
     }
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs 
b/datafusion/physical-expr/src/aggregate/sum.rs
index 866e90f1e..b0a7de6c6 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -32,7 +32,7 @@ use arrow::{
     datatypes::Field,
 };
 use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 use crate::aggregate::row_accumulator::RowAccumulator;
 use crate::expressions::format_state_name;
@@ -435,8 +435,8 @@ pub(crate) fn add_to_row(
 }
 
 impl Accumulator for SumAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.sum.clone()])
+    fn state(&self) -> Result<Vec<AggregateState>> {
+        Ok(vec![AggregateState::Scalar(self.sum.clone())])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs 
b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
index a64b4b497..d939a033e 100644
--- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -29,7 +29,7 @@ use std::collections::HashSet;
 use crate::{AggregateExpr, PhysicalExpr};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 /// Expression for a SUM(DISTINCT) aggregation.
 #[derive(Debug)]
@@ -128,7 +128,7 @@ impl DistinctSumAccumulator {
 }
 
 impl Accumulator for DistinctSumAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         // 1. Stores aggregate state in `ScalarValue::List`
         // 2. Constructs `ScalarValue::List` state from distinct numeric 
stored in hash set
         let state_out = {
@@ -136,10 +136,10 @@ impl Accumulator for DistinctSumAccumulator {
             self.hash_values
                 .iter()
                 .for_each(|distinct_value| 
distinct_values.push(distinct_value.clone()));
-            vec![ScalarValue::List(
+            vec![AggregateState::Scalar(ScalarValue::List(
                 Some(distinct_values),
                 Box::new(Field::new("item", self.data_type.clone(), true)),
-            )]
+            ))]
         };
         Ok(state_out)
     }
@@ -181,6 +181,7 @@ impl Accumulator for DistinctSumAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::aggregate::utils::get_accum_scalar_values;
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
     use arrow::record_batch::RecordBatch;
@@ -196,7 +197,7 @@ mod tests {
         let mut accum = agg.create_accumulator()?;
         accum.update_batch(arrays)?;
 
-        Ok((accum.state()?, accum.evaluate()?))
+        Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
     }
 
     macro_rules! generic_test_sum_distinct {
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs 
b/datafusion/physical-expr/src/aggregate/utils.rs
new file mode 100644
index 000000000..1cac5b98a
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Utilities used in aggregates
+
+use arrow::array::ArrayRef;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::Accumulator;
+
+/// Extract scalar values from an accumulator. This can return an error if the 
accumulator
+/// has any non-scalar values.
+pub fn get_accum_scalar_values(accum: &dyn Accumulator) -> 
Result<Vec<ScalarValue>> {
+    accum
+        .state()?
+        .iter()
+        .map(|agg| agg.as_scalar().map(|v| v.clone()))
+        .collect::<Result<Vec<_>>>()
+}
+
+/// Convert scalar values from an accumulator into arrays. This can return an 
error if the
+/// accumulator has any non-scalar values.
+pub fn get_accum_scalar_values_as_arrays(
+    accum: &dyn Accumulator,
+) -> Result<Vec<ArrayRef>> {
+    accum
+        .state()?
+        .iter()
+        .map(|v| {
+            v.as_scalar()
+                .map(|s| vec![s.clone()])
+                .and_then(ScalarValue::iter_to_array)
+        })
+        .collect::<Result<Vec<_>>>()
+}
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs 
b/datafusion/physical-expr/src/aggregate/variance.rs
index 364936213..4ff4318e3 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -32,7 +32,7 @@ use arrow::{
 };
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, AggregateState};
 
 /// VAR and VAR_SAMP aggregate expression
 #[derive(Debug)]
@@ -210,11 +210,11 @@ impl VarianceAccumulator {
 }
 
 impl Accumulator for VarianceAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn state(&self) -> Result<Vec<AggregateState>> {
         Ok(vec![
-            ScalarValue::from(self.count),
-            ScalarValue::from(self.mean),
-            ScalarValue::from(self.m2),
+            AggregateState::Scalar(ScalarValue::from(self.count)),
+            AggregateState::Scalar(ScalarValue::from(self.mean)),
+            AggregateState::Scalar(ScalarValue::from(self.m2)),
         ])
     }
 
@@ -296,6 +296,7 @@ impl Accumulator for VarianceAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::aggregate::utils::get_accum_scalar_values_as_arrays;
     use crate::expressions::col;
     use crate::expressions::tests::aggregate;
     use crate::generic_test_op;
@@ -522,12 +523,7 @@ mod tests {
             .collect::<Result<Vec<_>>>()?;
         accum1.update_batch(&values1)?;
         accum2.update_batch(&values2)?;
-        let state2 = accum2
-            .state()?
-            .iter()
-            .map(|v| vec![v.clone()])
-            .map(|x| ScalarValue::iter_to_array(x).unwrap())
-            .collect::<Vec<_>>();
+        let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
         accum1.merge_batch(&state2)?;
         accum1.evaluate()
     }
diff --git a/datafusion/physical-expr/src/expressions/mod.rs 
b/datafusion/physical-expr/src/expressions/mod.rs
index 7a78f4603..6d8852e77 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -52,6 +52,7 @@ pub use crate::aggregate::count::Count;
 pub use crate::aggregate::count_distinct::DistinctCount;
 pub use crate::aggregate::covariance::{Covariance, CovariancePop};
 pub use crate::aggregate::grouping::Grouping;
+pub use crate::aggregate::median::Median;
 pub use crate::aggregate::min_max::{Max, Min};
 pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator};
 pub use crate::aggregate::stats::StatsType;
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index ec816a419..c9c1237a7 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -466,6 +466,7 @@ enum AggregateFunction {
   APPROX_MEDIAN=15;
   APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
   GROUPING = 17;
+  MEDIAN=18;
 }
 
 message AggregateExprNode {
diff --git a/datafusion/proto/src/from_proto.rs 
b/datafusion/proto/src/from_proto.rs
index 40ea1bd02..1f3c3955a 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -504,6 +504,7 @@ impl From<protobuf::AggregateFunction> for 
AggregateFunction {
             }
             protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian,
             protobuf::AggregateFunction::Grouping => Self::Grouping,
+            protobuf::AggregateFunction::Median => Self::Median,
         }
     }
 }
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index 1d7847df2..88230766d 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -68,8 +68,8 @@ mod roundtrip_tests {
     use datafusion_expr::expr::GroupingSet;
     use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
     use datafusion_expr::{
-        col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, 
Expr,
-        LogicalPlan, Volatility,
+        col, lit, Accumulator, AggregateFunction, AggregateState,
+        BuiltinScalarFunction::Sqrt, Expr, LogicalPlan, Volatility,
     };
     use prost::Message;
     use std::any::Any;
@@ -986,7 +986,7 @@ mod roundtrip_tests {
         struct Dummy {}
 
         impl Accumulator for Dummy {
-            fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
+            fn state(&self) -> datafusion::error::Result<Vec<AggregateState>> {
                 Ok(vec![])
             }
 
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 323e2186d..b8ca81008 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -354,6 +354,7 @@ impl From<&AggregateFunction> for 
protobuf::AggregateFunction {
             }
             AggregateFunction::ApproxMedian => Self::ApproxMedian,
             AggregateFunction::Grouping => Self::Grouping,
+            AggregateFunction::Median => Self::Median,
         }
     }
 }
@@ -540,6 +541,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                         protobuf::AggregateFunction::ApproxMedian
                     }
                     AggregateFunction::Grouping => 
protobuf::AggregateFunction::Grouping,
+                    AggregateFunction::Median => 
protobuf::AggregateFunction::Median,
                 };
 
                 let aggregate_expr = protobuf::AggregateExprNode {

Reply via email to