This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3d00760853 feature: Add a WindowUDFImpl::simplify() API (#9906)
3d00760853 is described below
commit 3d007608535cb138ae4473ce6305bd4ec8481627
Author: junxiangMu <[email protected]>
AuthorDate: Wed May 29 20:18:02 2024 -0400
feature: Add a WindowUDFImpl::simplify() API (#9906)
* feature: Add a WindowUDFImpl::simplfy() API
Signed-off-by: guojidan <[email protected]>
* fix doc
Signed-off-by: guojidan <[email protected]>
* fix fmt
Signed-off-by: guojidan <[email protected]>
---------
Signed-off-by: guojidan <[email protected]>
---
.../examples/simplify_udwf_expression.rs | 142 +++++++++++++++++++++
datafusion/expr/src/function.rs | 13 ++
datafusion/expr/src/udwf.rs | 34 ++++-
.../src/simplify_expressions/expr_simplifier.rs | 103 ++++++++++++++-
4 files changed, 288 insertions(+), 4 deletions(-)
diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs
b/datafusion-examples/examples/simplify_udwf_expression.rs
new file mode 100644
index 0000000000..2824d03761
--- /dev/null
+++ b/datafusion-examples/examples/simplify_udwf_expression.rs
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::any::Any;
+
+use arrow_schema::DataType;
+use datafusion::execution::context::SessionContext;
+use datafusion::{error::Result, execution::options::CsvReadOptions};
+use datafusion_expr::function::WindowFunctionSimplification;
+use datafusion_expr::{
+ expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr,
+ PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
+};
+
+/// This UDWF will show how to use the WindowUDFImpl::simplify() API
+#[derive(Debug, Clone)]
+struct SimplifySmoothItUdf {
+ signature: Signature,
+}
+
+impl SimplifySmoothItUdf {
+ fn new() -> Self {
+ Self {
+ signature: Signature::exact(
+ // this function will always take one arguments of type f64
+ vec![DataType::Float64],
+ // this function is deterministic and will always return the
same
+ // result for the same input
+ Volatility::Immutable,
+ ),
+ }
+ }
+}
+impl WindowUDFImpl for SimplifySmoothItUdf {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "simplify_smooth_it"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Float64)
+ }
+
+ fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
+ todo!()
+ }
+
+ /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`.
+ fn simplify(&self) -> Option<WindowFunctionSimplification> {
+ // Ok(ExprSimplifyResult::Simplified(Expr::WindowFunction(
+ // WindowFunction {
+ // fun:
datafusion_expr::WindowFunctionDefinition::AggregateFunction(
+ // AggregateFunction::Avg,
+ // ),
+ // args,
+ // partition_by: partition_by.to_vec(),
+ // order_by: order_by.to_vec(),
+ // window_frame: window_frame.clone(),
+ // null_treatment: *null_treatment,
+ // },
+ // )))
+ let simplify = |window_function: datafusion_expr::expr::WindowFunction,
+ _: &dyn SimplifyInfo| {
+ Ok(Expr::WindowFunction(WindowFunction {
+ fun:
datafusion_expr::WindowFunctionDefinition::AggregateFunction(
+ AggregateFunction::Avg,
+ ),
+ args: window_function.args,
+ partition_by: window_function.partition_by,
+ order_by: window_function.order_by,
+ window_frame: window_function.window_frame,
+ null_treatment: window_function.null_treatment,
+ }))
+ };
+
+ Some(Box::new(simplify))
+ }
+}
+
+// create local execution context with `cars.csv` registered as a table named
`cars`
+async fn create_context() -> Result<SessionContext> {
+ // declare a new context. In spark API, this corresponds to a new spark
SQL session
+ let ctx = SessionContext::new();
+
+ // declare a table in memory. In spark API, this corresponds to
createDataFrame(...).
+ println!("pwd: {}", std::env::current_dir().unwrap().display());
+ let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string();
+ let read_options = CsvReadOptions::default().has_header(true);
+
+ ctx.register_csv("cars", &csv_path, read_options).await?;
+ Ok(ctx)
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ let ctx = create_context().await?;
+ let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new());
+ ctx.register_udwf(simplify_smooth_it.clone());
+
+ // Use SQL to run the new window function
+ let df = ctx.sql("SELECT * from cars").await?;
+ // print the results
+ df.show().await?;
+
+ let df = ctx
+ .sql(
+ "SELECT \
+ car, \
+ speed, \
+ simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time)
AS smooth_speed,\
+ time \
+ from cars \
+ ORDER BY \
+ car",
+ )
+ .await?;
+ // print the results
+ df.show().await?;
+
+ Ok(())
+}
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index eb748ed271..7f49b03bb2 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -134,3 +134,16 @@ pub type AggregateFunctionSimplification = Box<
&dyn crate::simplify::SimplifyInfo,
) -> Result<Expr>,
>;
+
+/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure
+/// A closure with two arguments:
+/// * 'window_function': [crate::expr::WindowFunction] for which simplified
has been invoked
+/// * 'info': [crate::simplify::SimplifyInfo]
+///
+/// closure returns simplified [Expr] or an error.
+pub type WindowFunctionSimplification = Box<
+ dyn Fn(
+ crate::expr::WindowFunction,
+ &dyn crate::simplify::SimplifyInfo,
+ ) -> Result<Expr>,
+>;
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index 5a8373509a..ce28b444ad 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -18,8 +18,8 @@
//! [`WindowUDF`]: User Defined Window Functions
use crate::{
- Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction,
Signature,
- WindowFrame,
+ function::WindowFunctionSimplification, Expr, PartitionEvaluator,
+ PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
@@ -170,6 +170,13 @@ impl WindowUDF {
self.inner.return_type(args)
}
+ /// Do the function rewrite
+ ///
+ /// See [`WindowUDFImpl::simplify`] for more details.
+ pub fn simplify(&self) -> Option<WindowFunctionSimplification> {
+ self.inner.simplify()
+ }
+
/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(&self) -> Result<Box<dyn
PartitionEvaluator>> {
self.inner.partition_evaluator()
@@ -266,6 +273,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}
+
+ /// Optionally apply per-UDWF simplification / rewrite rules.
+ ///
+ /// This can be used to apply function specific simplification rules during
+ /// optimization. The default implementation does nothing.
+ ///
+ /// Note that DataFusion handles simplifying arguments and "constant
+ /// folding" (replacing a function call with constant arguments such as
+ /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
+ /// optimizations manually for specific UDFs.
+ ///
+ /// Example:
+ /// [`simplify_udwf_expression.rs`]:
<https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simplify_udwf_expression.rs>
+ ///
+ /// # Returns
+ /// [None] if simplify is not defined or,
+ ///
+ /// Or, a closure with two arguments:
+ /// * 'window_function': [crate::expr::WindowFunction] for which
simplified has been invoked
+ /// * 'info': [crate::simplify::SimplifyInfo]
+ fn simplify(&self) -> Option<WindowFunctionSimplification> {
+ None
+ }
}
/// WindowUDF that adds an alias to the underlying function. It is better to
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 25504e5c78..c87654292a 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -32,10 +32,13 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result,
ScalarValue};
-use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
+use datafusion_expr::expr::{
+ AggregateFunctionDefinition, InList, InSubquery, WindowFunction,
+};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator,
Volatility,
+ WindowFunctionDefinition,
};
use datafusion_expr::{expr::ScalarFunction,
interval_arithmetic::NullableInterval};
use datafusion_physical_expr::{create_physical_expr,
execution_props::ExecutionProps};
@@ -1391,6 +1394,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
(_, expr) => Transformed::no(expr),
},
+ Expr::WindowFunction(WindowFunction {
+ fun: WindowFunctionDefinition::WindowUDF(ref udwf),
+ ..
+ }) => match (udwf.simplify(), expr) {
+ (Some(simplify_function), Expr::WindowFunction(wf)) => {
+ Transformed::yes(simplify_function(wf, info)?)
+ }
+ (_, expr) => Transformed::no(expr),
+ },
+
//
// Rules for Between
//
@@ -1758,7 +1771,10 @@ fn inlist_except(mut l1: InList, l2: InList) ->
Result<Expr> {
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{
- function::{AccumulatorArgs, AggregateFunctionSimplification},
+ function::{
+ AccumulatorArgs, AggregateFunctionSimplification,
+ WindowFunctionSimplification,
+ },
interval_arithmetic::Interval,
*,
};
@@ -3800,4 +3816,87 @@ mod tests {
}
}
}
+
+ #[test]
+ fn test_simplify_udwf() {
+ let udwf = WindowFunctionDefinition::WindowUDF(
+
WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
+ );
+ let window_function_expr =
+ Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
+ udwf,
+ vec![],
+ vec![],
+ vec![],
+ WindowFrame::new(None),
+ None,
+ ));
+
+ let expected = col("result_column");
+ assert_eq!(simplify(window_function_expr), expected);
+
+ let udwf = WindowFunctionDefinition::WindowUDF(
+
WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
+ );
+ let window_function_expr =
+ Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
+ udwf,
+ vec![],
+ vec![],
+ vec![],
+ WindowFrame::new(None),
+ None,
+ ));
+
+ let expected = window_function_expr.clone();
+ assert_eq!(simplify(window_function_expr), expected);
+ }
+
+ /// A Mock UDWF which defines `simplify` to be used in tests
+ /// related to UDWF simplification
+ #[derive(Debug, Clone)]
+ struct SimplifyMockUdwf {
+ simplify: bool,
+ }
+
+ impl SimplifyMockUdwf {
+ /// make simplify method return new expression
+ fn new_with_simplify() -> Self {
+ Self { simplify: true }
+ }
+ /// make simplify method return no change
+ fn new_without_simplify() -> Self {
+ Self { simplify: false }
+ }
+ }
+
+ impl WindowUDFImpl for SimplifyMockUdwf {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "mock_simplify"
+ }
+
+ fn signature(&self) -> &Signature {
+ unimplemented!()
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ unimplemented!("not needed for tests")
+ }
+
+ fn simplify(&self) -> Option<WindowFunctionSimplification> {
+ if self.simplify {
+ Some(Box::new(|_, _| Ok(col("result_column"))))
+ } else {
+ None
+ }
+ }
+
+ fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
+ unimplemented!("not needed for tests")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]