This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 24a08465e1 Introduce expr builder for aggregate function (#10560)
24a08465e1 is described below

commit 24a08465e12bc07275cafe5310a7ac44898e39de
Author: Jay Zhan <[email protected]>
AuthorDate: Sun Jun 9 13:49:17 2024 +0800

    Introduce expr builder for aggregate function (#10560)
    
    * expr builder
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * build
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * upd user-guide
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix builder
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * Consolidate example in udaf_expr.rs, simplify filter API
    
    * Add doc strings and examples
    
    * Add tests and checks
    
    * Improve documentation more
    
    * fixup
    
    * rm spce
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion-examples/examples/expr_api.rs           |  44 ++++-
 datafusion/core/tests/expr_api/mod.rs              | 190 ++++++++++++++++++++-
 datafusion/expr/src/expr.rs                        |  15 +-
 datafusion/expr/src/lib.rs                         |   2 +-
 datafusion/expr/src/udaf.rs                        | 181 +++++++++++++++++++-
 datafusion/functions-aggregate/src/first_last.rs   |  25 ++-
 datafusion/functions-aggregate/src/macros.rs       |  32 +---
 .../optimizer/src/replace_distinct_aggregate.rs    |  23 +--
 .../proto/tests/cases/roundtrip_logical_plan.rs    |   7 +-
 docs/source/user-guide/expressions.md              |  10 ++
 10 files changed, 467 insertions(+), 62 deletions(-)

diff --git a/datafusion-examples/examples/expr_api.rs 
b/datafusion-examples/examples/expr_api.rs
index 0082ed6eb9..591f6ac3de 100644
--- a/datafusion-examples/examples/expr_api.rs
+++ b/datafusion-examples/examples/expr_api.rs
@@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch;
 use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
 use datafusion::common::DFSchema;
 use datafusion::error::Result;
+use datafusion::functions_aggregate::first_last::first_value_udaf;
 use datafusion::optimizer::simplify_expressions::ExprSimplifier;
 use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries};
 use datafusion::prelude::*;
@@ -32,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
 use datafusion_expr::expr::BinaryExpr;
 use datafusion_expr::interval_arithmetic::Interval;
 use datafusion_expr::simplify::SimplifyContext;
-use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};
+use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};
 
 /// This example demonstrates the DataFusion [`Expr`] API.
 ///
@@ -44,11 +45,12 @@ use datafusion_expr::{ColumnarValue, ExprSchemable, 
Operator};
 /// also comes with APIs for evaluation, simplification, and analysis.
 ///
 /// The code in this example shows how to:
-/// 1. Create [`Exprs`] using different APIs: [`main`]`
-/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`]
-/// 3. Simplify expressions: [`simplify_demo`]
-/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`]
-/// 5. Get the types of the expressions: [`expression_type_demo`]
+/// 1. Create [`Expr`]s using different APIs: [`main`]`
+/// 2. Use the fluent API to easly create complex [`Expr`]s:  [`expr_fn_demo`]
+/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`]
+/// 4. Simplify expressions: [`simplify_demo`]
+/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`]
+/// 6. Get the types of the expressions: [`expression_type_demo`]
 #[tokio::main]
 async fn main() -> Result<()> {
     // The easiest way to do create expressions is to use the
@@ -63,6 +65,9 @@ async fn main() -> Result<()> {
     ));
     assert_eq!(expr, expr2);
 
+    // See how to build aggregate functions with the expr_fn API
+    expr_fn_demo()?;
+
     // See how to evaluate expressions
     evaluate_demo()?;
 
@@ -78,6 +83,33 @@ async fn main() -> Result<()> {
     Ok(())
 }
 
