This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 48a824650 Simplify approx_median implementation (#3064)
48a824650 is described below

commit 48a824650a2175a51e2be5e33127bbca6c266b9b
Author: Andy Grove <[email protected]>
AuthorDate: Mon Aug 8 06:20:41 2022 -0600

    Simplify approx_median implementation (#3064)
---
 datafusion/core/tests/dataframe_functions.rs       | 21 ++++++++++++++++++++
 datafusion/expr/src/expr_fn.rs                     |  9 +++++++++
 .../physical-expr/src/aggregate/approx_median.rs   | 23 +++++++++++++++-------
 datafusion/physical-expr/src/aggregate/build_in.rs |  4 ++--
 datafusion/sql/src/planner.rs                      | 20 ++++---------------
 5 files changed, 52 insertions(+), 25 deletions(-)

diff --git a/datafusion/core/tests/dataframe_functions.rs 
b/datafusion/core/tests/dataframe_functions.rs
index b126d010c..cefdaa777 100644
--- a/datafusion/core/tests/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe_functions.rs
@@ -33,6 +33,7 @@ use datafusion::prelude::*;
 use datafusion::execution::context::SessionContext;
 
 use datafusion::assert_batches_eq;
+use datafusion_expr::approx_median;
 
 fn create_test_table() -> Result<Arc<DataFrame>> {
     let schema = Arc::new(Schema::new(vec![
@@ -152,6 +153,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_fn_approx_median() -> Result<()> {
+    let expr = approx_median(col("b"));
+
+    let expected = vec![
+        "+----------------------+",
+        "| APPROXMEDIAN(test.b) |",
+        "+----------------------+",
+        "| 10                   |",
+        "+----------------------+",
+    ];
+
+    let df = create_test_table()?;
+    let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
+
+    assert_batches_eq!(expected, &batches);
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn test_fn_approx_percentile_cont() -> Result<()> {
     let expr = approx_percentile_cont(col("b"), lit(0.5));
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 97bbd419e..5c95ba51d 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -166,6 +166,15 @@ pub fn approx_distinct(expr: Expr) -> Expr {
     }
 }
 
+/// Calculate an approximation of the median for `expr`.
+pub fn approx_median(expr: Expr) -> Expr {
+    Expr::AggregateFunction {
+        fun: aggregate_function::AggregateFunction::ApproxMedian,
+        distinct: false,
+        args: vec![expr],
+    }
+}
+
 /// Calculate an approximation of the specified `percentile` for `expr`.
 pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
     Expr::AggregateFunction {
diff --git a/datafusion/physical-expr/src/aggregate/approx_median.rs 
b/datafusion/physical-expr/src/aggregate/approx_median.rs
index 651ccbdb7..3befd7a81 100644
--- a/datafusion/physical-expr/src/aggregate/approx_median.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_median.rs
@@ -17,6 +17,7 @@
 
 //! Defines physical expressions for APPROX_MEDIAN that can be evaluated 
MEDIAN at runtime during query execution
 
+use crate::expressions::{lit, ApproxPercentileCont};
 use crate::{AggregateExpr, PhysicalExpr};
 use arrow::{datatypes::DataType, datatypes::Field};
 use datafusion_common::Result;
@@ -30,20 +31,28 @@ pub struct ApproxMedian {
     name: String,
     expr: Arc<dyn PhysicalExpr>,
     data_type: DataType,
+    approx_percentile: ApproxPercentileCont,
 }
 
 impl ApproxMedian {
     /// Create a new APPROX_MEDIAN aggregate function
-    pub fn new(
+    pub fn try_new(
         expr: Arc<dyn PhysicalExpr>,
         name: impl Into<String>,
         data_type: DataType,
-    ) -> Self {
-        Self {
-            name: name.into(),
+    ) -> Result<Self> {
+        let name: String = name.into();
+        let approx_percentile = ApproxPercentileCont::new(
+            vec![expr.clone(), lit(0.5_f64)],
+            name.clone(),
+            data_type.clone(),
+        )?;
+        Ok(Self {
+            name,
             expr,
             data_type,
-        }
+            approx_percentile,
+        })
     }
 }
 
@@ -58,11 +67,11 @@ impl AggregateExpr for ApproxMedian {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        unimplemented!()
+        self.approx_percentile.create_accumulator()
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
-        unimplemented!()
+        self.approx_percentile.state_fields()
     }
 
     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 8d76e35e4..7bfa9e0a1 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -244,11 +244,11 @@ pub fn create_aggregate_expr(
             ));
         }
         (AggregateFunction::ApproxMedian, false) => {
-            Arc::new(expressions::ApproxMedian::new(
+            Arc::new(expressions::ApproxMedian::try_new(
                 coerced_phy_exprs[0].clone(),
                 name,
                 return_type,
-            ))
+            )?)
         }
         (AggregateFunction::ApproxMedian, true) => {
             return Err(DataFusionError::NotImplemented(
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 6c3e80e32..8c615a24a 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -43,7 +43,6 @@ use datafusion_expr::{
 };
 use hashbrown::HashMap;
 use std::collections::HashSet;
-use std::iter;
 use std::str::FromStr;
 use std::sync::Arc;
 use std::{convert::TryInto, vec};
@@ -2183,20 +2182,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     _ => self.sql_fn_arg_to_logical_expr(a, schema, &mut 
HashMap::new()),
                 })
                 .collect::<Result<Vec<Expr>>>()?,
-            AggregateFunction::ApproxMedian => function
-                .args
-                .into_iter()
-                .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, &mut 
HashMap::new()))
-                .chain(iter::once(Ok(lit(0.5_f64))))
-                .collect::<Result<Vec<Expr>>>()?,
             _ => self.function_args_to_expr(function.args, schema)?,
         };
 
-        let fun = match fun {
-            AggregateFunction::ApproxMedian => 
AggregateFunction::ApproxPercentileCont,
-            _ => fun,
-        };
-
         Ok((fun, args))
     }
 
@@ -3567,8 +3555,8 @@ mod tests {
     #[test]
     fn select_approx_median() {
         let sql = "SELECT approx_median(age) FROM person";
-        let expected = "Projection: 
#APPROXPERCENTILECONT(person.age,Float64(0.5))\
-                        \n  Aggregate: groupBy=[[]], 
aggr=[[APPROXPERCENTILECONT(#person.age, Float64(0.5))]]\
+        let expected = "Projection: #APPROXMEDIAN(person.age)\
+                        \n  Aggregate: groupBy=[[]], 
aggr=[[APPROXMEDIAN(#person.age)]]\
                         \n    TableScan: person";
         quick_test(sql, expected);
     }
@@ -4360,8 +4348,8 @@ mod tests {
         let sql =
             "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) 
from orders";
         let expected = "\
-        Projection: #orders.order_id, 
#APPROXPERCENTILECONT(orders.qty,Float64(0.5)) PARTITION BY [#orders.order_id]\
-        \n  WindowAggr: windowExpr=[[APPROXPERCENTILECONT(#orders.qty, 
Float64(0.5)) PARTITION BY [#orders.order_id]]]\
+        Projection: #orders.order_id, #APPROXMEDIAN(orders.qty) PARTITION BY 
[#orders.order_id]\
+        \n  WindowAggr: windowExpr=[[APPROXMEDIAN(#orders.qty) PARTITION BY 
[#orders.order_id]]]\
         \n    TableScan: orders";
         quick_test(sql, expected);
     }

Reply via email to