This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d00760853 feature: Add a WindowUDFImpl::simplify() API (#9906)
3d00760853 is described below

commit 3d007608535cb138ae4473ce6305bd4ec8481627
Author: junxiangMu <[email protected]>
AuthorDate: Wed May 29 20:18:02 2024 -0400

    feature: Add a WindowUDFImpl::simplify() API (#9906)
    
    * feature: Add a WindowUDFImpl::simplfy() API
    
    Signed-off-by: guojidan <[email protected]>
    
    * fix doc
    
    Signed-off-by: guojidan <[email protected]>
    
    * fix fmt
    
    Signed-off-by: guojidan <[email protected]>
    
    ---------
    
    Signed-off-by: guojidan <[email protected]>
---
 .../examples/simplify_udwf_expression.rs           | 142 +++++++++++++++++++++
 datafusion/expr/src/function.rs                    |  13 ++
 datafusion/expr/src/udwf.rs                        |  34 ++++-
 .../src/simplify_expressions/expr_simplifier.rs    | 103 ++++++++++++++-
 4 files changed, 288 insertions(+), 4 deletions(-)

diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs 
b/datafusion-examples/examples/simplify_udwf_expression.rs
new file mode 100644
index 0000000000..2824d03761
--- /dev/null
+++ b/datafusion-examples/examples/simplify_udwf_expression.rs
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::any::Any;
+
+use arrow_schema::DataType;
+use datafusion::execution::context::SessionContext;
+use datafusion::{error::Result, execution::options::CsvReadOptions};
+use datafusion_expr::function::WindowFunctionSimplification;
+use datafusion_expr::{
+    expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr,
+    PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
+};
+
+/// This UDWF will show how to use the WindowUDFImpl::simplify() API
+#[derive(Debug, Clone)]
+struct SimplifySmoothItUdf {
+    signature: Signature,
+}
+
+impl SimplifySmoothItUdf {
+    fn new() -> Self {
+        Self {
+            signature: Signature::exact(
+                // this function will always take one arguments of type f64
+                vec![DataType::Float64],
+                // this function is deterministic and will always return the 
same
+                // result for the same input
+                Volatility::Immutable,
+            ),
+        }
+    }
+}
+impl WindowUDFImpl for SimplifySmoothItUdf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "simplify_smooth_it"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
+        todo!()
+    }
+
+    /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`.
+    fn simplify(&self) -> Option<WindowFunctionSimplification> {
+        // Ok(ExprSimplifyResult::Simplified(Expr::WindowFunction(
+        //     WindowFunction {
+        //         fun: 
datafusion_expr::WindowFunctionDefinition::AggregateFunction(
+        //             AggregateFunction::Avg,
+        //         ),
+        //         args,
+        //         partition_by: partition_by.to_vec(),
+        //         order_by: order_by.to_vec(),
+        //         window_frame: window_frame.clone(),
+        //         null_treatment: *null_treatment,
+        //     },
+        // )))
+        let simplify = |window_function: datafusion_expr::expr::WindowFunction,
+                        _: &dyn SimplifyInfo| {
+            Ok(Expr::WindowFunction(WindowFunction {
+                fun: 
datafusion_expr::WindowFunctionDefinition::AggregateFunction(
+                    AggregateFunction::Avg,
+                ),
+                args: window_function.args,
+                partition_by: window_function.partition_by,
+                order_by: window_function.order_by,
+                window_frame: window_function.window_frame,
+                null_treatment: window_function.null_treatment,
+            }))
+        };
+
+        Some(Box::new(simplify))
+    }
+}
+
+// create local execution context with `cars.csv` registered as a table named 
`cars`
+async fn create_context() -> Result<SessionContext> {
+    // declare a new context. In spark API, this corresponds to a new spark 
SQL session
+    let ctx = SessionContext::new();
+
+    // declare a table in memory. In spark API, this corresponds to 
createDataFrame(...).
+    println!("pwd: {}", std::env::current_dir().unwrap().display());
+    let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string();
+    let read_options = CsvReadOptions::default().has_header(true);
+
+    ctx.register_csv("cars", &csv_path, read_options).await?;
+    Ok(ctx)
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+    let ctx = create_context().await?;
+    let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new());
+    ctx.register_udwf(simplify_smooth_it.clone());
+
+    // Use SQL to run the new window function
+    let df = ctx.sql("SELECT * from cars").await?;
+    // print the results
+    df.show().await?;
+
+    let df = ctx
+        .sql(
+            "SELECT \
+               car, \
+               speed, \
+               simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) 
AS smooth_speed,\
+               time \
+               from cars \
+             ORDER BY \
+               car",
+        )
+        .await?;
+    // print the results
+    df.show().await?;
+
+    Ok(())
+}
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index eb748ed271..7f49b03bb2 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -134,3 +134,16 @@ pub type AggregateFunctionSimplification = Box<
         &dyn crate::simplify::SimplifyInfo,
     ) -> Result<Expr>,
 >;
+
+/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure
+/// A closure with two arguments:
+/// * 'window_function': [crate::expr::WindowFunction] for which simplified 
has been invoked
+/// * 'info': [crate::simplify::SimplifyInfo]
+///
+/// closure returns simplified [Expr] or an error.
+pub type WindowFunctionSimplification = Box<
+    dyn Fn(
+        crate::expr::WindowFunction,
+        &dyn crate::simplify::SimplifyInfo,
+    ) -> Result<Expr>,
+>;
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index 5a8373509a..ce28b444ad 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -18,8 +18,8 @@
 //! [`WindowUDF`]: User Defined Window Functions
 
 use crate::{
-    Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, 
Signature,
-    WindowFrame,
+    function::WindowFunctionSimplification, Expr, PartitionEvaluator,
+    PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
 };
 use arrow::datatypes::DataType;
 use datafusion_common::Result;
@@ -170,6 +170,13 @@ impl WindowUDF {
         self.inner.return_type(args)
     }
 
+    /// Do the function rewrite
+    ///
+    /// See [`WindowUDFImpl::simplify`] for more details.
+    pub fn simplify(&self) -> Option<WindowFunctionSimplification> {
+        self.inner.simplify()
+    }
+
     /// Return a `PartitionEvaluator` for evaluating this window function
     pub fn partition_evaluator_factory(&self) -> Result<Box<dyn 
PartitionEvaluator>> {
         self.inner.partition_evaluator()
@@ -266,6 +273,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
     fn aliases(&self) -> &[String] {
         &[]
     }
+
+    /// Optionally apply per-UDWF simplification / rewrite rules.
+    ///
+    /// This can be used to apply function specific simplification rules during
+    /// optimization. The default implementation does nothing.
+    ///
+    /// Note that DataFusion handles simplifying arguments and  "constant
+    /// folding" (replacing a function call with constant arguments such as
+    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
+    /// optimizations manually for specific UDFs.
+    ///
+    /// Example:
+    /// [`simplify_udwf_expression.rs`]: 
<https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simplify_udwf_expression.rs>
+    ///
+    /// # Returns
+    /// [None] if simplify is not defined or,
+    ///
+    /// Or, a closure with two arguments:
+    /// * 'window_function': [crate::expr::WindowFunction] for which 
simplified has been invoked
+    /// * 'info': [crate::simplify::SimplifyInfo]
+    fn simplify(&self) -> Option<WindowFunctionSimplification> {
+        None
+    }
 }
 
 /// WindowUDF that adds an alias to the underlying function. It is better to
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 25504e5c78..c87654292a 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -32,10 +32,13 @@ use datafusion_common::{
     tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
 };
 use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, 
ScalarValue};
-use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
+use datafusion_expr::expr::{
+    AggregateFunctionDefinition, InList, InSubquery, WindowFunction,
+};
 use datafusion_expr::simplify::ExprSimplifyResult;
 use datafusion_expr::{
     and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, 
Volatility,
+    WindowFunctionDefinition,
 };
 use datafusion_expr::{expr::ScalarFunction, 
interval_arithmetic::NullableInterval};
 use datafusion_physical_expr::{create_physical_expr, 
execution_props::ExecutionProps};
@@ -1391,6 +1394,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 (_, expr) => Transformed::no(expr),
             },
 
+            Expr::WindowFunction(WindowFunction {
+                fun: WindowFunctionDefinition::WindowUDF(ref udwf),
+                ..
+            }) => match (udwf.simplify(), expr) {
+                (Some(simplify_function), Expr::WindowFunction(wf)) => {
+                    Transformed::yes(simplify_function(wf, info)?)
+                }
+                (_, expr) => Transformed::no(expr),
+            },
+
             //
             // Rules for Between
             //
@@ -1758,7 +1771,10 @@ fn inlist_except(mut l1: InList, l2: InList) -> 
Result<Expr> {
 mod tests {
     use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
     use datafusion_expr::{
-        function::{AccumulatorArgs, AggregateFunctionSimplification},
+        function::{
+            AccumulatorArgs, AggregateFunctionSimplification,
+            WindowFunctionSimplification,
+        },
         interval_arithmetic::Interval,
         *,
     };
@@ -3800,4 +3816,87 @@ mod tests {
             }
         }
     }
+
+    #[test]
+    fn test_simplify_udwf() {
+        let udwf = WindowFunctionDefinition::WindowUDF(
+            
WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
+        );
+        let window_function_expr =
+            Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
+                udwf,
+                vec![],
+                vec![],
+                vec![],
+                WindowFrame::new(None),
+                None,
+            ));
+
+        let expected = col("result_column");
+        assert_eq!(simplify(window_function_expr), expected);
+
+        let udwf = WindowFunctionDefinition::WindowUDF(
+            
WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
+        );
+        let window_function_expr =
+            Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
+                udwf,
+                vec![],
+                vec![],
+                vec![],
+                WindowFrame::new(None),
+                None,
+            ));
+
+        let expected = window_function_expr.clone();
+        assert_eq!(simplify(window_function_expr), expected);
+    }
+
+    /// A Mock UDWF which defines `simplify` to be used in tests
+    /// related to UDWF simplification
+    #[derive(Debug, Clone)]
+    struct SimplifyMockUdwf {
+        simplify: bool,
+    }
+
+    impl SimplifyMockUdwf {
+        /// make simplify method return new expression
+        fn new_with_simplify() -> Self {
+            Self { simplify: true }
+        }
+        /// make simplify method return no change
+        fn new_without_simplify() -> Self {
+            Self { simplify: false }
+        }
+    }
+
+    impl WindowUDFImpl for SimplifyMockUdwf {
+        fn as_any(&self) -> &dyn std::any::Any {
+            self
+        }
+
+        fn name(&self) -> &str {
+            "mock_simplify"
+        }
+
+        fn signature(&self) -> &Signature {
+            unimplemented!()
+        }
+
+        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+            unimplemented!("not needed for tests")
+        }
+
+        fn simplify(&self) -> Option<WindowFunctionSimplification> {
+            if self.simplify {
+                Some(Box::new(|_, _| Ok(col("result_column"))))
+            } else {
+                None
+            }
+        }
+
+        fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
+            unimplemented!("not needed for tests")
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to