+/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the
+/// full range of expression types such as aggregates and window functions.
+fn expr_fn_demo() -> Result<()> {
+    // Let's say you want to call the "first_value" aggregate function
+    let first_value = first_value_udaf();
+
+    // For example, to create the expression `FIRST_VALUE(price)`
+    // These expressions can be passed to `DataFrame::aggregate` and other
+    // APIs that take aggregate expressions.
+    let agg = first_value.call(vec![col("price")]);
+    assert_eq!(agg.to_string(), "first_value(price)");
+
+    // You can use the AggregateExt trait to create more complex aggregates
+    // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
+    let agg = first_value
+        .call(vec![col("price")])
+        .order_by(vec![col("ts").sort(false, false)])
+        .filter(col("quantity").gt(lit(100)))
+        .build()?; // build the aggregate
+    assert_eq!(
+        agg.to_string(),
+        "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts 
DESC NULLS LAST]"
+    );
+
+    Ok(())
+}
+
 /// DataFusion can also evaluate arbitrary expressions on Arrow arrays.
 fn evaluate_demo() -> Result<()> {
     // For example, let's say you have some integers in an array
diff --git a/datafusion/core/tests/expr_api/mod.rs 
b/datafusion/core/tests/expr_api/mod.rs
index 1db5aa9f23..7085333bee 100644
--- a/datafusion/core/tests/expr_api/mod.rs
+++ b/datafusion/core/tests/expr_api/mod.rs
@@ -15,14 +15,18 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::util::pretty::pretty_format_columns;
+use arrow::util::pretty::{pretty_format_batches, pretty_format_columns};
 use arrow_array::builder::{ListBuilder, StringBuilder};
-use arrow_array::{ArrayRef, RecordBatch, StringArray, StructArray};
+use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray};
 use arrow_schema::{DataType, Field};
 use datafusion::prelude::*;
-use datafusion_common::{DFSchema, ScalarValue};
+use datafusion_common::{assert_contains, DFSchema, ScalarValue};
+use datafusion_expr::AggregateExt;
 use datafusion_functions::core::expr_ext::FieldAccessor;
+use datafusion_functions_aggregate::first_last::first_value_udaf;
+use datafusion_functions_aggregate::sum::sum_udaf;
 use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor};
+use sqlparser::ast::NullTreatment;
 /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan
 use std::sync::{Arc, OnceLock};
 
@@ -162,6 +166,183 @@ fn test_list_range() {
     );
 }
 
