This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 230c68c02b Add `simplify`  method to aggregate function (#10354)
230c68c02b is described below

commit 230c68c02bf0c3d5b7d50d24145eb50604420d4f
Author: Marko Milenković <[email protected]>
AuthorDate: Mon May 13 12:53:56 2024 +0100

    Add `simplify`  method to aggregate function (#10354)
    
    * add simplify method for aggregate function
    
    * simplify returns closure
---
 .../examples/simplify_udaf_expression.rs           | 180 +++++++++++++++++++++
 datafusion/expr/src/function.rs                    |  13 ++
 datafusion/expr/src/udaf.rs                        |  33 +++-
 .../src/simplify_expressions/expr_simplifier.rs    | 105 +++++++++++-
 4 files changed, 328 insertions(+), 3 deletions(-)

diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs 
b/datafusion-examples/examples/simplify_udaf_expression.rs
new file mode 100644
index 0000000000..92deb20272
--- /dev/null
+++ b/datafusion-examples/examples/simplify_udaf_expression.rs
@@ -0,0 +1,180 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::{Field, Schema};
+use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
+use datafusion_expr::function::AggregateFunctionSimplification;
+use datafusion_expr::simplify::SimplifyInfo;
+
+use std::{any::Any, sync::Arc};
+
+use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
+use datafusion::error::Result;
+use datafusion::{assert_batches_eq, prelude::*};
+use datafusion_common::cast::as_float64_array;
+use datafusion_expr::{
+    expr::{AggregateFunction, AggregateFunctionDefinition},
+    function::AccumulatorArgs,
+    Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
+};
+
+/// This example shows how to use the AggregateUDFImpl::simplify API to 
simplify/replace user
+/// defined aggregate function with a different expression which is defined in 
the `simplify` method.
+
+#[derive(Debug, Clone)]
+struct BetterAvgUdaf {
+    signature: Signature,
+}
+
+impl BetterAvgUdaf {
+    /// Create a new instance of the GeoMeanUdaf struct
+    fn new() -> Self {
+        Self {
+            signature: Signature::exact(vec![DataType::Float64], 
Volatility::Immutable),
+        }
+    }
+}
+
+impl AggregateUDFImpl for BetterAvgUdaf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "better_avg"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        unimplemented!("should not be invoked")
+    }
+
+    fn state_fields(
+        &self,
+        _name: &str,
+        _value_type: DataType,
+        _ordering_fields: Vec<arrow_schema::Field>,
+    ) -> Result<Vec<arrow_schema::Field>> {
+        unimplemented!("should not be invoked")
+    }
+
+    fn groups_accumulator_supported(&self) -> bool {
+        true
+    }
+
+    fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+        unimplemented!("should not get here");
+    }
+    // we override method, to return new expression which would substitute
+    // user defined function call
+    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+        // as an example for this functionality we replace UDF function
+        // with build-in aggregate function to illustrate the use
+        let simplify = |aggregate_function: 
datafusion_expr::expr::AggregateFunction,
+                        _: &dyn SimplifyInfo| {
+            Ok(Expr::AggregateFunction(AggregateFunction {
+                func_def: AggregateFunctionDefinition::BuiltIn(
+                    // yes it is the same Avg, `BetterAvgUdaf` was just a
+                    // marketing pitch :)
+                    
datafusion_expr::aggregate_function::AggregateFunction::Avg,
+                ),
+                args: aggregate_function.args,
+                distinct: aggregate_function.distinct,
+                filter: aggregate_function.filter,
+                order_by: aggregate_function.order_by,
+                null_treatment: aggregate_function.null_treatment,
+            }))
+        };
+
+        Some(Box::new(simplify))
+    }
+}
+
+// create local session context with an in-memory table
+fn create_context() -> Result<SessionContext> {
+    use datafusion::datasource::MemTable;
+    // define a schema.
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("a", DataType::Float32, false),
+        Field::new("b", DataType::Float32, false),
+    ]));
+
+    // define data in two partitions
+    let batch1 = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
+            Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
+        ],
+    )?;
+    let batch2 = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Float32Array::from(vec![16.0])),
+            Arc::new(Float32Array::from(vec![2.0])),
+        ],
+    )?;
+
+    let ctx = SessionContext::new();
+
+    // declare a table in memory. In spark API, this corresponds to 
createDataFrame(...).
+    let provider = MemTable::try_new(schema, vec![vec![batch1], 
vec![batch2]])?;
+    ctx.register_table("t", Arc::new(provider))?;
+    Ok(ctx)
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+    let ctx = create_context()?;
+
+    let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
+    ctx.register_udaf(better_avg.clone());
+
+    let result = ctx
+        .sql("SELECT better_avg(a) FROM t group by b")
+        .await?
+        .collect()
+        .await?;
+
+    let expected = [
+        "+-----------------+",
+        "| better_avg(t.a) |",
+        "+-----------------+",
+        "| 7.5             |",
+        "+-----------------+",
+    ];
+
+    assert_batches_eq!(expected, &result);
+
+    let df = ctx.table("t").await?;
+    let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
+
+    let results = df.collect().await?;
+    let result = as_float64_array(results[0].column(0))?;
+
+    assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
+    println!("The average of [2,4,8,16] is {}", result.value(0));
+
+    Ok(())
+}
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 7a92a50ae1..4e4d77924a 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory =
 /// its state, given its return datatype.
 pub type StateTypeFunction =
     Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
