This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e4f7b9811f feat: support unparsing LogicalPlan::Window nodes (#10767)
e4f7b9811f is described below

commit e4f7b9811f245f0ccf8d0289f7d5edfe1499947a
Author: Devin D'Angelo <[email protected]>
AuthorDate: Mon Jun 3 15:00:32 2024 -0400

    feat: support unparsing LogicalPlan::Window nodes (#10767)
    
    * unparse window plans
    
    * new tests + fixes
    
    * fmt
---
 datafusion/sql/src/unparser/expr.rs       | 32 +++++++++-----
 datafusion/sql/src/unparser/plan.rs       | 71 +++++++++++++++++++++----------
 datafusion/sql/src/unparser/utils.rs      | 68 +++++++++++++++++++++++------
 datafusion/sql/tests/cases/plan_to_sql.rs |  8 +++-
 4 files changed, 132 insertions(+), 47 deletions(-)

diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index df390ce6ea..1ba6638e73 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -236,8 +236,8 @@ impl Unparser<'_> {
                     .map(|expr| expr_to_unparsed(expr)?.into_order_by_expr())
                     .collect::<Result<Vec<_>>>()?;
 
-                let start_bound = 
self.convert_bound(&window_frame.start_bound);
-                let end_bound = self.convert_bound(&window_frame.end_bound);
+                let start_bound = 
self.convert_bound(&window_frame.start_bound)?;
+                let end_bound = self.convert_bound(&window_frame.end_bound)?;
                 let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
                     window_name: None,
                     partition_by: partition_by
@@ -513,20 +513,30 @@ impl Unparser<'_> {
     fn convert_bound(
         &self,
         bound: &datafusion_expr::window_frame::WindowFrameBound,
-    ) -> ast::WindowFrameBound {
+    ) -> Result<ast::WindowFrameBound> {
         match bound {
             datafusion_expr::window_frame::WindowFrameBound::Preceding(val) => 
{
-                ast::WindowFrameBound::Preceding(
-                    self.scalar_to_sql(val).map(Box::new).ok(),
-                )
+                Ok(ast::WindowFrameBound::Preceding({
+                    let val = self.scalar_to_sql(val)?;
+                    if let ast::Expr::Value(ast::Value::Null) = &val {
+                        None
+                    } else {
+                        Some(Box::new(val))
+                    }
+                }))
             }
             datafusion_expr::window_frame::WindowFrameBound::Following(val) => 
{
-                ast::WindowFrameBound::Following(
-                    self.scalar_to_sql(val).map(Box::new).ok(),
-                )
+                Ok(ast::WindowFrameBound::Following({
+                    let val = self.scalar_to_sql(val)?;
+                    if let ast::Expr::Value(ast::Value::Null) = &val {
+                        None
+                    } else {
+                        Some(Box::new(val))
+                    }
+                }))
             }
             datafusion_expr::window_frame::WindowFrameBound::CurrentRow => {
-                ast::WindowFrameBound::CurrentRow
+                Ok(ast::WindowFrameBound::CurrentRow)
             }
         }
     }
@@ -1148,7 +1158,7 @@ mod tests {
                     window_frame: WindowFrame::new(None),
                     null_treatment: None,
                 }),
-                r#"ROW_NUMBER(col) OVER (ROWS BETWEEN NULL PRECEDING AND NULL 
FOLLOWING)"#,
+                r#"ROW_NUMBER(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 
UNBOUNDED FOLLOWING)"#,
             ),
             (
                 Expr::WindowFunction(WindowFunction {
diff --git a/datafusion/sql/src/unparser/plan.rs 
b/datafusion/sql/src/unparser/plan.rs
index e7e4d7700a..183bb1f7fb 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -28,7 +28,7 @@ use super::{
         BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
         SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
     },
-    utils::find_agg_node_within_select,
+    utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
     Unparser,
 };
 
@@ -162,23 +162,42 @@ impl Unparser<'_> {
                 // A second projection implies a derived tablefactor
                 if !select.already_projected() {
                     // Special handling when projecting an agregation plan
-                    if let Some(agg) = find_agg_node_within_select(plan, true) 
{
-                        let items = p
-                            .expr
-                            .iter()
-                            .map(|proj_expr| {
-                                let unproj = unproject_agg_exprs(proj_expr, 
agg)?;
-                                self.select_item_to_sql(&unproj)
-                            })
-                            .collect::<Result<Vec<_>>>()?;
-
-                        select.projection(items);
-                        select.group_by(ast::GroupByExpr::Expressions(
-                            agg.group_expr
-                                .iter()
-                                .map(|expr| self.expr_to_sql(expr))
-                                .collect::<Result<Vec<_>>>()?,
-                        ));
+                    if let Some(aggvariant) =
+                        find_agg_node_within_select(plan, None, true)
+                    {
+                        match aggvariant {
+                            AggVariant::Aggregate(agg) => {
+                                let items = p
+                                    .expr
+                                    .iter()
+                                    .map(|proj_expr| {
+                                        let unproj = 
unproject_agg_exprs(proj_expr, agg)?;
+                                        self.select_item_to_sql(&unproj)
+                                    })
+                                    .collect::<Result<Vec<_>>>()?;
+
+                                select.projection(items);
+                                select.group_by(ast::GroupByExpr::Expressions(
+                                    agg.group_expr
+                                        .iter()
+                                        .map(|expr| self.expr_to_sql(expr))
+                                        .collect::<Result<Vec<_>>>()?,
+                                ));
+                            }
+                            AggVariant::Window(window) => {
+                                let items = p
+                                    .expr
+                                    .iter()
+                                    .map(|proj_expr| {
+                                        let unproj =
+                                            unproject_window_exprs(proj_expr, 
&window)?;
+                                        self.select_item_to_sql(&unproj)
+                                    })
+                                    .collect::<Result<Vec<_>>>()?;
+
+                                select.projection(items);
+                            }
+                        }
                     } else {
                         let items = p
                             .expr
@@ -210,8 +229,8 @@ impl Unparser<'_> {
                 }
             }
             LogicalPlan::Filter(filter) => {
-                if let Some(agg) =
-                    find_agg_node_within_select(plan, 
select.already_projected())
+                if let Some(AggVariant::Aggregate(agg)) =
+                    find_agg_node_within_select(plan, None, 
select.already_projected())
                 {
                     let unprojected = unproject_agg_exprs(&filter.predicate, 
agg)?;
                     let filter_expr = self.expr_to_sql(&unprojected)?;
@@ -265,7 +284,7 @@ impl Unparser<'_> {
                 )
             }
             LogicalPlan::Aggregate(agg) => {
-                // Aggregate nodes are handled simulatenously with Projection 
nodes
+                // Aggregate nodes are handled simultaneously with Projection 
nodes
                 self.select_to_sql_recursively(
                     agg.input.as_ref(),
                     query,
@@ -441,8 +460,14 @@ impl Unparser<'_> {
 
                 Ok(())
             }
-            LogicalPlan::Window(_window) => {
-                not_impl_err!("Unsupported operator: {plan:?}")
+            LogicalPlan::Window(window) => {
+                // Window nodes are handled simultaneously with Projection 
nodes
+                self.select_to_sql_recursively(
+                    window.input.as_ref(),
+                    query,
+                    select,
+                    relation,
+                )
             }
             LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: 
{plan:?}"),
             _ => not_impl_err!("Unsupported operator: {plan:?}"),
diff --git a/datafusion/sql/src/unparser/utils.rs 
b/datafusion/sql/src/unparser/utils.rs
index c1b02c330f..326cd15ba1 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -20,16 +20,24 @@ use datafusion_common::{
     tree_node::{Transformed, TreeNode},
     Result,
 };
-use datafusion_expr::{Aggregate, Expr, LogicalPlan};
+use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};
 
-/// Recursively searches children of [LogicalPlan] to find an Aggregate node 
if one exists
+/// One of the possible aggregation plans which can be found within a single 
select query.
+pub(crate) enum AggVariant<'a> {
+    Aggregate(&'a Aggregate),
+    Window(Vec<&'a Window>),
+}
+
+/// Recursively searches children of [LogicalPlan] to find an Aggregate or 
window node if one exists
 /// prior to encountering a Join, TableScan, or a nested subquery (derived 
table factor).
-/// If an Aggregate node is not found prior to this or at all before reaching 
the end
-/// of the tree, None is returned.
-pub(crate) fn find_agg_node_within_select(
-    plan: &LogicalPlan,
+/// If an Aggregate or window node is not found prior to this or at all before 
reaching the end
+/// of the tree, None is returned. It is assumed that a Window and Aggegate 
node cannot both
+/// be found in a single select query.
+pub(crate) fn find_agg_node_within_select<'a>(
+    plan: &'a LogicalPlan,
+    mut prev_windows: Option<AggVariant<'a>>,
     already_projected: bool,
-) -> Option<&Aggregate> {
+) -> Option<AggVariant<'a>> {
     // Note that none of the nodes that have a corresponding agg node can have 
more
     // than 1 input node. E.g. Projection / Filter always have 1 input node.
     let input = plan.inputs();
@@ -38,18 +46,29 @@ pub(crate) fn find_agg_node_within_select(
     } else {
         input.first()?
     };
+    // Agg nodes explicitly return immediately with a single node
+    // Window nodes accumulate in a vec until encountering a TableScan or 2nd 
projection
     if let LogicalPlan::Aggregate(agg) = input {
-        Some(agg)
+        Some(AggVariant::Aggregate(agg))
+    } else if let LogicalPlan::Window(window) = input {
+        prev_windows = match &mut prev_windows {
+            Some(AggVariant::Window(windows)) => {
+                windows.push(window);
+                prev_windows
+            }
+            _ => Some(AggVariant::Window(vec![window])),
+        };
+        find_agg_node_within_select(input, prev_windows, already_projected)
     } else if let LogicalPlan::TableScan(_) = input {
-        None
+        prev_windows
     } else if let LogicalPlan::Projection(_) = input {
         if already_projected {
-            None
+            prev_windows
         } else {
-            find_agg_node_within_select(input, true)
+            find_agg_node_within_select(input, prev_windows, true)
         }
     } else {
-        find_agg_node_within_select(input, already_projected)
+        find_agg_node_within_select(input, prev_windows, already_projected)
     }
 }
 
@@ -82,3 +101,28 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: 
&Aggregate) -> Result<Expr>
         })
         .map(|e| e.data)
 }