+#[tokio::test]
+async fn test_aggregate_error() {
+    let err = first_value_udaf()
+        .call(vec![col("props")])
+        // not a sort column
+        .order_by(vec![col("id")])
+        .build()
+        .unwrap_err()
+        .to_string();
+    assert_contains!(
+        err,
+        "Error during planning: ORDER BY expressions must be Expr::Sort"
+    );
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_order_by() {
+    let agg = first_value_udaf().call(vec![col("props")]);
+
+    // ORDER BY id ASC
+    let agg_asc = agg
+        .clone()
+        .order_by(vec![col("id").sort(true, true)])
+        .build()
+        .unwrap()
+        .alias("asc");
+
+    // ORDER BY id DESC
+    let agg_desc = agg
+        .order_by(vec![col("id").sort(false, true)])
+        .build()
+        .unwrap()
+        .alias("desc");
+
+    evaluate_agg_test(
+        agg_asc,
+        vec![
+            "+-----------------+",
+            "| asc             |",
+            "+-----------------+",
+            "| {a: 2021-02-01} |",
+            "+-----------------+",
+        ],
+    )
+    .await;
+
+    evaluate_agg_test(
+        agg_desc,
+        vec![
+            "+-----------------+",
+            "| desc            |",
+            "+-----------------+",
+            "| {a: 2021-02-03} |",
+            "+-----------------+",
+        ],
+    )
+    .await;
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_filter() {
+    let agg = first_value_udaf()
+        .call(vec![col("i")])
+        .order_by(vec![col("i").sort(true, true)])
+        .filter(col("i").is_not_null())
+        .build()
+        .unwrap()
+        .alias("val");
+
+    #[rustfmt::skip]
+    evaluate_agg_test(
+        agg,
+        vec![
+            "+-----+",
+            "| val |",
+            "+-----+",
+            "| 5   |",
+            "+-----+",
+        ],
+    )
+        .await;
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_distinct() {
+    let agg = sum_udaf()
+        .call(vec![lit(5)])
+        // distinct sum should be 5, not 15
+        .distinct()
+        .build()
+        .unwrap()
+        .alias("distinct");
+
+    evaluate_agg_test(
+        agg,
+        vec![
+            "+----------+",
+            "| distinct |",
+            "+----------+",
+            "| 5        |",
+            "+----------+",
+        ],
+    )
+    .await;
+}
+
+#[tokio::test]
+async fn test_aggregate_ext_null_treatment() {
+    let agg = first_value_udaf()
+        .call(vec![col("i")])
+        .order_by(vec![col("i").sort(true, true)]);
+
+    let agg_respect = agg
+        .clone()
+        .null_treatment(NullTreatment::RespectNulls)
+        .build()
+        .unwrap()
+        .alias("respect");
+
+    let agg_ignore = agg
+        .null_treatment(NullTreatment::IgnoreNulls)
+        .build()
+        .unwrap()
+        .alias("ignore");
+
+    evaluate_agg_test(
+        agg_respect,
+        vec![
+            "+---------+",
+            "| respect |",
+            "+---------+",
+            "|         |",
+            "+---------+",
+        ],
+    )
+    .await;
+
+    evaluate_agg_test(
+        agg_ignore,
+        vec![
+            "+--------+",
+            "| ignore |",
+            "+--------+",
+            "| 5      |",
+            "+--------+",
+        ],
+    )
+    .await;
+}
+
+/// Evaluates the specified expr as an aggregate and compares the result to the
+/// expected result.
+async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) {
+    let batch = test_batch();
+
+    let ctx = SessionContext::new();
+    let group_expr = vec![];
+    let agg_expr = vec![expr];
+    let result = ctx
+        .read_batch(batch)
+        .unwrap()
+        .aggregate(group_expr, agg_expr)
+        .unwrap()
+        .collect()
+        .await
+        .unwrap();
+
+    let result = pretty_format_batches(&result).unwrap().to_string();
+    let actual_lines = result.lines().collect::<Vec<_>>();
+
+    assert_eq!(
+        expected_lines, actual_lines,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected_lines, actual_lines
+    );
+}
+
 /// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided
 /// `RecordBatch` and compares the result to the expected result.
 fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) {
@@ -189,6 +370,8 @@ fn test_batch() -> RecordBatch {
     TEST_BATCH
         .get_or_init(|| {
             let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", 
"2", "3"]));
+            let int_array: ArrayRef =
+                Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)]));
 
             // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" }
             let struct_array: ArrayRef = Arc::from(StructArray::from(vec![(
@@ -209,6 +392,7 @@ fn test_batch() -> RecordBatch {
 
             RecordBatch::try_from_iter(vec![
                 ("id", string_array),
+                ("i", int_array),
                 ("props", struct_array),
                 ("list", list_array),
             ])
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index fe58b2f90a..98ab8ec251 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -255,19 +255,23 @@ pub enum Expr {
     /// can be used. The first form consists of a series of boolean "when" 
expressions with
     /// corresponding "then" expressions, and an optional "else" expression.
     ///
+    /// ```text
     /// CASE WHEN condition THEN result
     ///      [WHEN ...]
     ///      [ELSE result]
     /// END
+    /// ```
     ///
     /// The second form uses a base expression and then a series of "when" 
clauses that match on a
     /// literal value.
     ///
+    /// ```text
     /// CASE expression
     ///     WHEN value THEN result
     ///     [WHEN ...]
     ///     [ELSE result]
     /// END
+    /// ```
     Case(Case),
     /// Casts the expression to a given type and will return a runtime error 
if the expression cannot be cast.
     /// This expression is guaranteed to have a fixed type.
@@ -279,7 +283,12 @@ pub enum Expr {
     Sort(Sort),
     /// Represents the call of a scalar function with a set of arguments.
     ScalarFunction(ScalarFunction),
-    /// Represents the call of an aggregate built-in function with arguments.
+    /// Calls an aggregate function with arguments, and optional
+    /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`.
+    ///
+    /// See also [`AggregateExt`] to set these fields.
+    ///
+    /// [`AggregateExt`]: crate::udaf::AggregateExt
     AggregateFunction(AggregateFunction),
     /// Represents the call of a window function with arguments.
     WindowFunction(WindowFunction),
@@ -623,6 +632,10 @@ impl AggregateFunctionDefinition {
 }
 
 /// Aggregate function
+///
+/// See also  [`AggregateExt`] to set these fields on `Expr`
+///
+/// [`AggregateExt`]: crate::udaf::AggregateExt
 #[derive(Clone, PartialEq, Eq, Hash, Debug)]
 pub struct AggregateFunction {
     /// Name of the function
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 379e00fa92..89ee94f9f8 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -82,7 +82,7 @@ pub use signature::{
     ArrayFunctionSignature, Signature, TypeSignature, Volatility, 
TIMEZONE_WILDCARD,
 };
 pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
-pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
+pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF};
 pub use udf::{ScalarUDF, ScalarUDFImpl};
 pub use udwf::{WindowUDF, WindowUDFImpl};
 pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index d778203207..a248518c2d 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,6 +17,7 @@
 
 //! [`AggregateUDF`]: User Defined Aggregate Functions
 
+use crate::expr::AggregateFunction;
 use crate::function::{
     AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
 };
@@ -26,7 +27,8 @@ use crate::utils::AggregateOrderSensitivity;
 use crate::{Accumulator, Expr};
 use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
 use arrow::datatypes::{DataType, Field};
-use datafusion_common::{exec_err, not_impl_err, Result};
+use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
+use sqlparser::ast::NullTreatment;
 use std::any::Any;
 use std::fmt::{self, Debug, Formatter};
 use std::sync::Arc;
@@ -139,8 +141,7 @@ impl AggregateUDF {
     /// This utility allows using the UDAF without requiring access to
     /// the registry, such as with the DataFrame API.
     pub fn call(&self, args: Vec<Expr>) -> Expr {
-        // TODO: Support dictinct, filter, order by and null_treatment
-        Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
+        Expr::AggregateFunction(AggregateFunction::new_udf(
             Arc::new(self.clone()),
             args,
             false,
@@ -606,3 +607,177 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
         (self.accumulator)(acc_args)
     }
 }
+
+/// Extensions for configuring [`Expr::AggregateFunction`]
+///
+/// Adds methods to [`Expr`] that make it easy to set optional aggregate 
options
+/// such as `ORDER BY`, `FILTER` and `DISTINCT`
+///
+/// # Example
+/// ```no_run
+/// # use datafusion_common::Result;
+/// # use datafusion_expr::{AggregateUDF, col, Expr, lit};
+/// # use sqlparser::ast::NullTreatment;
+/// # fn count(arg: Expr) -> Expr { todo!{} }
+/// # fn first_value(arg: Expr) -> Expr { todo!{} }
+/// # fn main() -> Result<()> {
+/// use datafusion_expr::AggregateExt;
+///
+/// // Create COUNT(x FILTER y > 5)
+/// let agg = count(col("x"))
+///    .filter(col("y").gt(lit(5)))
+///    .build()?;
+///  // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS)
+/// let sort_expr = col("y").sort(true, true);
+/// let agg = first_value(col("x"))
+///   .order_by(vec![sort_expr])
+///   .null_treatment(NullTreatment::IgnoreNulls)
+///   .build()?;
+/// # Ok(())
+/// # }
+/// ```
+pub trait AggregateExt {
+    /// Add `ORDER BY <order_by>`
+    ///
+    /// Note: `order_by` must be [`Expr::Sort`]
+    fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder;
+    /// Add `FILTER <filter>`
+    fn filter(self, filter: Expr) -> AggregateBuilder;
+    /// Add `DISTINCT`
+    fn distinct(self) -> AggregateBuilder;
+    /// Add `RESPECT NULLS` or `IGNORE NULLS`
+    fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder;
+}
+
+/// Implementation of [`AggregateExt`].
+///
+/// See [`AggregateExt`] for usage and examples
+#[derive(Debug, Clone)]
+pub struct AggregateBuilder {
+    udaf: Option<AggregateFunction>,
+    order_by: Option<Vec<Expr>>,
+    filter: Option<Expr>,
+    distinct: bool,
+    null_treatment: Option<NullTreatment>,
+}
+
+impl AggregateBuilder {
+    /// Create a new `AggregateBuilder`, see [`AggregateExt`]
+
+    fn new(udaf: Option<AggregateFunction>) -> Self {
+        Self {
+            udaf,
+            order_by: None,
+            filter: None,
+            distinct: false,
+            null_treatment: None,
+        }
+    }
+
+    /// Updates and returns the in progress [`Expr::AggregateFunction`]
+    ///
+    /// # Errors:
+    ///
+    /// Returns an error of this builder  [`AggregateExt`] was used with an
+    /// `Expr` variant other than [`Expr::AggregateFunction`]
+    pub fn build(self) -> Result<Expr> {
+        let Self {
+            udaf,
+            order_by,
+            filter,
+            distinct,
+            null_treatment,
+        } = self;
+
+        let Some(mut udaf) = udaf else {
+            return plan_err!(
+                "AggregateExt can only be used with Expr::AggregateFunction"
+            );
+        };
+
+        if let Some(order_by) = &order_by {
+            for expr in order_by.iter() {
+                if !matches!(expr, Expr::Sort(_)) {
+                    return plan_err!(
+                        "ORDER BY expressions must be Expr::Sort, found 
{expr:?}"
+                    );
+                }
+            }
+        }
+
+        udaf.order_by = order_by;
+        udaf.filter = filter.map(Box::new);
+        udaf.distinct = distinct;
+        udaf.null_treatment = null_treatment;
+        Ok(Expr::AggregateFunction(udaf))
+    }
+
+    /// Add `ORDER BY <order_by>`
+    ///
+    /// Note: `order_by` must be [`Expr::Sort`]
+    pub fn order_by(mut self, order_by: Vec<Expr>) -> AggregateBuilder {
+        self.order_by = Some(order_by);
+        self
+    }
+
+    /// Add `FILTER <filter>`
+    pub fn filter(mut self, filter: Expr) -> AggregateBuilder {
+        self.filter = Some(filter);
+        self
+    }
+
+    /// Add `DISTINCT`
+    pub fn distinct(mut self) -> AggregateBuilder {
+        self.distinct = true;
+        self
+    }
+
+    /// Add `RESPECT NULLS` or `IGNORE NULLS`
+    pub fn null_treatment(mut self, null_treatment: NullTreatment) -> 
AggregateBuilder {
+        self.null_treatment = Some(null_treatment);
+        self
+    }
+}
+
+impl AggregateExt for Expr {
+    fn order_by(self, order_by: Vec<Expr>) -> AggregateBuilder {
+        match self {
+            Expr::AggregateFunction(udaf) => {
+                let mut builder = AggregateBuilder::new(Some(udaf));
+                builder.order_by = Some(order_by);
+                builder
+            }
+            _ => AggregateBuilder::new(None),
+        }
+    }
+    fn filter(self, filter: Expr) -> AggregateBuilder {
+        match self {
+            Expr::AggregateFunction(udaf) => {
+                let mut builder = AggregateBuilder::new(Some(udaf));
+                builder.filter = Some(filter);
+                builder
+            }
+            _ => AggregateBuilder::new(None),
+        }
+    }
+    fn distinct(self) -> AggregateBuilder {
+        match self {
+            Expr::AggregateFunction(udaf) => {
+                let mut builder = AggregateBuilder::new(Some(udaf));
+                builder.distinct = true;
+                builder
+            }
+            _ => AggregateBuilder::new(None),
+        }
+    }
+    fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder 
{
+        match self {
+            Expr::AggregateFunction(udaf) => {
+                let mut builder = AggregateBuilder::new(Some(udaf));
+                builder.null_treatment = Some(null_treatment);
+                builder
+            }
+            _ => AggregateBuilder::new(None),
+        }
+    }
+}
diff --git a/datafusion/functions-aggregate/src/first_last.rs 
b/datafusion/functions-aggregate/src/first_last.rs
index 435d277473..dd38e34872 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -31,20 +31,29 @@ use datafusion_common::{
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
 use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
 use datafusion_expr::{
-    Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature, 
TypeSignature,
-    Volatility,
+    Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, 
Signature,
+    TypeSignature, Volatility,
 };
 use datafusion_physical_expr_common::aggregate::utils::get_sort_options;
 use datafusion_physical_expr_common::sort_expr::{
     limited_convert_logical_sort_exprs_to_physical, LexOrdering, 
PhysicalSortExpr,
 };
 
-make_udaf_expr_and_func!(
-    FirstValue,
-    first_value,
-    "Returns the first value in a group of values.",
-    first_value_udaf
-);
+create_func!(FirstValue, first_value_udaf);
+
+/// Returns the first value in a group of values.
+pub fn first_value(expression: Expr, order_by: Option<Vec<Expr>>) -> Expr {
+    if let Some(order_by) = order_by {
+        first_value_udaf()
+            .call(vec![expression])
+            .order_by(order_by)
+            .build()
+            // guaranteed to be `Expr::AggregateFunction`
+            .unwrap()
+    } else {
+        first_value_udaf().call(vec![expression])
+    }
+}
 
 pub struct FirstValue {
     signature: Signature,
diff --git a/datafusion/functions-aggregate/src/macros.rs 
b/datafusion/functions-aggregate/src/macros.rs
index 6c3348d6c1..75bb9dc547 100644
--- a/datafusion/functions-aggregate/src/macros.rs
+++ b/datafusion/functions-aggregate/src/macros.rs
@@ -48,24 +48,7 @@ macro_rules! make_udaf_expr_and_func {
                 None,
             ))
         }
-        create_func!($UDAF, $AGGREGATE_UDF_FN);
-    };
-    ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr, 
$AGGREGATE_UDF_FN:ident) => {
-        // "fluent expr_fn" style function
-        #[doc = $DOC]
-        pub fn $EXPR_FN(
-            $($arg: datafusion_expr::Expr,)*
-            distinct: bool,
-        ) -> datafusion_expr::Expr {
-            
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
-                $AGGREGATE_UDF_FN(),
-                vec![$($arg),*],
-                distinct,
-                None,
-                None,
-                None
-            ))
-        }
+
         create_func!($UDAF, $AGGREGATE_UDF_FN);
     };
     ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
@@ -73,20 +56,17 @@ macro_rules! make_udaf_expr_and_func {
         #[doc = $DOC]
         pub fn $EXPR_FN(
             args: Vec<datafusion_expr::Expr>,
-            distinct: bool,
-            filter: Option<Box<datafusion_expr::Expr>>,
-            order_by: Option<Vec<datafusion_expr::Expr>>,
-            null_treatment: Option<sqlparser::ast::NullTreatment>
         ) -> datafusion_expr::Expr {
             
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
                 $AGGREGATE_UDF_FN(),
                 args,
-                distinct,
-                filter,
-                order_by,
-                null_treatment,
+                false,
+                None,
+                None,
+                None,
             ))
         }
+
         create_func!($UDAF, $AGGREGATE_UDF_FN);
     };
 }
diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs 
b/datafusion/optimizer/src/replace_distinct_aggregate.rs
index 752e2b2007..b32a886353 100644
--- a/datafusion/optimizer/src/replace_distinct_aggregate.rs
+++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs
@@ -21,10 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule};
 
 use datafusion_common::tree_node::Transformed;
 use datafusion_common::{internal_err, Column, Result};
-use datafusion_expr::expr::AggregateFunction;
 use datafusion_expr::expr_rewriter::normalize_cols;
 use datafusion_expr::utils::expand_wildcard;
-use datafusion_expr::{col, LogicalPlanBuilder};
+use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder};
 use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
 
 /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
@@ -95,17 +94,19 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
                 let expr_cnt = on_expr.len();
 
                 // Construct the aggregation expression to be used to fetch 
the selected expressions.
-                let first_value_udaf =
+                let first_value_udaf: 
std::sync::Arc<datafusion_expr::AggregateUDF> =
                     config.function_registry().unwrap().udaf("first_value")?;
                 let aggr_expr = select_expr.into_iter().map(|e| {
-                    Expr::AggregateFunction(AggregateFunction::new_udf(
-                        first_value_udaf.clone(),
-                        vec![e],
-                        false,
-                        None,
-                        sort_expr.clone(),
-                        None,
-                    ))
+                    if let Some(order_by) = &sort_expr {
+                        first_value_udaf
+                            .call(vec![e])
+                            .order_by(order_by.clone())
+                            .build()
+                            // guaranteed to be `Expr::AggregateFunction`
+                            .unwrap()
+                    } else {
+                        first_value_udaf.call(vec![e])
+                    }
                 });
 
                 let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index d0e0803372..a6889633d2 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -26,6 +26,8 @@ use arrow::datatypes::{
     DataType, Field, Fields, Int32Type, IntervalDayTimeType, 
IntervalMonthDayNanoType,
     IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
 };
