This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ad273cab8b Improve unparsing for `ORDER BY`, `UNION`, Windows 
functions with Aggregation (#12946)
ad273cab8b is described below

commit ad273cab8bf300a704baf005df072bb980645e51
Author: Sergei Grebnov <[email protected]>
AuthorDate: Thu Oct 17 10:23:29 2024 -0700

    Improve unparsing for `ORDER BY`, `UNION`, Windows functions with 
Aggregation (#12946)
    
    * Improve unparsing for ORDER BY with Aggregation functions (#38)
    
    * Improve UNION unparsing (#39)
    
    * Scalar functions in ORDER BY unparsing support (#41)
    
    * Improve unparsing for complex Window functions with Aggregation (#42)
    
    * WindowFunction order_by should respect `supports_nulls_first_in_sort` 
dialect setting (#43)
    
    * Fix plan_to_sql
    
    * Improve
---
 datafusion/sql/src/unparser/expr.rs       | 10 ++---
 datafusion/sql/src/unparser/plan.rs       | 42 +++++++++++++------
 datafusion/sql/src/unparser/utils.rs      | 69 ++++++++++++++++++++++++-------
 datafusion/sql/tests/cases/plan_to_sql.rs | 63 +++++++++++++++++++++++++++-
 4 files changed, 148 insertions(+), 36 deletions(-)

diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index 1be5aa68bf..8864c97bb1 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -76,11 +76,6 @@ pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
     unparser.expr_to_sql(expr)
 }
 
-pub fn sort_to_sql(sort: &Sort) -> Result<ast::OrderByExpr> {
-    let unparser = Unparser::default();
-    unparser.sort_to_sql(sort)
-}
-
 const LOWEST: &BinaryOperator = &BinaryOperator::Or;
 // Closest precedence we have to IS operator is BitwiseAnd (any other) in PG 
docs
 // (https://www.postgresql.org/docs/7.2/sql-precedence.html)
@@ -229,9 +224,10 @@ impl Unparser<'_> {
                         ast::WindowFrameUnits::Groups
                     }
                 };
-                let order_by: Vec<ast::OrderByExpr> = order_by
+
+                let order_by = order_by
                     .iter()
-                    .map(sort_to_sql)
+                    .map(|sort_expr| self.sort_to_sql(sort_expr))
                     .collect::<Result<Vec<_>>>()?;
 
                 let start_bound = 
self.convert_bound(&window_frame.start_bound)?;
diff --git a/datafusion/sql/src/unparser/plan.rs 
b/datafusion/sql/src/unparser/plan.rs
index 9b4818b98c..c22400f1fa 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -27,7 +27,7 @@ use super::{
     },
     utils::{
         find_agg_node_within_select, find_window_nodes_within_select,
-        unproject_window_exprs,
+        unproject_sort_expr, unproject_window_exprs,
     },
     Unparser,
 };
@@ -352,19 +352,30 @@ impl Unparser<'_> {
                 if select.already_projected() {
                     return self.derive(plan, relation);
                 }
-                if let Some(query_ref) = query {
-                    if let Some(fetch) = sort.fetch {
-                        
query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
-                            fetch.to_string(),
-                            false,
-                        ))));
-                    }
-                    query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?);
-                } else {
+                let Some(query_ref) = query else {
                     return internal_err!(
                         "Sort operator only valid in a statement context."
                     );
-                }
+                };
+
+                if let Some(fetch) = sort.fetch {
+                    query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
+                        fetch.to_string(),
+                        false,
+                    ))));
+                };
+
+                let agg = find_agg_node_within_select(plan, 
select.already_projected());
+                // unproject sort expressions
+                let sort_exprs: Vec<SortExpr> = sort
+                    .expr
+                    .iter()
+                    .map(|sort_expr| {
+                        unproject_sort_expr(sort_expr, agg, 
sort.input.as_ref())
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+
+                query_ref.order_by(self.sorts_to_sql(&sort_exprs)?);
 
                 self.select_to_sql_recursively(
                     sort.input.as_ref(),
@@ -402,7 +413,7 @@ impl Unparser<'_> {
                             .collect::<Result<Vec<_>>>()?;
                         if let Some(sort_expr) = &on.sort_expr {
                             if let Some(query_ref) = query {
-                                
query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?);
+                                
query_ref.order_by(self.sorts_to_sql(sort_expr)?);
                             } else {
                                 return internal_err!(
                                     "Sort operator only valid in a statement 
context."
@@ -546,6 +557,11 @@ impl Unparser<'_> {
                     );
                 }
 
+                // Covers cases where the UNION is a subquery and the 
projection is at the top level
+                if select.already_projected() {
+                    return self.derive(plan, relation);
+                }
+
                 let input_exprs: Vec<SetExpr> = union
                     .inputs
                     .iter()
@@ -691,7 +707,7 @@ impl Unparser<'_> {
         }
     }
 
-    fn sorts_to_sql(&self, sort_exprs: Vec<SortExpr>) -> 
Result<Vec<ast::OrderByExpr>> {
+    fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) -> 
Result<Vec<ast::OrderByExpr>> {
         sort_exprs
             .iter()
             .map(|sort_expr| self.sort_to_sql(sort_expr))
diff --git a/datafusion/sql/src/unparser/utils.rs 
b/datafusion/sql/src/unparser/utils.rs
index e8c4eca569..5e3a3aa600 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -20,10 +20,11 @@ use std::cmp::Ordering;
 use datafusion_common::{
     internal_err,
     tree_node::{Transformed, TreeNode},
-    Column, DataFusionError, Result, ScalarValue,
+    Column, Result, ScalarValue,
 };
 use datafusion_expr::{
-    utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
+    utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, 
SortExpr,
+    Window,
 };
 use sqlparser::ast;
 
@@ -118,21 +119,11 @@ pub(crate) fn unproject_agg_exprs(
             if let Expr::Column(c) = sub_expr {
                 if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
                     Ok(Transformed::yes(unprojected_expr.clone()))
-                } else if let Some(mut unprojected_expr) =
+                } else if let Some(unprojected_expr) =
                     windows.and_then(|w| find_window_expr(w, &c.name).cloned())
                 {
-                    if let Expr::WindowFunction(func) = &mut unprojected_expr {
-                        // Window function can contain an aggregation column, 
e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
-                        func.args.iter_mut().try_for_each(|arg| {
-                            if let Expr::Column(c) = arg {
-                                if let Some(expr) = find_agg_expr(agg, c)? {
-                                    *arg = expr.clone();
-                                }
-                            }
-                            Ok::<(), DataFusionError>(())
-                        })?;
-                    }
-                    Ok(Transformed::yes(unprojected_expr))
+                    // Window function can contain an aggregation columns, 
e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
+                    return 
Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?));
                 } else {
                     internal_err!(
                         "Tried to unproject agg expr for column '{}' that was 
not found in the provided Aggregate!", &c.name
@@ -200,6 +191,54 @@ fn find_window_expr<'a>(
         .find(|expr| expr.schema_name().to_string() == column_name)
 }
 
+/// Transforms a Column expression into the actual expression from aggregation 
or projection if found.
+/// This is required because if an ORDER BY expression is present in an 
Aggregate or Select, it is replaced
+/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We 
need to transform it back to
+/// the actual expression, such as sum("catalog_returns"."cr_net_loss").
+pub(crate) fn unproject_sort_expr(
+    sort_expr: &SortExpr,
+    agg: Option<&Aggregate>,
+    input: &LogicalPlan,
+) -> Result<SortExpr> {
+    let mut sort_expr = sort_expr.clone();
+
+    // Remove alias if present, because ORDER BY cannot use aliases
+    if let Expr::Alias(alias) = &sort_expr.expr {
+        sort_expr.expr = *alias.expr.clone();
+    }
+
+    let Expr::Column(ref col_ref) = sort_expr.expr else {
+        return Ok(sort_expr);
+    };
+
+    if col_ref.relation.is_some() {
+        return Ok(sort_expr);
+    };
+
+    // In case of aggregation there could be columns containing aggregation 
functions we need to unproject
+    if let Some(agg) = agg {
+        if agg.schema.is_column_from_schema(col_ref) {
+            let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?;
+            sort_expr.expr = new_expr;
+            return Ok(sort_expr);
+        }
+    }
+
+    // If SELECT and ORDER BY contain the same expression with a scalar 
function, the ORDER BY expression will
+    // be replaced by a Column expression (e.g., "substr(customer.c_last_name, 
Int64(0), Int64(5))"), and we need
+    // to transform it back to the actual expression.
+    if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input {
+        if let Ok(idx) = schema.index_of_column(col_ref) {
+            if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
+                sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone());
+            }
+        }
+        return Ok(sort_expr);
+    }
+
+    Ok(sort_expr)
+}
+
 /// Converts a date_part function to SQL, tailoring it to the supported date 
