This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new cdbd964346 support window function sql2expr (#10243)
cdbd964346 is described below

commit cdbd96434676c8c34e742a6cfea6bc7499e97cde
Author: Junhao Liu <[email protected]>
AuthorDate: Fri Apr 26 09:18:16 2024 -0600

    support window function sql2expr (#10243)
---
 datafusion/expr/src/built_in_window_function.rs |   2 +-
 datafusion/expr/src/expr.rs                     |  10 ++
 datafusion/sql/src/unparser/expr.rs             | 175 +++++++++++++++++++-----
 3 files changed, 151 insertions(+), 36 deletions(-)

diff --git a/datafusion/expr/src/built_in_window_function.rs 
b/datafusion/expr/src/built_in_window_function.rs
index 1001bbb015..18a888ae8b 100644
--- a/datafusion/expr/src/built_in_window_function.rs
+++ b/datafusion/expr/src/built_in_window_function.rs
@@ -71,7 +71,7 @@ pub enum BuiltInWindowFunction {
 }
 
 impl BuiltInWindowFunction {
-    fn name(&self) -> &str {
+    pub fn name(&self) -> &str {
         use BuiltInWindowFunction::*;
         match self {
             RowNumber => "ROW_NUMBER",
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 0d8e8d816b..e310eaa7e4 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -669,6 +669,16 @@ impl WindowFunctionDefinition {
             WindowFunctionDefinition::WindowUDF(fun) => 
fun.signature().clone(),
         }
     }
+
+    /// Function's name for display
+    pub fn name(&self) -> &str {
+        match self {
+            WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(),
+            WindowFunctionDefinition::WindowUDF(fun) => fun.name(),
+            WindowFunctionDefinition::AggregateFunction(fun) => fun.name(),
+            WindowFunctionDefinition::AggregateUDF(fun) => fun.name(),
+        }
+    }
 }
 
 impl fmt::Display for WindowFunctionDefinition {
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index d091fbe14d..7194b0a7d8 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -21,10 +21,7 @@ use datafusion_common::{
     internal_datafusion_err, not_impl_err, plan_err, Column, Result, 
ScalarValue,
 };
 use datafusion_expr::{
-    expr::{
-        AggregateFunctionDefinition, Alias, Exists, InList, ScalarFunction, 
Sort,
-        WindowFunction,
-    },
+    expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction},
     Between, BinaryExpr, Case, Cast, Expr, Like, Operator,
 };
 use sqlparser::ast::{
@@ -170,14 +167,56 @@ impl Unparser<'_> {
             Expr::Literal(value) => Ok(self.scalar_to_sql(value)?),
             Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr),
             Expr::WindowFunction(WindowFunction {
-                fun: _,
-                args: _,
-                partition_by: _,
+                fun,
+                args,
+                partition_by,
                 order_by: _,
-                window_frame: _,
+                window_frame,
                 null_treatment: _,
             }) => {
-                not_impl_err!("Unsupported expression: {expr:?}")
+                let func_name = fun.name();
+
+                let args = self.function_args_to_sql(args)?;
+
+                let units = match window_frame.units {
+                    datafusion_expr::window_frame::WindowFrameUnits::Rows => {
+                        ast::WindowFrameUnits::Rows
+                    }
+                    datafusion_expr::window_frame::WindowFrameUnits::Range => {
+                        ast::WindowFrameUnits::Range
+                    }
+                    datafusion_expr::window_frame::WindowFrameUnits::Groups => 
{
+                        ast::WindowFrameUnits::Groups
+                    }
+                };
+                let start_bound = 
self.convert_bound(&window_frame.start_bound);
+                let end_bound = self.convert_bound(&window_frame.end_bound);
+                let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
+                    window_name: None,
+                    partition_by: partition_by
+                        .iter()
+                        .map(|e| self.expr_to_sql(e))
+                        .collect::<Result<Vec<_>>>()?,
+                    order_by: vec![],
+                    window_frame: Some(ast::WindowFrame {
+                        units,
+                        start_bound,
+                        end_bound: Option::from(end_bound),
+                    }),
+                }));
+                Ok(ast::Expr::Function(Function {
+                    name: ast::ObjectName(vec![Ident {
+                        value: func_name.to_string(),
+                        quote_style: None,
+                    }]),
+                    args,
+                    filter: None,
+                    null_treatment: None,
+                    over,
+                    distinct: false,
+                    special: false,
+                    order_by: vec![],
+                }))
             }
             Expr::SimilarTo(Like {
                 negated,
@@ -199,37 +238,20 @@ impl Unparser<'_> {
                 escape_char: *escape_char,
             }),
             Expr::AggregateFunction(agg) => {
-                let func_name = if let 
AggregateFunctionDefinition::BuiltIn(built_in) =
-                    &agg.func_def
-                {
-                    built_in.name()
-                } else {
-                    return not_impl_err!(
-                        "Only built in agg functions are supported, got 
{agg:?}"
-                    );
-                };
-
-                let args = agg
-                    .args
-                    .iter()
-                    .map(|e| {
-                        if matches!(e, Expr::Wildcard { qualifier: None }) {
-                            
Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
-                        } else {
-                            self.expr_to_sql(e).map(|e| {
-                                
FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
-                            })
-                        }
-                    })
-                    .collect::<Result<Vec<_>>>()?;
+                let func_name = agg.func_def.name();
 
+                let args = self.function_args_to_sql(&agg.args)?;
+                let filter = match &agg.filter {
+                    Some(filter) => Some(Box::new(self.expr_to_sql(filter)?)),
+                    None => None,
+                };
                 Ok(ast::Expr::Function(Function {
                     name: ast::ObjectName(vec![Ident {
                         value: func_name.to_string(),
                         quote_style: None,
                     }]),
                     args,
-                    filter: None,
+                    filter,
                     null_treatment: None,
                     over: None,
                     distinct: agg.distinct,
@@ -355,6 +377,40 @@ impl Unparser<'_> {
         Ok(ast::Expr::Identifier(self.new_ident(col.name.to_string())))
     }
 
+    fn convert_bound(
+        &self,
+        bound: &datafusion_expr::window_frame::WindowFrameBound,
+    ) -> ast::WindowFrameBound {
+        match bound {
+            datafusion_expr::window_frame::WindowFrameBound::Preceding(val) => 
{
+                ast::WindowFrameBound::Preceding(
+                    self.scalar_to_sql(val).map(Box::new).ok(),
+                )
+            }
+            datafusion_expr::window_frame::WindowFrameBound::Following(val) => 
{
+                ast::WindowFrameBound::Following(
+                    self.scalar_to_sql(val).map(Box::new).ok(),
+                )
+            }
+            datafusion_expr::window_frame::WindowFrameBound::CurrentRow => {
+                ast::WindowFrameBound::CurrentRow
+            }
+        }
+    }
+
+    fn function_args_to_sql(&self, args: &[Expr]) -> 
Result<Vec<ast::FunctionArg>> {
+        args.iter()
+            .map(|e| {
+                if matches!(e, Expr::Wildcard { qualifier: None }) {
+                    
Ok(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
+                } else {
+                    self.expr_to_sql(e)
+                        .map(|e| 
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)))
+                }
+            })
+            .collect::<Result<Vec<_>>>()
+    }
+
     pub(super) fn new_ident(&self, str: String) -> ast::Ident {
         ast::Ident {
             value: str,
@@ -735,8 +791,10 @@ mod tests {
     use arrow::datatypes::{Field, Schema};
     use datafusion_common::TableReference;
     use datafusion_expr::{
-        case, col, exists, expr::AggregateFunction, lit, not, not_exists, 
table_scan,
-        ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
+        case, col, exists,
+        expr::{AggregateFunction, AggregateFunctionDefinition},
+        lit, not, not_exists, table_scan, wildcard, ColumnarValue, ScalarUDF,
+        ScalarUDFImpl, Signature, Volatility, WindowFrame, 
WindowFunctionDefinition,
     };
 
     use crate::unparser::dialect::CustomDialect;
@@ -901,6 +959,53 @@ mod tests {
                 }),
                 "COUNT(DISTINCT *)",
             ),
+            (
+                Expr::AggregateFunction(AggregateFunction {
+                    func_def: AggregateFunctionDefinition::BuiltIn(
+                        datafusion_expr::AggregateFunction::Count,
+                    ),
+                    args: vec![Expr::Wildcard { qualifier: None }],
+                    distinct: false,
+                    filter: Some(Box::new(lit(true))),
+                    order_by: None,
+                    null_treatment: None,
+                }),
+                "COUNT(*) FILTER (WHERE true)",
+            ),
+            (
+                Expr::WindowFunction(WindowFunction {
+                    fun: WindowFunctionDefinition::BuiltInWindowFunction(
+                        datafusion_expr::BuiltInWindowFunction::RowNumber,
+                    ),
+                    args: vec![col("col")],
+                    partition_by: vec![],
+                    order_by: vec![],
+                    window_frame: WindowFrame::new(None),
+                    null_treatment: None,
+                }),
+                r#"ROW_NUMBER("col") OVER (ROWS BETWEEN NULL PRECEDING AND 
NULL FOLLOWING)"#,
+            ),
+            (
+                Expr::WindowFunction(WindowFunction {
+                    fun: WindowFunctionDefinition::AggregateFunction(
+                        datafusion_expr::AggregateFunction::Count,
+                    ),
+                    args: vec![wildcard()],
+                    partition_by: vec![],
+                    order_by: vec![],
+                    window_frame: WindowFrame::new_bounds(
+                        datafusion_expr::WindowFrameUnits::Range,
+                        datafusion_expr::WindowFrameBound::Preceding(
+                            ScalarValue::UInt32(Some(6)),
+                        ),
+                        datafusion_expr::WindowFrameBound::Following(
+                            ScalarValue::UInt32(Some(2)),
+                        ),
+                    ),
+                    null_treatment: None,
+                }),
+                r#"COUNT(*) OVER (RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#,
+            ),
             (col("a").is_not_null(), r#""a" IS NOT NULL"#),
             (
                 (col("a") + col("b")).gt(lit(4)).is_true(),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to