This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 48a824650 Simplify approx_median implementation (#3064)
48a824650 is described below
commit 48a824650a2175a51e2be5e33127bbca6c266b9b
Author: Andy Grove <[email protected]>
AuthorDate: Mon Aug 8 06:20:41 2022 -0600
Simplify approx_median implementation (#3064)
---
datafusion/core/tests/dataframe_functions.rs | 21 ++++++++++++++++++++
datafusion/expr/src/expr_fn.rs | 9 +++++++++
.../physical-expr/src/aggregate/approx_median.rs | 23 +++++++++++++++-------
datafusion/physical-expr/src/aggregate/build_in.rs | 4 ++--
datafusion/sql/src/planner.rs | 20 ++++---------------
5 files changed, 52 insertions(+), 25 deletions(-)
diff --git a/datafusion/core/tests/dataframe_functions.rs
b/datafusion/core/tests/dataframe_functions.rs
index b126d010c..cefdaa777 100644
--- a/datafusion/core/tests/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe_functions.rs
@@ -33,6 +33,7 @@ use datafusion::prelude::*;
use datafusion::execution::context::SessionContext;
use datafusion::assert_batches_eq;
+use datafusion_expr::approx_median;
fn create_test_table() -> Result<Arc<DataFrame>> {
let schema = Arc::new(Schema::new(vec![
@@ -152,6 +153,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_fn_approx_median() -> Result<()> {
+ let expr = approx_median(col("b"));
+
+ let expected = vec![
+ "+----------------------+",
+ "| APPROXMEDIAN(test.b) |",
+ "+----------------------+",
+ "| 10 |",
+ "+----------------------+",
+ ];
+
+ let df = create_test_table()?;
+ let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
+
+ assert_batches_eq!(expected, &batches);
+
+ Ok(())
+}
+
#[tokio::test]
async fn test_fn_approx_percentile_cont() -> Result<()> {
let expr = approx_percentile_cont(col("b"), lit(0.5));
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 97bbd419e..5c95ba51d 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -166,6 +166,15 @@ pub fn approx_distinct(expr: Expr) -> Expr {
}
}
+/// Calculate an approximation of the median for `expr`.
+pub fn approx_median(expr: Expr) -> Expr {
+ Expr::AggregateFunction {
+ fun: aggregate_function::AggregateFunction::ApproxMedian,
+ distinct: false,
+ args: vec![expr],
+ }
+}
+
/// Calculate an approximation of the specified `percentile` for `expr`.
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
Expr::AggregateFunction {
diff --git a/datafusion/physical-expr/src/aggregate/approx_median.rs
b/datafusion/physical-expr/src/aggregate/approx_median.rs
index 651ccbdb7..3befd7a81 100644
--- a/datafusion/physical-expr/src/aggregate/approx_median.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_median.rs
@@ -17,6 +17,7 @@
//! Defines physical expressions for APPROX_MEDIAN that can be evaluated
MEDIAN at runtime during query execution
+use crate::expressions::{lit, ApproxPercentileCont};
use crate::{AggregateExpr, PhysicalExpr};
use arrow::{datatypes::DataType, datatypes::Field};
use datafusion_common::Result;
@@ -30,20 +31,28 @@ pub struct ApproxMedian {
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
+ approx_percentile: ApproxPercentileCont,
}
impl ApproxMedian {
/// Create a new APPROX_MEDIAN aggregate function
- pub fn new(
+ pub fn try_new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
- ) -> Self {
- Self {
- name: name.into(),
+ ) -> Result<Self> {
+ let name: String = name.into();
+ let approx_percentile = ApproxPercentileCont::new(
+ vec![expr.clone(), lit(0.5_f64)],
+ name.clone(),
+ data_type.clone(),
+ )?;
+ Ok(Self {
+ name,
expr,
data_type,
- }
+ approx_percentile,
+ })
}
}
@@ -58,11 +67,11 @@ impl AggregateExpr for ApproxMedian {
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- unimplemented!()
+ self.approx_percentile.create_accumulator()
}
fn state_fields(&self) -> Result<Vec<Field>> {
- unimplemented!()
+ self.approx_percentile.state_fields()
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 8d76e35e4..7bfa9e0a1 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -244,11 +244,11 @@ pub fn create_aggregate_expr(
));
}
(AggregateFunction::ApproxMedian, false) => {
- Arc::new(expressions::ApproxMedian::new(
+ Arc::new(expressions::ApproxMedian::try_new(
coerced_phy_exprs[0].clone(),
name,
return_type,
- ))
+ )?)
}
(AggregateFunction::ApproxMedian, true) => {
return Err(DataFusionError::NotImplemented(
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 6c3e80e32..8c615a24a 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -43,7 +43,6 @@ use datafusion_expr::{
};
use hashbrown::HashMap;
use std::collections::HashSet;
-use std::iter;
use std::str::FromStr;
use std::sync::Arc;
use std::{convert::TryInto, vec};
@@ -2183,20 +2182,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
_ => self.sql_fn_arg_to_logical_expr(a, schema, &mut
HashMap::new()),
})
.collect::<Result<Vec<Expr>>>()?,
- AggregateFunction::ApproxMedian => function
- .args
- .into_iter()
- .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, &mut
HashMap::new()))
- .chain(iter::once(Ok(lit(0.5_f64))))
- .collect::<Result<Vec<Expr>>>()?,
_ => self.function_args_to_expr(function.args, schema)?,
};
- let fun = match fun {
- AggregateFunction::ApproxMedian =>
AggregateFunction::ApproxPercentileCont,
- _ => fun,
- };
-
Ok((fun, args))
}
@@ -3567,8 +3555,8 @@ mod tests {
#[test]
fn select_approx_median() {
let sql = "SELECT approx_median(age) FROM person";
- let expected = "Projection:
#APPROXPERCENTILECONT(person.age,Float64(0.5))\
- \n Aggregate: groupBy=[[]],
aggr=[[APPROXPERCENTILECONT(#person.age, Float64(0.5))]]\
+ let expected = "Projection: #APPROXMEDIAN(person.age)\
+ \n Aggregate: groupBy=[[]],
aggr=[[APPROXMEDIAN(#person.age)]]\
\n TableScan: person";
quick_test(sql, expected);
}
@@ -4360,8 +4348,8 @@ mod tests {
let sql =
"SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id)
from orders";
let expected = "\
- Projection: #orders.order_id,
#APPROXPERCENTILECONT(orders.qty,Float64(0.5)) PARTITION BY [#orders.order_id]\
- \n WindowAggr: windowExpr=[[APPROXPERCENTILECONT(#orders.qty,
Float64(0.5)) PARTITION BY [#orders.order_id]]]\
+ Projection: #orders.order_id, #APPROXMEDIAN(orders.qty) PARTITION BY
[#orders.order_id]\
+ \n WindowAggr: windowExpr=[[APPROXMEDIAN(#orders.qty) PARTITION BY
[#orders.order_id]]]\
\n TableScan: orders";
quick_test(sql, expected);
}