field extraction style.
 pub(crate) fn date_part_to_sql(
     unparser: &Unparser,
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index e4e5d6a929..74abdf075f 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -22,6 +22,9 @@ use arrow_schema::*;
 use datafusion_common::{DFSchema, Result, TableReference};
 use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, 
sum_udaf};
 use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
+use datafusion_functions::unicode;
+use datafusion_functions_aggregate::grouping::grouping_udaf;
+use datafusion_functions_window::rank::rank_udwf;
 use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
 use datafusion_sql::unparser::dialect::{
     DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
@@ -139,6 +142,13 @@ fn roundtrip_statement() -> Result<()> {
             SELECT j2_string as string FROM j2
             ORDER BY string DESC
             LIMIT 10"#,
+            r#"SELECT col1, id FROM (
+                SELECT j1_string AS col1, j1_id AS id FROM j1
+                UNION ALL
+                SELECT j2_string AS col1, j2_id AS id FROM j2
+                UNION ALL
+                SELECT j3_string AS col1, j3_id AS id FROM j3
+            ) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#,
             "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
             last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
             first_name from person",
@@ -657,7 +667,12 @@ where
         .unwrap();
 
     let context = MockContextProvider {
-        state: MockSessionState::default(),
+        state: MockSessionState::default()
+            .with_aggregate_function(sum_udaf())
+            .with_aggregate_function(max_udaf())
+            .with_aggregate_function(grouping_udaf())
+            .with_window_function(rank_udwf())
+            
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone())),
     };
     let sql_to_rel = SqlToRel::new(&context);
     let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
