This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 24a08465e1 Introduce expr builder for aggregate function (#10560)
24a08465e1 is described below
commit 24a08465e12bc07275cafe5310a7ac44898e39de
Author: Jay Zhan <[email protected]>
AuthorDate: Sun Jun 9 13:49:17 2024 +0800
Introduce expr builder for aggregate function (#10560)
* expr builder
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* build
Signed-off-by: jayzhan211 <[email protected]>
* upd user-guide
Signed-off-by: jayzhan211 <[email protected]>
* fix builder
Signed-off-by: jayzhan211 <[email protected]>
* Consolidate example in udaf_expr.rs, simplify filter API
* Add doc strings and examples
* Add tests and checks
* Improve documentation more
* fixup
* rm spce
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion-examples/examples/expr_api.rs | 44 ++++-
datafusion/core/tests/expr_api/mod.rs | 190 ++++++++++++++++++++-
datafusion/expr/src/expr.rs | 15 +-
datafusion/expr/src/lib.rs | 2 +-
datafusion/expr/src/udaf.rs | 181 +++++++++++++++++++-
datafusion/functions-aggregate/src/first_last.rs | 25 ++-
datafusion/functions-aggregate/src/macros.rs | 32 +---
.../optimizer/src/replace_distinct_aggregate.rs | 23 +--
.../proto/tests/cases/roundtrip_logical_plan.rs | 7 +-
docs/source/user-guide/expressions.md | 10 ++
10 files changed, 467 insertions(+), 62 deletions(-)
diff --git a/datafusion-examples/examples/expr_api.rs
b/datafusion-examples/examples/expr_api.rs
index 0082ed6eb9..591f6ac3de 100644
--- a/datafusion-examples/examples/expr_api.rs
+++ b/datafusion-examples/examples/expr_api.rs
@@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::DFSchema;
use datafusion::error::Result;
+use datafusion::functions_aggregate::first_last::first_value_udaf;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries};
use datafusion::prelude::*;
@@ -32,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
-use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
+use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};
/// This example demonstrates the DataFusion [`Expr`] API.
///
@@ -44,11 +45,12 @@ use datafusion_expr::{ColumnarValue, ExprSchemable,
Operator};
/// also comes with APIs for evaluation, simplification, and analysis.
///
/// The code in this example shows how to:
-/// 1. Create [`Exprs`] using different APIs: [`main`]`
-/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`]
-/// 3. Simplify expressions: [`simplify_demo`]
-/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`]
-/// 5. Get the types of the expressions: [`expression_type_demo`]
+/// 1. Create [`Expr`]s using different APIs: [`main`]`
+/// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`]
+/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`]
+/// 4. Simplify expressions: [`simplify_demo`]
+/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`]
+/// 6. Get the types of the expressions: [`expression_type_demo`]
#[tokio::main]
async fn main() -> Result<()> {
// The easiest way to do create expressions is to use the
@@ -63,6 +65,9 @@ async fn main() -> Result<()> {
));
assert_eq!(expr, expr2);
+ // See how to build aggregate functions with the expr_fn API
+ expr_fn_demo()?;
+
// See how to evaluate expressions
evaluate_demo()?;
@@ -78,6 +83,33 @@ async fn main() -> Result<()> {
Ok(())
}
+/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the
+/// full range of expression types such as aggregates and window functions.
+fn expr_fn_demo() -> Result<()> {
+ // Let's say you want to call the "first_value" aggregate function
+ let first_value = first_value_udaf();
+
+ // For example, to create the expression `FIRST_VALUE(price)`
+ // These expressions can be passed to `DataFrame::aggregate` and other
+ // APIs that take aggregate expressions.
+ let agg = first_value.call(vec![col("price")]);
+ assert_eq!(agg.to_string(), "first_value(price)");
+
+ // You can use the AggregateExt trait to create more complex aggregates
+ // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
+ let agg = first_value
+ .call(vec![col("price")])
+ .order_by(vec![col("ts").sort(false, false)])
+ .filter(col("quantity").gt(lit(100)))
+ .build()?; // build the aggregate
+ assert_eq!(
+ agg.to_string(),
+ "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts
DESC NULLS LAST]"
+ );
+
+ Ok(())
+}
+
/// DataFusion can also evaluate arbitrary expressions on Arrow arrays.
fn evaluate_demo() -> Result<()> {
// For example, let's say you have some integers in an array
diff --git a/datafusion/core/tests/expr_api/mod.rs
b/datafusion/core/tests/expr_api/mod.rs
index 1db5aa9f23..7085333bee 100644
--- a/datafusion/core/tests/expr_api/mod.rs
+++ b/datafusion/core/tests/expr_api/mod.rs
@@ -15,14 +15,18 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::util::pretty::pretty_format_columns;
+use arrow::util::pretty::{pretty_format_batches, pretty_format_columns};
use arrow_array::builder::{ListBuilder, StringBuilder};
-use arrow_array::{ArrayRef, RecordBatch, StringArray, StructArray};
+use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray};
use arrow_schema::{DataType, Field};
use datafusion::prelude::*;
-use datafusion_common::{DFSchema, ScalarValue};
+use datafusion_common::{assert_contains, DFSchema, ScalarValue};
+use datafusion_expr::AggregateExt;
use datafusion_functions::core::expr_ext::FieldAccessor;
+use datafusion_functions_aggregate::first_last::first_value_udaf;
+use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor};
+use sqlparser::ast::NullTreatment;
/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan
use std::sync::{Arc, OnceLock};
@@ -162,6 +166,183 @@ fn test_list_range() {
);
}
+#[tokio::test]
+async fn test_aggregate_error() {
+ let err = first_value_udaf()
+ .call(vec![col("props")])
+ // not a sort column
+ .order_by(vec![col("id")])
+ .build()
+ .unwrap_err()
+ .to_string();
+ assert_contains!(
+ err,
+ "Error during planning: ORDER BY expressions must be Expr::Sort"
+ );
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_order_by() {
+ let agg = first_value_udaf().call(vec![col("props")]);
+
+ // ORDER BY id ASC
+ let agg_asc = agg
+ .clone()
+ .order_by(vec![col("id").sort(true, true)])
+ .build()
+ .unwrap()
+ .alias("asc");
+
+ // ORDER BY id DESC
+ let agg_desc = agg
+ .order_by(vec![col("id").sort(false, true)])
+ .build()
+ .unwrap()
+ .alias("desc");
+
+ evaluate_agg_test(
+ agg_asc,
+ vec![
+ "+-----------------+",
+ "| asc |",
+ "+-----------------+",
+ "| {a: 2021-02-01} |",
+ "+-----------------+",
+ ],
+ )
+ .await;
+
+ evaluate_agg_test(
+ agg_desc,
+ vec![
+ "+-----------------+",
+ "| desc |",
+ "+-----------------+",
+ "| {a: 2021-02-03} |",
+ "+-----------------+",
+ ],
+ )
+ .await;
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_filter() {
+ let agg = first_value_udaf()
+ .call(vec![col("i")])
+ .order_by(vec![col("i").sort(true, true)])
+ .filter(col("i").is_not_null())
+ .build()
+ .unwrap()
+ .alias("val");
+
+ #[rustfmt::skip]
+ evaluate_agg_test(
+ agg,
+ vec![
+ "+-----+",
+ "| val |",
+ "+-----+",
+ "| 5 |",
+ "+-----+",
+ ],
+ )
+ .await;
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_distinct() {
+ let agg = sum_udaf()
+ .call(vec![lit(5)])
+ // distinct sum should be 5, not 15
+ .distinct()
+ .build()
+ .unwrap()
+ .alias("distinct");
+
+ evaluate_agg_test(
+ agg,
+ vec![
+ "+----------+",
+ "| distinct |",
+ "+----------+",
+ "| 5 |",
+ "+----------+",
+ ],
+ )
+ .await;
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_null_treatment() {
+ let agg = first_value_udaf()
+ .call(vec![col("i")])
+ .order_by(vec![col("i").sort(true, true)]);
+
+ let agg_respect = agg
+ .clone()
+ .null_treatment(NullTreatment::RespectNulls)
+ .build()
+ .unwrap()
+ .alias("respect");
+
+ let agg_ignore = agg
+ .null_treatment(NullTreatment::IgnoreNulls)
+ .build()
+ .unwrap()
+ .alias("ignore");
+
+ evaluate_agg_test(
+ agg_respect,
+ vec![
+ "+---------+",
+ "| respect |",
+ "+---------+",
+ "| |",
+ "+---------+",
+ ],
+ )
+ .await;
+
+ evaluate_agg_test(
+ agg_ignore,
+ vec![
+ "+--------+",
+ "| ignore |",
+ "+--------+",
+ "| 5 |",
+ "+--------+",
+ ],
+ )
+ .await;
+}
+
+/// Evaluates the specified expr as an aggregate and compares the result to the
+/// expected result.
+async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) {
+ let batch = test_batch();
+
+ let ctx = SessionContext::new();
+ let group_expr = vec![];
+ let agg_expr = vec![expr];
+ let result = ctx
+ .read_batch(batch)
+ .unwrap()
+ .aggregate(group_expr, agg_expr)
+ .unwrap()
+ .collect()
+ .await
+ .unwrap();
+
+ let result = pretty_format_batches(&result).unwrap().to_string();
+ let actual_lines = result.lines().collect::<Vec<_>>();
+
+ assert_eq!(
+ expected_lines, actual_lines,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected_lines, actual_lines
+ );
+}
+
/// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided
/// `RecordBatch` and compares the result to the expected result.
fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) {
@@ -189,6 +370,8 @@ fn test_batch() -> RecordBatch {
TEST_BATCH
.get_or_init(|| {
let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1",
"2", "3"]));
+ let int_array: ArrayRef =
+ Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)]));
// { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" }
let struct_array: ArrayRef = Arc::from(StructArray::from(vec![(
@@ -209,6 +392,7 @@ fn test_batch() -> RecordBatch {
RecordBatch::try_from_iter(vec![
("id", string_array),
+ ("i", int_array),
("props", struct_array),
("list", list_array),
])
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index fe58b2f90a..98ab8ec251 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -255,19 +255,23 @@ pub enum Expr {
/// can be used. The first form consists of a series of boolean "when"
expressions with
/// corresponding "then" expressions, and an optional "else" expression.
///
+ /// ```text
/// CASE WHEN condition THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
+ /// ```
///
/// The second form uses a base expression and then a series of "when"
clauses that match on a
/// literal value.
///
+ /// ```text
/// CASE expression
/// WHEN value THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
+ /// ```
Case(Case),
/// Casts the expression to a given type and will return a runtime error
if the expression cannot be cast.
/// This expression is guaranteed to have a fixed type.
@@ -279,7 +283,12 @@ pub enum Expr {
Sort(Sort),
/// Represents the call of a scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
- /// Represents the call of an aggregate built-in function with arguments.
+ /// Calls an aggregate function with arguments, and optional
+ /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`.
+ ///
+ /// See also [`AggregateExt`] to set these fields.
+ ///
+ /// [`AggregateExt`]: crate::udaf::AggregateExt
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
@@ -623,6 +632,10 @@ impl AggregateFunctionDefinition {
}
/// Aggregate function
+///
+/// See also [`AggregateExt`] to set these fields on `Expr`
+///
+/// [`AggregateExt`]: crate::udaf::AggregateExt
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 379e00fa92..89ee94f9f8 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -82,7 +82,7 @@ pub use signature::{
ArrayFunctionSignature, Signature, TypeSignature, Volatility,
TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
-pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
+pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index d778203207..a248518c2d 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,6 +17,7 @@
//! [`AggregateUDF`]: User Defined Aggregate Functions
+use crate::expr::AggregateFunction;
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
@@ -26,7 +27,8 @@ use crate::utils::AggregateOrderSensitivity;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
-use datafusion_common::{exec_err, not_impl_err, Result};
+use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
+use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
@@ -139,8 +141,7 @@ impl AggregateUDF {
/// This utility allows using the UDAF without requiring access to
/// the registry, such as with the DataFrame API.
pub fn call(&self, args: Vec<Expr>) -> Expr {
- // TODO: Support dictinct, filter, order by and null_treatment
- Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
+ Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
@@ -606,3 +607,177 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
(self.accumulator)(acc_args)
}
}
+
+/// Extensions for configuring [`Expr::AggregateFunction`]
+///
+/// Adds methods to [`Expr`] that make it easy to set optional aggregate
options
+/// such as `ORDER BY`, `FILTER` and `DISTINCT`
+///
+/// # Example
+/// ```no_run
+/// # use datafusion_common::Result;
+/// # use datafusion_expr::{AggregateUDF, col, Expr, lit};
+/// # use sqlparser::ast::NullTreatment;
+/// # fn count(arg: Expr) -> Expr { todo!{} }
+/// # fn first_value(arg: Expr) -> Expr { todo!{} }
+/// # fn main() -> Result<()> {
+/// use datafusion_expr::AggregateExt;
+///
+/// // Create COUNT(x FILTER y > 5)
+/// let agg = count(col("x"))
+/// .filter(col("y").gt(lit(5)))
+/// .build()?;
+/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS)
+/// let sort_expr = col("y").sort(true, true);
+/// let agg = first_value(col("x"))
+/// .order_by(vec![sort_expr])
+/// .null_treatment(NullTreatment::IgnoreNulls)
+/// .build()?;
+/// # Ok(())
+/// # }
+/// ```
+pub trait AggregateExt {
+ /// Add `ORDER BY <order_by>`
+ ///
+ /// Note: `order_by` must be [`Expr::Sort`]
+ fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder;
+ /// Add `FILTER <filter>`
+ fn filter(self, filter: Expr) -> AggregateBuilder;
+ /// Add `DISTINCT`
+ fn distinct(self) -> AggregateBuilder;
+ /// Add `RESPECT NULLS` or `IGNORE NULLS`
+ fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder;
+}
+
+/// Implementation of [`AggregateExt`].
+///
+/// See [`AggregateExt`] for usage and examples
+#[derive(Debug, Clone)]
+pub struct AggregateBuilder {
+ udaf: Option<AggregateFunction>,
+ order_by: Option<Vec<Expr>>,
+ filter: Option<Expr>,
+ distinct: bool,
+ null_treatment: Option<NullTreatment>,
+}
+
+impl AggregateBuilder {
+ /// Create a new `AggregateBuilder`, see [`AggregateExt`]
+
+ fn new(udaf: Option<AggregateFunction>) -> Self {
+ Self {
+ udaf,
+ order_by: None,
+ filter: None,
+ distinct: false,
+ null_treatment: None,
+ }
+ }
+
+ /// Updates and returns the in progress [`Expr::AggregateFunction`]
+ ///
+ /// # Errors:
+ ///
+ /// Returns an error of this builder [`AggregateExt`] was used with an
+ /// `Expr` variant other than [`Expr::AggregateFunction`]
+ pub fn build(self) -> Result<Expr> {
+ let Self {
+ udaf,
+ order_by,
+ filter,
+ distinct,
+ null_treatment,
+ } = self;
+
+ let Some(mut udaf) = udaf else {
+ return plan_err!(
+ "AggregateExt can only be used with Expr::AggregateFunction"
+ );
+ };
+
+ if let Some(order_by) = &order_by {
+ for expr in order_by.iter() {
+ if !matches!(expr, Expr::Sort(_)) {
+ return plan_err!(
+ "ORDER BY expressions must be Expr::Sort, found
{expr:?}"
+ );
+ }
+ }
+ }
+
+ udaf.order_by = order_by;
+ udaf.filter = filter.map(Box::new);
+ udaf.distinct = distinct;
+ udaf.null_treatment = null_treatment;
+ Ok(Expr::AggregateFunction(udaf))
+ }
+
+ /// Add `ORDER BY <order_by>`
+ ///
+ /// Note: `order_by` must be [`Expr::Sort`]
+ pub fn order_by(mut self, order_by: Vec<Expr>) -> AggregateBuilder {
+ self.order_by = Some(order_by);
+ self
+ }
+
+ /// Add `FILTER <filter>`
+ pub fn filter(mut self, filter: Expr) -> AggregateBuilder {
+ self.filter = Some(filter);
+ self
+ }
+
+ /// Add `DISTINCT`
+ pub fn distinct(mut self) -> AggregateBuilder {
+ self.distinct = true;
+ self
+ }
+
+ /// Add `RESPECT NULLS` or `IGNORE NULLS`
+ pub fn null_treatment(mut self, null_treatment: NullTreatment) ->
AggregateBuilder {
+ self.null_treatment = Some(null_treatment);
+ self
+ }
+}
+
+impl AggregateExt for Expr {
+ fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder {
+ match self {
+ Expr::AggregateFunction(udaf) => {
+ let mut builder = AggregateBuilder::new(Some(udaf));
+ builder.order_by = Some(order_by);
+ builder
+ }
+ _ => AggregateBuilder::new(None),
+ }
+ }
+ fn filter(self, filter: Expr) -> AggregateBuilder {
+ match self {
+ Expr::AggregateFunction(udaf) => {
+ let mut builder = AggregateBuilder::new(Some(udaf));
+ builder.filter = Some(filter);
+ builder
+ }
+ _ => AggregateBuilder::new(None),
+ }
+ }
+ fn distinct(self) -> AggregateBuilder {
+ match self {
+ Expr::AggregateFunction(udaf) => {
+ let mut builder = AggregateBuilder::new(Some(udaf));
+ builder.distinct = true;
+ builder
+ }
+ _ => AggregateBuilder::new(None),
+ }
+ }
+ fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder
{
+ match self {
+ Expr::AggregateFunction(udaf) => {
+ let mut builder = AggregateBuilder::new(Some(udaf));
+ builder.null_treatment = Some(null_treatment);
+ builder
+ }
+ _ => AggregateBuilder::new(None),
+ }
+ }
+}
diff --git a/datafusion/functions-aggregate/src/first_last.rs
b/datafusion/functions-aggregate/src/first_last.rs
index 435d277473..dd38e34872 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -31,20 +31,29 @@ use datafusion_common::{
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
use datafusion_expr::{
- Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature,
TypeSignature,
- Volatility,
+ Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr,
Signature,
+ TypeSignature, Volatility,
};
use datafusion_physical_expr_common::aggregate::utils::get_sort_options;
use datafusion_physical_expr_common::sort_expr::{
limited_convert_logical_sort_exprs_to_physical, LexOrdering,
PhysicalSortExpr,
};
-make_udaf_expr_and_func!(
- FirstValue,
- first_value,
- "Returns the first value in a group of values.",
- first_value_udaf
-);
+create_func!(FirstValue, first_value_udaf);
+
+/// Returns the first value in a group of values.
+pub fn first_value(expression: Expr, order_by: Option<Vec<Expr>>) -> Expr {
+ if let Some(order_by) = order_by {
+ first_value_udaf()
+ .call(vec![expression])
+ .order_by(order_by)
+ .build()
+ // guaranteed to be `Expr::AggregateFunction`
+ .unwrap()
+ } else {
+ first_value_udaf().call(vec![expression])
+ }
+}
pub struct FirstValue {
signature: Signature,
diff --git a/datafusion/functions-aggregate/src/macros.rs
b/datafusion/functions-aggregate/src/macros.rs
index 6c3348d6c1..75bb9dc547 100644
--- a/datafusion/functions-aggregate/src/macros.rs
+++ b/datafusion/functions-aggregate/src/macros.rs
@@ -48,24 +48,7 @@ macro_rules! make_udaf_expr_and_func {
None,
))
}
- create_func!($UDAF, $AGGREGATE_UDF_FN);
- };
- ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr,
$AGGREGATE_UDF_FN:ident) => {
- // "fluent expr_fn" style function
- #[doc = $DOC]
- pub fn $EXPR_FN(
- $($arg: datafusion_expr::Expr,)*
- distinct: bool,
- ) -> datafusion_expr::Expr {
-
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
- $AGGREGATE_UDF_FN(),
- vec![$($arg),*],
- distinct,
- None,
- None,
- None
- ))
- }
+
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
@@ -73,20 +56,17 @@ macro_rules! make_udaf_expr_and_func {
#[doc = $DOC]
pub fn $EXPR_FN(
args: Vec<datafusion_expr::Expr>,
- distinct: bool,
- filter: Option<Box<datafusion_expr::Expr>>,
- order_by: Option<Vec<datafusion_expr::Expr>>,
- null_treatment: Option<sqlparser::ast::NullTreatment>
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
args,
- distinct,
- filter,
- order_by,
- null_treatment,
+ false,
+ None,
+ None,
+ None,
))
}
+
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
}
diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs
b/datafusion/optimizer/src/replace_distinct_aggregate.rs
index 752e2b2007..b32a886353 100644
--- a/datafusion/optimizer/src/replace_distinct_aggregate.rs
+++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs
@@ -21,10 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Column, Result};
-use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr_rewriter::normalize_cols;
use datafusion_expr::utils::expand_wildcard;
-use datafusion_expr::{col, LogicalPlanBuilder};
+use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
@@ -95,17 +94,19 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
let expr_cnt = on_expr.len();
// Construct the aggregation expression to be used to fetch
the selected expressions.
- let first_value_udaf =
+ let first_value_udaf:
std::sync::Arc<datafusion_expr::AggregateUDF> =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
- Expr::AggregateFunction(AggregateFunction::new_udf(
- first_value_udaf.clone(),
- vec![e],
- false,
- None,
- sort_expr.clone(),
- None,
- ))
+ if let Some(order_by) = &sort_expr {
+ first_value_udaf
+ .call(vec![e])
+ .order_by(order_by.clone())
+ .build()
+ // guaranteed to be `Expr::AggregateFunction`
+ .unwrap()
+ } else {
+ first_value_udaf.call(vec![e])
+ }
});
let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index d0e0803372..a6889633d2 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -26,6 +26,8 @@ use arrow::datatypes::{
DataType, Field, Fields, Int32Type, IntervalDayTimeType,
IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
};
+use prost::Message;
+
use datafusion::datasource::provider::TableProviderFactory;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
@@ -64,8 +66,6 @@ use datafusion_proto::logical_plan::{
};
use datafusion_proto::protobuf;
-use prost::Message;
-
#[cfg(feature = "json")]
fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
let string = serde_json::to_string(proto).unwrap();
@@ -647,7 +647,8 @@ async fn roundtrip_expr_api() -> Result<()> {
lit(1),
),
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2),
lit(4)),
- first_value(vec![lit(1)], false, None, None, None),
+ first_value(lit(1), None),
+ first_value(lit(1), Some(vec![lit(2).sort(true, true)])),
covar_samp(lit(1.5), lit(2.2)),
covar_pop(lit(1.5), lit(2.2)),
sum(lit(1)),
diff --git a/docs/source/user-guide/expressions.md
b/docs/source/user-guide/expressions.md
index a5fc134916..cae9627210 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -304,6 +304,16 @@ select log(-1), log(0), sqrt(-1);
| rollup(exprs) | Creates
a grouping set for rollup sets.
|
| sum(expr) |
Сalculates the sum of `expr`.
|
+## Aggregate Function Builder
+
+You can also use the `AggregateExt` trait to more easily build Aggregate
arguments `Expr`.
+
+See `datafusion-examples/examples/expr_api.rs` for example usage.
+
+| Syntax |
Equivalent to |
+| ----------------------------------------------------------------------- |
----------------------------------- |
+| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build().unwrap() |
first_value(expr, Some(vec![expr])) |
+
## Subquery Expressions
| Syntax | Description
|
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]