+
+/// Recursively identify all Column expressions and transform them into the 
appropriate
+/// window expression contained in window.
+///
+/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" 
it will be transformed
+/// into an actual window expression as identified in the window node.
+pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> 
Result<Expr> {
+    expr.clone()
+        .transform(|sub_expr| {
+            if let Expr::Column(c) = sub_expr {
+                if let Some(unproj) = windows
+                    .iter()
+                    .flat_map(|w| w.window_expr.iter())
+                    .find(|window_expr| window_expr.display_name().unwrap() == 
c.name)
+                {
+                    Ok(Transformed::yes(unproj.clone()))
+                } else {
+                    Ok(Transformed::no(Expr::Column(c)))
+                }
+            } else {
+                Ok(Transformed::no(sub_expr))
+            }
+        })
+        .map(|e| e.data)
+}
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 1bf441351a..4a430bdc80 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -127,7 +127,13 @@ fn roundtrip_statement() -> Result<()> {
             UNION ALL
             SELECT j2_string as string FROM j2
             ORDER BY string DESC
-            LIMIT 10"#
+            LIMIT 10"#,
+            "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 
+            last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 
+            first_name from person",
+            r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED 
PRECEDING AND UNBOUNDED FOLLOWING), 
+            sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED 
PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
+            "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 
PRECEDING AND 2 FOLLOWING) from person",            
         ];
 
     // For each test sql string, we transform as follows:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to