@@ -969,3 +984,49 @@ fn test_with_offset0() {
 fn test_with_offset95() {
     sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET 
95");
 }
+
+#[test]
+fn test_order_by_to_sql() {
+    // order by aggregation function
+    sql_round_trip(
+        GenericDialect {},
+        r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name 
ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#,
+        r#"SELECT person.id, person.first_name, sum(person.id) FROM person 
GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, 
person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name 
ASC NULLS LAST LIMIT 10"#,
+    );
+
+    // order by aggregation function alias
+    sql_round_trip(
+        GenericDialect {},
+        r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY 
id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 
10"#,
+        r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum 
FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS 
LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, 
person.first_name ASC NULLS LAST LIMIT 10"#,
+    );
+
+    // order by scalar function from projection
+    sql_round_trip(
+        GenericDialect {},
+        r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY 
id, substr(first_name,0,5)"#,
+        r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 
5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 
5) ASC NULLS LAST"#,
+    );
+}
+
+#[test]
+fn test_aggregation_to_sql() {
+    sql_round_trip(
+        GenericDialect {},
+        r#"SELECT id, first_name,
+        SUM(id) AS total_sum,
+        SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 
FOLLOWING) AS moving_sum,
+        MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED 
PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
+        rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN 
grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1,
+        rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN 
(CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS 
rank_within_parent_2
+        FROM person
+        GROUP BY id, first_name;"#,
+        r#"SELECT person.id, person.first_name,
+sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY 
person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum,
+max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
+rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE 
WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC 
NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS 
rank_within_parent_1,
+rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE 
WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY 
sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT 
ROW) AS rank_within_parent_2
+FROM person
+GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(),
+    );
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to