This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3773fb7fb5 Convert `approx_median` to UDAF (#10840)
3773fb7fb5 is described below

commit 3773fb7fb54419f889e7d18b73e9eb48069eb08e
Author: Jax Liu <[email protected]>
AuthorDate: Mon Jun 10 14:45:15 2024 +0800

    Convert `approx_median` to UDAF (#10840)
    
    * move tdigest to physical-expr-common
    
    * move approx_percentile_cont_accumulator to function-aggregate
    
    * implement approx_meidan udaf
    
    * remove approx_median aggregation function
    
    * fix sqllogictests
    
    * add removed type tests
    
    * cargo fmt and clippy
    
    * add logical roundtrip test
    
    * fix dataframe test
    
    * fix test and proto gen
    
    * update lock in datafusion-cli
    
    * fix typo
    
    * fix test and doc
    
    * fix sql_integration
    
    * cargo fmt
    
    * follow the checking style like other udaf
    
    * add comment and modified dependency
    
    * update lock and fmt
    
    * add missing test annotation
---
 datafusion-cli/Cargo.lock                          |  20 +-
 .../core/tests/dataframe/dataframe_functions.rs    |   3 +-
 datafusion/expr/src/aggregate_function.rs          |   8 +-
 datafusion/expr/src/expr_fn.rs                     |  12 -
 datafusion/expr/src/type_coercion/aggregates.rs    |  10 -
 .../functions-aggregate/src/approx_median.rs       | 129 +++++++++++
 .../src/approx_percentile_cont.rs                  | 255 +++++++++++++++++++++
 datafusion/functions-aggregate/src/lib.rs          |   5 +
 datafusion/functions-aggregate/src/sum.rs          |   2 +-
 .../physical-expr-common/src/aggregate/mod.rs      |   1 +
 .../src/aggregate/tdigest.rs                       |  46 ++--
 .../physical-expr/src/aggregate/approx_median.rs   |  99 --------
 .../src/aggregate/approx_percentile_cont.rs        | 246 +-------------------
 .../approx_percentile_cont_with_weight.rs          |   6 +-
 datafusion/physical-expr/src/aggregate/build_in.rs |  71 +-----
 datafusion/physical-expr/src/aggregate/mod.rs      |   2 -
 datafusion/physical-expr/src/expressions/mod.rs    |   1 -
 datafusion/proto/proto/datafusion.proto            |   2 +-
 datafusion/proto/src/generated/pbjson.rs           |   3 -
 datafusion/proto/src/generated/prost.rs            |   4 +-
 datafusion/proto/src/logical_plan/from_proto.rs    |   1 -
 datafusion/proto/src/logical_plan/to_proto.rs      |   4 -
 datafusion/proto/src/physical_plan/to_proto.rs     |  14 +-
 .../proto/tests/cases/roundtrip_logical_plan.rs    |   2 +
 datafusion/sql/tests/sql_integration.rs            |  12 +-
 datafusion/sqllogictest/test_files/aggregate.slt   |  10 +
 26 files changed, 471 insertions(+), 497 deletions(-)

diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 95b114ca4a..932f44d984 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -2880,9 +2880,9 @@ dependencies = [
 
 [[package]]
 name = "regex"
-version = "1.10.4"
+version = "1.10.5"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
+checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
 dependencies = [
  "aho-corasick",
  "memchr",
@@ -2892,9 +2892,9 @@ dependencies = [
 
 [[package]]
 name = "regex-automata"
-version = "0.4.6"
+version = "0.4.7"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
+checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
 dependencies = [
  "aho-corasick",
  "memchr",
@@ -2903,15 +2903,15 @@ dependencies = [
 
 [[package]]
 name = "regex-lite"
-version = "0.1.5"
+version = "0.1.6"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e"
+checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a"
 
 [[package]]
 name = "regex-syntax"
-version = "0.8.3"
+version = "0.8.4"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
+checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
 
 [[package]]
 name = "reqwest"
@@ -3846,9 +3846,9 @@ checksum = 
"daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
 
 [[package]]
 name = "utf8parse"
-version = "0.2.1"
+version = "0.2.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
+checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
 
 [[package]]
 name = "uuid"
diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs 
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index 7d155bb16c..b05769a6ce 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -33,6 +33,7 @@ use datafusion::assert_batches_eq;
 use datafusion_common::{DFSchema, ScalarValue};
 use datafusion_expr::expr::Alias;
 use datafusion_expr::ExprSchemable;
+use datafusion_functions_aggregate::expr_fn::approx_median;
 
 fn test_schema() -> SchemaRef {
     Arc::new(Schema::new(vec![
@@ -342,7 +343,7 @@ async fn test_fn_approx_median() -> Result<()> {
 
     let expected = [
         "+-----------------------+",
-        "| APPROX_MEDIAN(test.b) |",
+        "| approx_median(test.b) |",
         "+-----------------------+",
         "| 10                    |",
         "+-----------------------+",
diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index 9e4f7a50ac..6227df814f 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -71,8 +71,6 @@ pub enum AggregateFunction {
     ApproxPercentileCont,
     /// Approximate continuous percentile function with weight
     ApproxPercentileContWithWeight,
-    /// ApproxMedian
-    ApproxMedian,
     /// Grouping
     Grouping,
     /// Bit And
@@ -112,7 +110,6 @@ impl AggregateFunction {
             RegrSXY => "REGR_SXY",
             ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
             ApproxPercentileContWithWeight => 
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
-            ApproxMedian => "APPROX_MEDIAN",
             Grouping => "GROUPING",
             BitAnd => "BIT_AND",
             BitOr => "BIT_OR",
@@ -161,7 +158,6 @@ impl FromStr for AggregateFunction {
             "regr_sxy" => AggregateFunction::RegrSXY,
             // approximate
             "approx_distinct" => AggregateFunction::ApproxDistinct,
-            "approx_median" => AggregateFunction::ApproxMedian,
             "approx_percentile_cont" => 
AggregateFunction::ApproxPercentileCont,
             "approx_percentile_cont_with_weight" => {
                 AggregateFunction::ApproxPercentileContWithWeight
@@ -234,7 +230,6 @@ impl AggregateFunction {
             AggregateFunction::ApproxPercentileContWithWeight => {
                 Ok(coerced_data_types[0].clone())
             }
-            AggregateFunction::ApproxMedian => 
Ok(coerced_data_types[0].clone()),
             AggregateFunction::Grouping => Ok(DataType::Int32),
             AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
             AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
@@ -284,7 +279,8 @@ impl AggregateFunction {
             AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
                 Signature::uniform(1, vec![DataType::Boolean], 
Volatility::Immutable)
             }
-            AggregateFunction::Avg | AggregateFunction::ApproxMedian => {
+
+            AggregateFunction::Avg => {
                 Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
             }
             AggregateFunction::NthValue => Signature::any(2, 
Volatility::Immutable),
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 0360478eac..5626d343a6 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -284,18 +284,6 @@ pub fn approx_distinct(expr: Expr) -> Expr {
     ))
 }
 
-/// Calculate an approximation of the median for `expr`.
-pub fn approx_median(expr: Expr) -> Expr {
-    Expr::AggregateFunction(AggregateFunction::new(
-        aggregate_function::AggregateFunction::ApproxMedian,
-        vec![expr],
-        false,
-        None,
-        None,
-        None,
-    ))
-}
-
 /// Calculate an approximation of the specified `percentile` for `expr`.
 pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
     Expr::AggregateFunction(AggregateFunction::new(
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs 
b/datafusion/expr/src/type_coercion/aggregates.rs
index 4b4d526532..efd3c9f371 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -231,16 +231,6 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
-        AggregateFunction::ApproxMedian => {
-            if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
-                return plan_err!(
-                    "The function {:?} does not support inputs of type {:?}.",
-                    agg_fun,
-                    input_types[0]
-                );
-            }
-            Ok(input_types.to_vec())
-        }
         AggregateFunction::NthValue => Ok(input_types.to_vec()),
         AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
         AggregateFunction::StringAgg => {
diff --git a/datafusion/functions-aggregate/src/approx_median.rs 
b/datafusion/functions-aggregate/src/approx_median.rs
new file mode 100644
index 0000000000..b8b86d3055
--- /dev/null
+++ b/datafusion/functions-aggregate/src/approx_median.rs
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions for APPROX_MEDIAN that can be evaluated 
MEDIAN at runtime during query execution
+
+use std::any::Any;
+use std::fmt::Debug;
+
+use arrow::{datatypes::DataType, datatypes::Field};
+use arrow_schema::DataType::{Float64, UInt64};
+
+use datafusion_common::{not_impl_err, plan_err, Result};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
+use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref;
+
+use crate::approx_percentile_cont::ApproxPercentileAccumulator;
+
+make_udaf_expr_and_func!(
+    ApproxMedian,
+    approx_median,
+    expression,
+    "Computes the approximate median of a set of numbers",
+    approx_median_udaf
+);
+
+/// APPROX_MEDIAN aggregate expression
+pub struct ApproxMedian {
+    signature: Signature,
+}
+
+impl Debug for ApproxMedian {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        f.debug_struct("ApproxMedian")
+            .field("name", &self.name())
+            .field("signature", &self.signature)
+            .finish()
+    }
+}
+
+impl Default for ApproxMedian {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl ApproxMedian {
+    /// Create a new APPROX_MEDIAN aggregate function
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::uniform(1, NUMERICS.to_vec(), 
Volatility::Immutable),
+        }
+    }
+}
+
+impl AggregateUDFImpl for ApproxMedian {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+        Ok(vec![
+            Field::new(format_state_name(args.name, "max_size"), UInt64, 
false),
+            Field::new(format_state_name(args.name, "sum"), Float64, false),
+            Field::new(format_state_name(args.name, "count"), Float64, false),
+            Field::new(format_state_name(args.name, "max"), Float64, false),
+            Field::new(format_state_name(args.name, "min"), Float64, false),
+            Field::new_list(
+                format_state_name(args.name, "centroids"),
+                Field::new("item", Float64, true),
+                false,
+            ),
+        ])
+    }
+
+    fn name(&self) -> &str {
+        "approx_median"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        if !arg_types[0].is_numeric() {
+            return plan_err!("ApproxMedian requires numeric input types");
+        }
+        Ok(arg_types[0].clone())
+    }
+
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        if acc_args.is_distinct {
+            return not_impl_err!(
+                "APPROX_MEDIAN(DISTINCT) aggregations are not available"
+            );
+        }
+
+        Ok(Box::new(ApproxPercentileAccumulator::new(
+            0.5_f64,
+            acc_args.input_type.clone(),
+        )))
+    }
+}
+
+impl PartialEq<dyn Any> for ApproxMedian {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| self.signature == x.signature)
+            .unwrap_or(false)
+    }
+}
diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs 
b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
new file mode 100644
index 0000000000..e75417efc6
--- /dev/null
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
@@ -0,0 +1,255 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::{
+    array::{
+        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, 
Int64Array,
+        Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
+    },
+    datatypes::DataType,
+};
+
+use datafusion_common::{downcast_value, internal_err, DataFusionError, 
ScalarValue};
+use datafusion_expr::Accumulator;
+use datafusion_physical_expr_common::aggregate::tdigest::{
+    TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
+};
+
+#[derive(Debug)]
+pub struct ApproxPercentileAccumulator {
+    digest: TDigest,
+    percentile: f64,
+    return_type: DataType,
+}
+
+impl ApproxPercentileAccumulator {
+    pub fn new(percentile: f64, return_type: DataType) -> Self {
+        Self {
+            digest: TDigest::new(DEFAULT_MAX_SIZE),
+            percentile,
+            return_type,
+        }
+    }
+
+    pub fn new_with_max_size(
+        percentile: f64,
+        return_type: DataType,
+        max_size: usize,
+    ) -> Self {
+        Self {
+            digest: TDigest::new(max_size),
+            percentile,
+            return_type,
+        }
+    }
+
+    // public for approx_percentile_cont_with_weight
+    pub fn merge_digests(&mut self, digests: &[TDigest]) {
+        let digests = digests.iter().chain(std::iter::once(&self.digest));
+        self.digest = TDigest::merge_digests(digests)
+    }
+
+    // public for approx_percentile_cont_with_weight
+    pub fn convert_to_float(values: &ArrayRef) -> 
datafusion_common::Result<Vec<f64>> {
+        match values.data_type() {
+            DataType::Float64 => {
+                let array = downcast_value!(values, Float64Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::Float32 => {
+                let array = downcast_value!(values, Float32Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::Int64 => {
+                let array = downcast_value!(values, Int64Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::Int32 => {
+                let array = downcast_value!(values, Int32Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::Int16 => {
+                let array = downcast_value!(values, Int16Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::Int8 => {
+                let array = downcast_value!(values, Int8Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::UInt64 => {
+                let array = downcast_value!(values, UInt64Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::UInt32 => {
+                let array = downcast_value!(values, UInt32Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::UInt16 => {
+                let array = downcast_value!(values, UInt16Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            DataType::UInt8 => {
+                let array = downcast_value!(values, UInt8Array);
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<datafusion_common::Result<Vec<_>>>()?)
+            }
+            e => internal_err!(
+                "APPROX_PERCENTILE_CONT is not expected to receive the type 
{e:?}"
+            ),
+        }
+    }
+}
+
+impl Accumulator for ApproxPercentileAccumulator {
+    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
+        Ok(self.digest.to_scalar_state().into_iter().collect())
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> 
datafusion_common::Result<()> {
+        let values = &values[0];
+        let sorted_values = &arrow::compute::sort(values, None)?;
+        let sorted_values = 
ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
+        self.digest = self.digest.merge_sorted_f64(&sorted_values);
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
+        if self.digest.count() == 0.0 {
+            return ScalarValue::try_from(self.return_type.clone());
+        }
+        let q = self.digest.estimate_quantile(self.percentile);
+
+        // These acceptable return types MUST match the validation in
+        // ApproxPercentile::create_accumulator.
+        Ok(match &self.return_type {
+            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
+            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
+            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
+            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
+            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
+            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
+            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
+            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
+            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
+            DataType::Float64 => ScalarValue::Float64(Some(q)),
+            v => unreachable!("unexpected return type {:?}", v),
+        })
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> 
datafusion_common::Result<()> {
+        if states.is_empty() {
+            return Ok(());
+        }
+
+        let states = (0..states[0].len())
+            .map(|index| {
+                states
+                    .iter()
+                    .map(|array| ScalarValue::try_from_array(array, index))
+                    .collect::<datafusion_common::Result<Vec<_>>>()
+                    .map(|state| TDigest::from_scalar_state(&state))
+            })
+            .collect::<datafusion_common::Result<Vec<_>>>()?;
+
+        self.merge_digests(&states);
+
+        Ok(())
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self) + self.digest.size()
+            - std::mem::size_of_val(&self.digest)
+            + self.return_type.size()
+            - std::mem::size_of_val(&self.return_type)
+    }
+
+    fn supports_retract_batch(&self) -> bool {
+        true
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use arrow_schema::DataType;
+
+    use datafusion_physical_expr_common::aggregate::tdigest::TDigest;
+
+    use crate::approx_percentile_cont::ApproxPercentileAccumulator;
+
+    #[test]
+    fn test_combine_approx_percentile_accumulator() {
+        let mut digests: Vec<TDigest> = Vec::new();
+
+        // one TDigest with 50_000 values from 1 to 1_000
+        for _ in 1..=50 {
+            let t = TDigest::new(100);
+            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
+            let t = t.merge_unsorted_f64(values);
+            digests.push(t)
+        }
+
+        let t1 = TDigest::merge_digests(&digests);
+        let t2 = TDigest::merge_digests(&digests);
+
+        let mut accumulator =
+            ApproxPercentileAccumulator::new_with_max_size(0.5, 
DataType::Float64, 100);
+
+        accumulator.merge_digests(&[t1]);
+        assert_eq!(accumulator.digest.count(), 50_000.0);
+        accumulator.merge_digests(&[t2]);
+        assert_eq!(accumulator.digest.count(), 100_000.0);
+    }
+}
diff --git a/datafusion/functions-aggregate/src/lib.rs 
b/datafusion/functions-aggregate/src/lib.rs
index b8a2e7032a..274ab8302e 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -62,6 +62,9 @@ pub mod stddev;
 pub mod sum;
 pub mod variance;
 
+pub mod approx_median;
+pub mod approx_percentile_cont;
+
 use datafusion_common::Result;
 use datafusion_execution::FunctionRegistry;
 use datafusion_expr::AggregateUDF;
@@ -70,6 +73,7 @@ use std::sync::Arc;
 
 /// Fluent-style API for creating `Expr`s
 pub mod expr_fn {
+    pub use super::approx_median::approx_median;
     pub use super::covariance::covar_pop;
     pub use super::covariance::covar_samp;
     pub use super::first_last::first_value;
@@ -95,6 +99,7 @@ pub fn all_default_aggregate_functions() -> 
Vec<Arc<AggregateUDF>> {
         variance::var_pop_udaf(),
         stddev::stddev_udaf(),
         stddev::stddev_pop_udaf(),
+        approx_median::approx_median_udaf(),
     ]
 }
 
diff --git a/datafusion/functions-aggregate/src/sum.rs 
b/datafusion/functions-aggregate/src/sum.rs
index 9d3fa25222..b9293bc2ca 100644
--- a/datafusion/functions-aggregate/src/sum.rs
+++ b/datafusion/functions-aggregate/src/sum.rs
@@ -46,7 +46,7 @@ make_udaf_expr_and_func!(
     Sum,
     sum,
     expression,
-    "Returns the first value in a group of values.",
+    "Returns the sum of a group of values.",
     sum_udaf
 );
 
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs 
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 2273418c60..ec02df57b8 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -17,6 +17,7 @@
 
 pub mod groups_accumulator;
 pub mod stats;
+pub mod tdigest;
 pub mod utils;
 
 use arrow::datatypes::{DataType, Field, Schema};
diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs 
b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
similarity index 95%
rename from datafusion/physical-expr/src/aggregate/tdigest.rs
rename to datafusion/physical-expr-common/src/aggregate/tdigest.rs
index e3b23b91d0..5107d0ab8e 100644
--- a/datafusion/physical-expr/src/aggregate/tdigest.rs
+++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
@@ -28,7 +28,7 @@
 //! [Facebook's Folly TDigest]: 
https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h
 
 use arrow::datatypes::DataType;
-use arrow_array::types::Float64Type;
+use arrow::datatypes::Float64Type;
 use datafusion_common::cast::as_primitive_array;
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
@@ -50,7 +50,7 @@ macro_rules! cast_scalar_f64 {
 /// This trait is implemented for each type a [`TDigest`] can operate on,
 /// allowing it to support both numerical rust types (obtained from
 /// `PrimitiveArray` instances), and [`ScalarValue`] instances.
-pub(crate) trait TryIntoF64 {
+pub trait TryIntoF64 {
     /// A fallible conversion of a possibly null `self` into a [`f64`].
     ///
     /// If `self` is null, this method must return `Ok(None)`.
@@ -84,7 +84,7 @@ impl_try_ordered_f64!(u8);
 
 /// Centroid implementation to the cluster mentioned in the paper.
 #[derive(Debug, PartialEq, Clone)]
-pub(crate) struct Centroid {
+pub struct Centroid {
     mean: f64,
     weight: f64,
 }
@@ -104,21 +104,21 @@ impl Ord for Centroid {
 }
 
 impl Centroid {
-    pub(crate) fn new(mean: f64, weight: f64) -> Self {
+    pub fn new(mean: f64, weight: f64) -> Self {
         Centroid { mean, weight }
     }
 
     #[inline]
-    pub(crate) fn mean(&self) -> f64 {
+    pub fn mean(&self) -> f64 {
         self.mean
     }
 
     #[inline]
-    pub(crate) fn weight(&self) -> f64 {
+    pub fn weight(&self) -> f64 {
         self.weight
     }
 
-    pub(crate) fn add(&mut self, sum: f64, weight: f64) -> f64 {
+    pub fn add(&mut self, sum: f64, weight: f64) -> f64 {
         let new_sum = sum + self.weight * self.mean;
         let new_weight = self.weight + weight;
         self.weight = new_weight;
@@ -138,7 +138,7 @@ impl Default for Centroid {
 
 /// T-Digest to be operated on.
 #[derive(Debug, PartialEq, Clone)]
-pub(crate) struct TDigest {
+pub struct TDigest {
     centroids: Vec<Centroid>,
     max_size: usize,
     sum: f64,
@@ -148,7 +148,7 @@ pub(crate) struct TDigest {
 }
 
 impl TDigest {
-    pub(crate) fn new(max_size: usize) -> Self {
+    pub fn new(max_size: usize) -> Self {
         TDigest {
             centroids: Vec::new(),
             max_size,
@@ -159,7 +159,7 @@ impl TDigest {
         }
     }
 
-    pub(crate) fn new_with_centroid(max_size: usize, centroid: Centroid) -> 
Self {
+    pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self {
         TDigest {
             centroids: vec![centroid.clone()],
             max_size,
@@ -171,27 +171,27 @@ impl TDigest {
     }
 
     #[inline]
-    pub(crate) fn count(&self) -> f64 {
+    pub fn count(&self) -> f64 {
         self.count
     }
 
     #[inline]
-    pub(crate) fn max(&self) -> f64 {
+    pub fn max(&self) -> f64 {
         self.max
     }
 
     #[inline]
-    pub(crate) fn min(&self) -> f64 {
+    pub fn min(&self) -> f64 {
         self.min
     }
 
     #[inline]
-    pub(crate) fn max_size(&self) -> usize {
+    pub fn max_size(&self) -> usize {
         self.max_size
     }
 
     /// Size in bytes including `Self`.
-    pub(crate) fn size(&self) -> usize {
+    pub fn size(&self) -> usize {
         std::mem::size_of_val(self)
             + (std::mem::size_of::<Centroid>() * self.centroids.capacity())
     }
@@ -228,14 +228,14 @@ impl TDigest {
         v.clamp(lo, hi)
     }
 
-    #[cfg(test)]
-    pub(crate) fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> 
TDigest {
+    // public for testing in other modules
+    pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest {
         let mut values = unsorted_values;
         values.sort_by(|a, b| a.total_cmp(b));
         self.merge_sorted_f64(&values)
     }
 
-    pub(crate) fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
+    pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
         #[cfg(debug_assertions)]
         debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
 
@@ -370,9 +370,7 @@ impl TDigest {
     }
 
     // Merge multiple T-Digests
-    pub(crate) fn merge_digests<'a>(
-        digests: impl IntoIterator<Item = &'a TDigest>,
-    ) -> TDigest {
+    pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>) 
-> TDigest {
         let digests = digests.into_iter().collect::<Vec<_>>();
         let n_centroids: usize = digests.iter().map(|d| 
d.centroids.len()).sum();
         if n_centroids == 0 {
@@ -465,7 +463,7 @@ impl TDigest {
     }
 
     /// To estimate the value located at `q` quantile
-    pub(crate) fn estimate_quantile(&self, q: f64) -> f64 {
+    pub fn estimate_quantile(&self, q: f64) -> f64 {
         if self.centroids.is_empty() {
             return 0.0;
         }
@@ -569,7 +567,7 @@ impl TDigest {
     /// The [`TDigest::from_scalar_state()`] method reverses this processes,
     /// consuming the output of this method and returning an unpacked
     /// [`TDigest`].
-    pub(crate) fn to_scalar_state(&self) -> Vec<ScalarValue> {
+    pub fn to_scalar_state(&self) -> Vec<ScalarValue> {
         // Gather up all the centroids
         let centroids: Vec<ScalarValue> = self
             .centroids
@@ -598,7 +596,7 @@ impl TDigest {
     /// Providing input to this method that was not obtained from
     /// [`Self::to_scalar_state()`] results in undefined behaviour and may
     /// panic.
-    pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self {
+    pub fn from_scalar_state(state: &[ScalarValue]) -> Self {
         assert_eq!(state.len(), 6, "invalid TDigest state");
 
         let max_size = match &state[0] {
diff --git a/datafusion/physical-expr/src/aggregate/approx_median.rs 
b/datafusion/physical-expr/src/aggregate/approx_median.rs
deleted file mode 100644
index cbbfef5a89..0000000000
--- a/datafusion/physical-expr/src/aggregate/approx_median.rs
+++ /dev/null
@@ -1,99 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Defines physical expressions for APPROX_MEDIAN that can be evaluated 
MEDIAN at runtime during query execution
-
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::{lit, ApproxPercentileCont};
-use crate::{AggregateExpr, PhysicalExpr};
-use arrow::{datatypes::DataType, datatypes::Field};
-use datafusion_common::Result;
-use datafusion_expr::Accumulator;
-use std::any::Any;
-use std::sync::Arc;
-
-/// MEDIAN aggregate expression
-#[derive(Debug)]
-pub struct ApproxMedian {
-    name: String,
-    expr: Arc<dyn PhysicalExpr>,
-    data_type: DataType,
-    approx_percentile: ApproxPercentileCont,
-}
-
-impl ApproxMedian {
-    /// Create a new APPROX_MEDIAN aggregate function
-    pub fn try_new(
-        expr: Arc<dyn PhysicalExpr>,
-        name: impl Into<String>,
-        data_type: DataType,
-    ) -> Result<Self> {
-        let name: String = name.into();
-        let approx_percentile = ApproxPercentileCont::new(
-            vec![expr.clone(), lit(0.5_f64)],
-            name.clone(),
-            data_type.clone(),
-        )?;
-        Ok(Self {
-            name,
-            expr,
-            data_type,
-            approx_percentile,
-        })
-    }
-}
-
-impl AggregateExpr for ApproxMedian {
-    /// Return a reference to Any that can be used for downcasting
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn field(&self) -> Result<Field> {
-        Ok(Field::new(&self.name, self.data_type.clone(), true))
-    }
-
-    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        self.approx_percentile.create_accumulator()
-    }
-
-    fn state_fields(&self) -> Result<Vec<Field>> {
-        self.approx_percentile.state_fields()
-    }
-
-    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
-        vec![self.expr.clone()]
-    }
-
-    fn name(&self) -> &str {
-        &self.name
-    }
-}
-
-impl PartialEq<dyn Any> for ApproxMedian {
-    fn eq(&self, other: &dyn Any) -> bool {
-        down_cast_any_ref(other)
-            .downcast_ref::<Self>()
-            .map(|x| {
-                self.name == x.name
-                    && self.data_type == x.data_type
-                    && self.expr.eq(&x.expr)
-                    && self.approx_percentile == x.approx_percentile
-            })
-            .unwrap_or(false)
-    }
-}
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 63a4c85f9e..f2068bbc92 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -15,26 +15,19 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::aggregate::tdigest::TryIntoF64;
-use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::format_state_name;
-use crate::{AggregateExpr, PhysicalExpr};
-use arrow::{
-    array::{
-        ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, 
Int64Array,
-        Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
-    },
-    datatypes::{DataType, Field},
-};
+use std::{any::Any, sync::Arc};
+
+use arrow::datatypes::{DataType, Field};
 use arrow_array::RecordBatch;
 use arrow_schema::Schema;
-use datafusion_common::{
-    downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, 
Result,
-    ScalarValue,
-};
+
+use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result, 
ScalarValue};
 use datafusion_expr::{Accumulator, ColumnarValue};
-use std::{any::Any, sync::Arc};
+use 
datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
+
+use crate::aggregate::utils::down_cast_any_ref;
+use crate::expressions::format_state_name;
+use crate::{AggregateExpr, PhysicalExpr};
 
 /// APPROX_PERCENTILE_CONT aggregate expression
 #[derive(Debug)]
@@ -195,7 +188,7 @@ impl AggregateExpr for ApproxPercentileCont {
     }
 
     #[allow(rustdoc::private_intra_doc_links)]
-    /// See [`TDigest::to_scalar_state()`] for a description of the serialised
+    /// See 
[`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`]
 for a description of the serialised
     /// state.
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![
@@ -254,220 +247,3 @@ impl PartialEq<dyn Any> for ApproxPercentileCont {
             .unwrap_or(false)
     }
 }
-
-#[derive(Debug)]
-pub struct ApproxPercentileAccumulator {
-    digest: TDigest,
-    percentile: f64,
-    return_type: DataType,
-}
-
-impl ApproxPercentileAccumulator {
-    pub fn new(percentile: f64, return_type: DataType) -> Self {
-        Self {
-            digest: TDigest::new(DEFAULT_MAX_SIZE),
-            percentile,
-            return_type,
-        }
-    }
-
-    pub fn new_with_max_size(
-        percentile: f64,
-        return_type: DataType,
-        max_size: usize,
-    ) -> Self {
-        Self {
-            digest: TDigest::new(max_size),
-            percentile,
-            return_type,
-        }
-    }
-
-    pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
-        let digests = digests.iter().chain(std::iter::once(&self.digest));
-        self.digest = TDigest::merge_digests(digests)
-    }
-
-    pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
-        match values.data_type() {
-            DataType::Float64 => {
-                let array = downcast_value!(values, Float64Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::Float32 => {
-                let array = downcast_value!(values, Float32Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::Int64 => {
-                let array = downcast_value!(values, Int64Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::Int32 => {
-                let array = downcast_value!(values, Int32Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::Int16 => {
-                let array = downcast_value!(values, Int16Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::Int8 => {
-                let array = downcast_value!(values, Int8Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::UInt64 => {
-                let array = downcast_value!(values, UInt64Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::UInt32 => {
-                let array = downcast_value!(values, UInt32Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::UInt16 => {
-                let array = downcast_value!(values, UInt16Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            DataType::UInt8 => {
-                let array = downcast_value!(values, UInt8Array);
-                Ok(array
-                    .values()
-                    .iter()
-                    .filter_map(|v| v.try_as_f64().transpose())
-                    .collect::<Result<Vec<_>>>()?)
-            }
-            e => internal_err!(
-                "APPROX_PERCENTILE_CONT is not expected to receive the type 
{e:?}"
-            ),
-        }
-    }
-}
-
-impl Accumulator for ApproxPercentileAccumulator {
-    fn state(&mut self) -> Result<Vec<ScalarValue>> {
-        Ok(self.digest.to_scalar_state().into_iter().collect())
-    }
-
-    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = &values[0];
-        let sorted_values = &arrow::compute::sort(values, None)?;
-        let sorted_values = 
ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
-        self.digest = self.digest.merge_sorted_f64(&sorted_values);
-        Ok(())
-    }
-
-    fn evaluate(&mut self) -> Result<ScalarValue> {
-        if self.digest.count() == 0.0 {
-            return ScalarValue::try_from(self.return_type.clone());
-        }
-        let q = self.digest.estimate_quantile(self.percentile);
-
-        // These acceptable return types MUST match the validation in
-        // ApproxPercentile::create_accumulator.
-        Ok(match &self.return_type {
-            DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
-            DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
-            DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
-            DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
-            DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
-            DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
-            DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
-            DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
-            DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
-            DataType::Float64 => ScalarValue::Float64(Some(q)),
-            v => unreachable!("unexpected return type {:?}", v),
-        })
-    }
-
-    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        if states.is_empty() {
-            return Ok(());
-        }
-
-        let states = (0..states[0].len())
-            .map(|index| {
-                states
-                    .iter()
-                    .map(|array| ScalarValue::try_from_array(array, index))
-                    .collect::<Result<Vec<_>>>()
-                    .map(|state| TDigest::from_scalar_state(&state))
-            })
-            .collect::<Result<Vec<_>>>()?;
-
-        self.merge_digests(&states);
-
-        Ok(())
-    }
-
-    fn size(&self) -> usize {
-        std::mem::size_of_val(self) + self.digest.size()
-            - std::mem::size_of_val(&self.digest)
-            + self.return_type.size()
-            - std::mem::size_of_val(&self.return_type)
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
-    use crate::aggregate::tdigest::TDigest;
-    use arrow_schema::DataType;
-
-    #[test]
-    fn test_combine_approx_percentile_accumulator() {
-        let mut digests: Vec<TDigest> = Vec::new();
-
-        // one TDigest with 50_000 values from 1 to 1_000
-        for _ in 1..=50 {
-            let t = TDigest::new(100);
-            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
-            let t = t.merge_unsorted_f64(values);
-            digests.push(t)
-        }
-
-        let t1 = TDigest::merge_digests(&digests);
-        let t2 = TDigest::merge_digests(&digests);
-
-        let mut accumulator =
-            ApproxPercentileAccumulator::new_with_max_size(0.5, 
DataType::Float64, 100);
-
-        accumulator.merge_digests(&[t1]);
-        assert_eq!(accumulator.digest.count(), 50_000.0);
-        accumulator.merge_digests(&[t2]);
-        assert_eq!(accumulator.digest.count(), 100_000.0);
-    }
-}
diff --git 
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
index 3fa715a592..07c2aff343 100644
--- 
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
+++ 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
@@ -15,14 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
-use crate::aggregate::tdigest::{Centroid, TDigest, DEFAULT_MAX_SIZE};
 use crate::expressions::ApproxPercentileCont;
 use crate::{AggregateExpr, PhysicalExpr};
 use arrow::{
     array::ArrayRef,
     datatypes::{DataType, Field},
 };
+use 
datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
+use datafusion_physical_expr_common::aggregate::tdigest::{
+    Centroid, TDigest, DEFAULT_MAX_SIZE,
+};
 
 use datafusion_common::Result;
 use datafusion_common::ScalarValue;
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index f0cff53fb3..89de6ad49c 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -280,18 +280,6 @@ pub fn create_aggregate_expr(
                 "approx_percentile_cont_with_weight(DISTINCT) aggregations are 
not available"
             );
         }
-        (AggregateFunction::ApproxMedian, false) => {
-            Arc::new(expressions::ApproxMedian::try_new(
-                input_phy_exprs[0].clone(),
-                name,
-                data_type,
-            )?)
-        }
-        (AggregateFunction::ApproxMedian, true) => {
-            return not_impl_err!(
-                "APPROX_MEDIAN(DISTINCT) aggregations are not available"
-            );
-        }
         (AggregateFunction::NthValue, _) => {
             let expr = &input_phy_exprs[0];
             let Some(n) = input_phy_exprs[1]
@@ -337,9 +325,8 @@ mod tests {
 
     use super::*;
     use crate::expressions::{
-        try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, 
ArrayAgg, Avg,
-        BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, 
DistinctCount,
-        Max, Min,
+        try_cast, ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, 
BitOr,
+        BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, 
Min,
     };
 
     use datafusion_common::{plan_err, DataFusionError, ScalarValue};
@@ -686,60 +673,6 @@ mod tests {
         Ok(())
     }
 
-    #[test]
-    fn test_median_expr() -> Result<()> {
-        let funcs = vec![AggregateFunction::ApproxMedian];
-        let data_types = vec![
-            DataType::UInt32,
-            DataType::UInt64,
-            DataType::Int32,
-            DataType::Int64,
-            DataType::Float32,
-            DataType::Float64,
-        ];
-        for fun in funcs {
-            for data_type in &data_types {
-                let input_schema =
-                    Schema::new(vec![Field::new("c1", data_type.clone(), 
true)]);
-                let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
-                    expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
-                )];
-                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
-                    &fun,
-                    false,
-                    &input_phy_exprs[0..1],
-                    &input_schema,
-                    "c1",
-                )?;
-
-                if fun == AggregateFunction::ApproxMedian {
-                    
assert!(result_agg_phy_exprs.as_any().is::<ApproxMedian>());
-                    assert_eq!("c1", result_agg_phy_exprs.name());
-                    assert_eq!(
-                        Field::new("c1", data_type.clone(), true),
-                        result_agg_phy_exprs.field().unwrap()
-                    );
-                }
-            }
-        }
-        Ok(())
-    }
-
-    #[test]
-    fn test_median() -> Result<()> {
-        let observed = 
AggregateFunction::ApproxMedian.return_type(&[DataType::Utf8]);
-        assert!(observed.is_err());
-
-        let observed = 
AggregateFunction::ApproxMedian.return_type(&[DataType::Int32])?;
-        assert_eq!(DataType::Int32, observed);
-
-        let observed =
-            
AggregateFunction::ApproxMedian.return_type(&[DataType::Decimal128(10, 6)]);
-        assert!(observed.is_err());
-
-        Ok(())
-    }
-
     #[test]
     fn test_min_max() -> Result<()> {
         let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?;
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs 
b/datafusion/physical-expr/src/aggregate/mod.rs
index 2c14c1550e..9db80f155a 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -18,10 +18,8 @@
 pub use datafusion_physical_expr_common::aggregate::AggregateExpr;
 
 mod hyperloglog;
-mod tdigest;
 
 pub(crate) mod approx_distinct;
-pub(crate) mod approx_median;
 pub(crate) mod approx_percentile_cont;
 pub(crate) mod approx_percentile_cont_with_weight;
 pub(crate) mod array_agg;
diff --git a/datafusion/physical-expr/src/expressions/mod.rs 
b/datafusion/physical-expr/src/expressions/mod.rs
index 476cbe3907..656cc570ca 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -38,7 +38,6 @@ pub mod helpers {
 }
 
 pub use crate::aggregate::approx_distinct::ApproxDistinct;
-pub use crate::aggregate::approx_median::ApproxMedian;
 pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont;
 pub use 
crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight;
 pub use crate::aggregate::array_agg::ArrayAgg;
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 0071a43bbe..9f23824b3a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -487,7 +487,7 @@ enum AggregateFunction {
   // STDDEV_POP = 12;
   CORRELATION = 13;
   APPROX_PERCENTILE_CONT = 14;
-  APPROX_MEDIAN = 15;
+  // APPROX_MEDIAN = 15;
   APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
   GROUPING = 17;
   // MEDIAN = 18;
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index e6aded8901..28f80c5ee1 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -540,7 +540,6 @@ impl serde::Serialize for AggregateFunction {
             Self::ArrayAgg => "ARRAY_AGG",
             Self::Correlation => "CORRELATION",
             Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
-            Self::ApproxMedian => "APPROX_MEDIAN",
             Self::ApproxPercentileContWithWeight => 
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
             Self::Grouping => "GROUPING",
             Self::BitAnd => "BIT_AND",
@@ -578,7 +577,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
             "ARRAY_AGG",
             "CORRELATION",
             "APPROX_PERCENTILE_CONT",
-            "APPROX_MEDIAN",
             "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
             "GROUPING",
             "BIT_AND",
@@ -645,7 +643,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
                     "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg),
                     "CORRELATION" => Ok(AggregateFunction::Correlation),
                     "APPROX_PERCENTILE_CONT" => 
Ok(AggregateFunction::ApproxPercentileCont),
-                    "APPROX_MEDIAN" => Ok(AggregateFunction::ApproxMedian),
                     "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => 
Ok(AggregateFunction::ApproxPercentileContWithWeight),
                     "GROUPING" => Ok(AggregateFunction::Grouping),
                     "BIT_AND" => Ok(AggregateFunction::BitAnd),
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 7ec9187491..9741b2bc42 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1931,7 +1931,7 @@ pub enum AggregateFunction {
     /// STDDEV_POP = 12;
     Correlation = 13,
     ApproxPercentileCont = 14,
-    ApproxMedian = 15,
+    /// APPROX_MEDIAN = 15;
     ApproxPercentileContWithWeight = 16,
     Grouping = 17,
     /// MEDIAN = 18;
@@ -1967,7 +1967,6 @@ impl AggregateFunction {
             AggregateFunction::ArrayAgg => "ARRAY_AGG",
             AggregateFunction::Correlation => "CORRELATION",
             AggregateFunction::ApproxPercentileCont => 
"APPROX_PERCENTILE_CONT",
-            AggregateFunction::ApproxMedian => "APPROX_MEDIAN",
             AggregateFunction::ApproxPercentileContWithWeight => {
                 "APPROX_PERCENTILE_CONT_WITH_WEIGHT"
             }
@@ -2001,7 +2000,6 @@ impl AggregateFunction {
             "ARRAY_AGG" => Some(Self::ArrayAgg),
             "CORRELATION" => Some(Self::Correlation),
             "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont),
-            "APPROX_MEDIAN" => Some(Self::ApproxMedian),
             "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => {
                 Some(Self::ApproxPercentileContWithWeight)
             }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index a77d361983..5c083fa27a 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -164,7 +164,6 @@ impl From<protobuf::AggregateFunction> for 
AggregateFunction {
             protobuf::AggregateFunction::ApproxPercentileContWithWeight => {
                 Self::ApproxPercentileContWithWeight
             }
-            protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian,
             protobuf::AggregateFunction::Grouping => Self::Grouping,
             protobuf::AggregateFunction::NthValueAgg => Self::NthValue,
             protobuf::AggregateFunction::StringAgg => Self::StringAgg,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 9c4c7685b3..e2259896b2 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -133,7 +133,6 @@ impl From<&AggregateFunction> for 
protobuf::AggregateFunction {
             AggregateFunction::ApproxPercentileContWithWeight => {
                 Self::ApproxPercentileContWithWeight
             }
-            AggregateFunction::ApproxMedian => Self::ApproxMedian,
             AggregateFunction::Grouping => Self::Grouping,
             AggregateFunction::NthValue => Self::NthValueAgg,
             AggregateFunction::StringAgg => Self::StringAgg,
@@ -430,9 +429,6 @@ pub fn serialize_expr(
                     AggregateFunction::RegrSXX => 
protobuf::AggregateFunction::RegrSxx,
                     AggregateFunction::RegrSYY => 
protobuf::AggregateFunction::RegrSyy,
                     AggregateFunction::RegrSXY => 
protobuf::AggregateFunction::RegrSxy,
-                    AggregateFunction::ApproxMedian => {
-                        protobuf::AggregateFunction::ApproxMedian
-                    }
                     AggregateFunction::Grouping => 
protobuf::AggregateFunction::Grouping,
                     AggregateFunction::NthValue => {
                         protobuf::AggregateFunction::NthValueAgg
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index 5d07d5c0fa..19ba4a40d5 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -23,12 +23,12 @@ use 
datafusion::datasource::file_format::parquet::ParquetSink;
 use datafusion::physical_expr::window::{NthValueKind, 
SlidingAggregateWindowExpr};
 use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
 use datafusion::physical_plan::expressions::{
-    ApproxDistinct, ApproxMedian, ApproxPercentileCont, 
ApproxPercentileContWithWeight,
-    ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, 
CaseExpr,
-    CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, 
DistinctBitXor,
-    DistinctCount, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, 
Max, Min,
-    NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, 
OrderSensitiveArrayAgg, Rank,
-    RankType, Regr, RegrType, RowNumber, StringAgg, TryCastExpr, WindowShift,
+    ApproxDistinct, ApproxPercentileCont, ApproxPercentileContWithWeight, 
ArrayAgg, Avg,
+    BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, 
Column,
+    Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, 
DistinctCount,
+    Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, 
NegativeExpr,
+    NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, 
RankType, Regr,
+    RegrType, RowNumber, StringAgg, TryCastExpr, WindowShift,
 };
 use datafusion::physical_plan::udaf::AggregateFunctionExpr;
 use datafusion::physical_plan::windows::{BuiltInWindowExpr, 
PlainAggregateWindowExpr};
@@ -296,8 +296,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> 
Result<AggrFn> {
         .is_some()
     {
         protobuf::AggregateFunction::ApproxPercentileContWithWeight
-    } else if aggr_expr.downcast_ref::<ApproxMedian>().is_some() {
-        protobuf::AggregateFunction::ApproxMedian
     } else if aggr_expr.downcast_ref::<StringAgg>().is_some() {
         protobuf::AggregateFunction::StringAgg
     } else if aggr_expr.downcast_ref::<NthValueAgg>().is_some() {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index b1cad69b14..699697dd2f 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -33,6 +33,7 @@ use datafusion::datasource::TableProvider;
 use datafusion::execution::context::SessionState;
 use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
 use datafusion::execution::FunctionRegistry;
+use datafusion::functions_aggregate::approx_median::approx_median;
 use datafusion::functions_aggregate::expr_fn::{
     covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, 
var_pop,
     var_sample,
@@ -658,6 +659,7 @@ async fn roundtrip_expr_api() -> Result<()> {
         var_pop(lit(2.2)),
         stddev(lit(2.2)),
         stddev_pop(lit(2.2)),
+        approx_median(lit(2)),
     ];
 
     // ensure expressions created with the expr api can be round tripped
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index 6a99f9719d..7b9d39a2b5 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -37,6 +37,7 @@ use datafusion_sql::{
     planner::{ParserOptions, SqlToRel},
 };
 
+use datafusion_functions_aggregate::approx_median::approx_median_udaf;
 use rstest::rstest;
 use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
 
@@ -1649,8 +1650,8 @@ fn select_count_column() {
 #[test]
 fn select_approx_median() {
     let sql = "SELECT approx_median(age) FROM person";
-    let expected = "Projection: APPROX_MEDIAN(person.age)\
-                        \n  Aggregate: groupBy=[[]], 
aggr=[[APPROX_MEDIAN(person.age)]]\
+    let expected = "Projection: approx_median(person.age)\
+                        \n  Aggregate: groupBy=[[]], 
aggr=[[approx_median(person.age)]]\
                         \n    TableScan: person";
     quick_test(sql, expected);
 }
@@ -2581,8 +2582,8 @@ fn approx_median_window() {
     let sql =
         "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from 
orders";
     let expected = "\
-        Projection: orders.order_id, APPROX_MEDIAN(orders.qty) PARTITION BY 
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\
-        \n  WindowAggr: windowExpr=[[APPROX_MEDIAN(orders.qty) PARTITION BY 
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
+        Projection: orders.order_id, approx_median(orders.qty) PARTITION BY 
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\
+        \n  WindowAggr: windowExpr=[[approx_median(orders.qty) PARTITION BY 
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
         \n    TableScan: orders";
     quick_test(sql, expected);
 }
@@ -2700,7 +2701,8 @@ fn logical_plan_with_dialect_and_options(
             DataType::Int32,
         ))
         .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64))
-        .with_udaf(sum_udaf());
+        .with_udaf(sum_udaf())
+        .with_udaf(approx_median_udaf());
 
     let planner = SqlToRel::new_with_options(&context, options);
     let result = DFParser::parse_sql_with_dialect(sql, dialect);
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 9958f8ac38..a245793ebd 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -518,6 +518,11 @@ SELECT approx_median(c12) FROM aggregate_test_100
 ----
 0.555006541052
 
+# csv_query_approx_median_4
+# test with string, approx median only supports numeric
+statement error
+SELECT approx_median(c1) FROM aggregate_test_100
+
 # csv_query_median_1
 query I
 SELECT median(c2) FROM aggregate_test_100
@@ -637,6 +642,11 @@ select median(c), arrow_typeof(median(c)) from t;
 ----
 0.0003 Decimal128(10, 4)
 
+query RT
+select approx_median(c), arrow_typeof(approx_median(c)) from t;
+----
+0.00035 Float64
+
 statement ok
 drop table t;
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to