+use prost::Message;
+
 use datafusion::datasource::provider::TableProviderFactory;
 use datafusion::datasource::TableProvider;
 use datafusion::execution::context::SessionState;
@@ -64,8 +66,6 @@ use datafusion_proto::logical_plan::{
 };
 use datafusion_proto::protobuf;
 
-use prost::Message;
-
 #[cfg(feature = "json")]
 fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
     let string = serde_json::to_string(proto).unwrap();
@@ -647,7 +647,8 @@ async fn roundtrip_expr_api() -> Result<()> {
             lit(1),
         ),
         array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), 
lit(4)),
-        first_value(vec![lit(1)], false, None, None, None),
+        first_value(lit(1), None),
+        first_value(lit(1), Some(vec![lit(2).sort(true, true)])),
         covar_samp(lit(1.5), lit(2.2)),
         covar_pop(lit(1.5), lit(2.2)),
         sum(lit(1)),
diff --git a/docs/source/user-guide/expressions.md 
b/docs/source/user-guide/expressions.md
index a5fc134916..cae9627210 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -304,6 +304,16 @@ select log(-1), log(0), sqrt(-1);
 | rollup(exprs)                                                     | Creates 
a grouping set for rollup sets.                                                 
|
 | sum(expr)                                                         | 
Сalculates the sum of `expr`.                                                   
        |
 
+## Aggregate Function Builder
+
+You can also use the `AggregateExt` trait to more easily build Aggregate 
arguments `Expr`.
+
+See `datafusion-examples/examples/expr_api.rs` for example usage.
+
+| Syntax                                                                  | 
Equivalent to                       |
+| ----------------------------------------------------------------------- | 
----------------------------------- |
+| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build().unwrap() | 
first_value(expr, Some(vec![expr])) |
+
 ## Subquery Expressions
 
 | Syntax          | Description                                                
                                   |


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to