+
+/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
+/// A closure with two arguments:
+/// * 'aggregate_function': [crate::expr::AggregateFunction] for which 
simplified has been invoked
+/// * 'info': [crate::simplify::SimplifyInfo]
+///
+/// closure returns simplified [Expr] or an error.
+pub type AggregateFunctionSimplification = Box<
+    dyn Fn(
+        crate::expr::AggregateFunction,
+        &dyn crate::simplify::SimplifyInfo,
+    ) -> Result<Expr>,
+>;
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index e5a47ddcd8..95121d78e7 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,7 +17,7 @@
 
 //! [`AggregateUDF`]: User Defined Aggregate Functions
 
-use crate::function::AccumulatorArgs;
+use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
 use crate::groups_accumulator::GroupsAccumulator;
 use crate::utils::format_state_name;
 use crate::{Accumulator, Expr};
@@ -199,6 +199,12 @@ impl AggregateUDF {
     pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
         not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
     }
+    /// Do the function rewrite
+    ///
+    /// See [`AggregateUDFImpl::simplify`] for more details.
+    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+        self.inner.simplify()
+    }
 }
 
 impl<F> From<F> for AggregateUDF
@@ -358,6 +364,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
     fn aliases(&self) -> &[String] {
         &[]
     }
+
+    /// Optionally apply per-UDaF simplification / rewrite rules.
+    ///
+    /// This can be used to apply function specific simplification rules during
+    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
+    /// implementation does nothing.
+    ///
+    /// Note that DataFusion handles simplifying arguments and  "constant
+    /// folding" (replacing a function call with constant arguments such as
+    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
+    /// optimizations manually for specific UDFs.
+    ///
+    /// # Returns
+    ///
+    /// [None] if simplify is not defined or,
+    ///
+    /// Or, a closure with two arguments:
+    /// * 'aggregate_function': [crate::expr::AggregateFunction] for which 
simplified has been invoked
+    /// * 'info': [crate::simplify::SimplifyInfo]
+    ///
+    /// closure returns simplified [Expr] or an error.
+    ///
+    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+        None
+    }
 }
 
 /// AggregateUDF that adds an alias to the underlying function. It is better to
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 5122de4f09..55052542a8 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -32,7 +32,7 @@ use datafusion_common::{
     tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
 };
 use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, 
ScalarValue};
-use datafusion_expr::expr::{InList, InSubquery};
+use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
 use datafusion_expr::simplify::ExprSimplifyResult;
 use datafusion_expr::{
     and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, 
Volatility,
@@ -1382,6 +1382,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 }
             }
 
+            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
+                func_def: AggregateFunctionDefinition::UDF(ref udaf),
+                ..
+            }) => match (udaf.simplify(), expr) {
+                (Some(simplify_function), Expr::AggregateFunction(af)) => {
+                    Transformed::yes(simplify_function(af, info)?)
+                }
+                (_, expr) => Transformed::no(expr),
+            },
+
             //
             // Rules for Between
             //
@@ -1748,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> 
Result<Expr> {
 #[cfg(test)]
 mod tests {
     use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
-    use datafusion_expr::{interval_arithmetic::Interval, *};
+    use datafusion_expr::{
+        function::AggregateFunctionSimplification, 
interval_arithmetic::Interval, *,
+    };
     use std::{
         collections::HashMap,
         ops::{BitAnd, BitOr, BitXor},
@@ -3698,4 +3710,93 @@ mod tests {
         assert_eq!(expr, expected);
         assert_eq!(num_iter, 2);
     }
+    #[test]
+    fn test_simplify_udaf() {
+        let udaf = 
AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
+        let aggregate_function_expr =
+            
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+                udaf.into(),
+                vec![],
+                false,
+                None,
+                None,
+                None,
+            ));
+
+        let expected = col("result_column");
+        assert_eq!(simplify(aggregate_function_expr), expected);
+
+        let udaf = 
AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
+        let aggregate_function_expr =
+            
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+                udaf.into(),
+                vec![],
+                false,
+                None,
+                None,
+                None,
+            ));
+
+        let expected = aggregate_function_expr.clone();
+        assert_eq!(simplify(aggregate_function_expr), expected);
+    }
+
+    /// A Mock UDAF which defines `simplify` to be used in tests
+    /// related to UDAF simplification
+    #[derive(Debug, Clone)]
+    struct SimplifyMockUdaf {
+        simplify: bool,
+    }
+
+    impl SimplifyMockUdaf {
+        /// make simplify method return new expression
+        fn new_with_simplify() -> Self {
+            Self { simplify: true }
+        }
+        /// make simplify method return no change
+        fn new_without_simplify() -> Self {
+            Self { simplify: false }
+        }
+    }
+
+    impl AggregateUDFImpl for SimplifyMockUdaf {
+        fn as_any(&self) -> &dyn std::any::Any {
+            self
+        }
+
+        fn name(&self) -> &str {
+            "mock_simplify"
+        }
+
+        fn signature(&self) -> &Signature {
+            unimplemented!()
+        }
+
+        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+            unimplemented!("not needed for tests")
+        }
+
+        fn accumulator(
+            &self,
+            _acc_args: function::AccumulatorArgs,
+        ) -> Result<Box<dyn Accumulator>> {
+            unimplemented!("not needed for tests")
+        }
+
+        fn groups_accumulator_supported(&self) -> bool {
+            unimplemented!("not needed for testing")
+        }
+
+        fn create_groups_accumulator(&self) -> Result<Box<dyn 
GroupsAccumulator>> {
+            unimplemented!("not needed for testing")
+        }
+
+        fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+            if self.simplify {
+                Some(Box::new(|_, _| Ok(col("result_column"))))
+            } else {
+                None
+            }
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to