This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 86030a1ff7 fix: invalid sqls when unparsing derived table with columns 
contains calculations, limit/order/distinct (#11756)
86030a1ff7 is described below

commit 86030a1ff713cc9709a81a1e9df82d4d13b8818d
Author: yfu <[email protected]>
AuthorDate: Fri Aug 9 06:39:19 2024 +1000

    fix: invalid sqls when unparsing derived table with columns contains 
calculations, limit/order/distinct (#11756)
    
    * Fix unparser derived table with columns include calculations, 
limit/order/distinct (#24)
    
    * compare format output to make sure the two level of projects match
    
    * add method to find inner projection that could be nested under 
limit/order/distinct
    
    * use format! for matching in unparser sort optimization too
    
    * refactor
    
    * use to_string and also put comments in
    
    * clippy
    
    * fix unparser derived table contains cast (#25)
    
    * fix unparser derived table contains cast
    
    * remove dbg
---
 datafusion/sql/src/unparser/plan.rs       |  67 ++----------------
 datafusion/sql/src/unparser/rewrite.rs    | 109 ++++++++++++++++++++++++++++--
 datafusion/sql/tests/cases/plan_to_sql.rs |  32 +++++++++
 3 files changed, 139 insertions(+), 69 deletions(-)

diff --git a/datafusion/sql/src/unparser/plan.rs 
b/datafusion/sql/src/unparser/plan.rs
index e08f25d3c2..277efd5fe7 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -30,8 +30,10 @@ use super::{
         BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
         SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
     },
-    rewrite::normalize_union_schema,
-    rewrite::rewrite_plan_for_sort_on_non_projected_fields,
+    rewrite::{
+        normalize_union_schema, rewrite_plan_for_sort_on_non_projected_fields,
+        subquery_alias_inner_query_and_columns,
+    },
     utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
     Unparser,
 };
@@ -687,67 +689,6 @@ impl Unparser<'_> {
     }
 }
 
-// This logic is to work out the columns and inner query for SubqueryAlias 
plan for both types of
-// subquery
-// - `(SELECT column_a as a from table) AS A`
-// - `(SELECT column_a from table) AS A (a)`
-//
-// A roundtrip example for table alias with columns
-//
-// query: SELECT id FROM (SELECT j1_id from j1) AS c (id)
-//
-// LogicPlan:
-// Projection: c.id
-//   SubqueryAlias: c
-//     Projection: j1.j1_id AS id
-//       Projection: j1.j1_id
-//         TableScan: j1
-//
-// Before introducing this logic, the unparsed query would be `SELECT c.id 
FROM (SELECT j1.j1_id AS
-// id FROM (SELECT j1.j1_id FROM j1)) AS c`.
-// The query is invalid as `j1.j1_id` is not a valid identifier in the derived 
table
-// `(SELECT j1.j1_id FROM j1)`
-//
-// With this logic, the unparsed query will be:
-// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)`
-//
-// Caveat: this won't handle the case like `select * from (select 1, 2) AS a 
(b, c)`
-// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal 
and
-// Column in the Projections. Once the parser side is fixed, this logic should 
work
-fn subquery_alias_inner_query_and_columns(
-    subquery_alias: &datafusion_expr::SubqueryAlias,
-) -> (&LogicalPlan, Vec<Ident>) {
-    let plan: &LogicalPlan = subquery_alias.input.as_ref();
-
-    let LogicalPlan::Projection(outer_projections) = plan else {
-        return (plan, vec![]);
-    };
-
-    // check if it's projection inside projection
-    let LogicalPlan::Projection(inner_projection) = 
outer_projections.input.as_ref()
-    else {
-        return (plan, vec![]);
-    };
-
-    let mut columns: Vec<Ident> = vec![];
-    // check if the inner projection and outer projection have a matching 
pattern like
-    //     Projection: j1.j1_id AS id
-    //       Projection: j1.j1_id
-    for (i, inner_expr) in inner_projection.expr.iter().enumerate() {
-        let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else {
-            return (plan, vec![]);
-        };
-
-        if outer_alias.expr.as_ref() != inner_expr {
-            return (plan, vec![]);
-        };
-
-        columns.push(outer_alias.name.as_str().into());
-    }
-
-    (outer_projections.input.as_ref(), columns)
-}
-
 impl From<BuilderError> for DataFusionError {
     fn from(e: BuilderError) -> Self {
         DataFusionError::External(Box::new(e))
diff --git a/datafusion/sql/src/unparser/rewrite.rs 
b/datafusion/sql/src/unparser/rewrite.rs
index fba95ad48f..f6725485f9 100644
--- a/datafusion/sql/src/unparser/rewrite.rs
+++ b/datafusion/sql/src/unparser/rewrite.rs
@@ -25,6 +25,7 @@ use datafusion_common::{
     Result,
 };
 use datafusion_expr::{Expr, LogicalPlan, Projection, Sort};
+use sqlparser::ast::Ident;
 
 /// Normalize the schema of a union plan to remove qualifiers from the schema 
fields and sort expressions.
 ///
@@ -137,14 +138,25 @@ pub(super) fn 
rewrite_plan_for_sort_on_non_projected_fields(
     let inner_exprs = inner_p
         .expr
         .iter()
-        .map(|f| {
-            if let Expr::Alias(alias) = f {
+        .enumerate()
+        .map(|(i, f)| match f {
+            Expr::Alias(alias) => {
                 let a = Expr::Column(alias.name.clone().into());
                 map.insert(a.clone(), f.clone());
                 a
-            } else {
+            }
+            Expr::Column(_) => {
+                map.insert(
+                    Expr::Column(inner_p.schema.field(i).name().into()),
+                    f.clone(),
+                );
                 f.clone()
             }
+            _ => {
+                let a = Expr::Column(inner_p.schema.field(i).name().into());
+                map.insert(a.clone(), f.clone());
+                a
+            }
         })
         .collect::<Vec<_>>();
 
@@ -155,9 +167,17 @@ pub(super) fn 
rewrite_plan_for_sort_on_non_projected_fields(
         }
     }
 
-    if collects.iter().collect::<HashSet<_>>()
-        == inner_exprs.iter().collect::<HashSet<_>>()
-    {
+    // Compare outer collects Expr::to_string with inner collected transformed 
values
+    // alias -> alias column
+    // column -> remain
+    // others, extract schema field name
+    let outer_collects = 
collects.iter().map(Expr::to_string).collect::<HashSet<_>>();
+    let inner_collects = inner_exprs
+        .iter()
+        .map(Expr::to_string)
+        .collect::<HashSet<_>>();
+
+    if outer_collects == inner_collects {
         let mut sort = sort.clone();
         let mut inner_p = inner_p.clone();
 
@@ -175,3 +195,80 @@ pub(super) fn 
rewrite_plan_for_sort_on_non_projected_fields(
         None
     }
 }
+
+// This logic is to work out the columns and inner query for SubqueryAlias 
plan for both types of
+// subquery
+// - `(SELECT column_a as a from table) AS A`
+// - `(SELECT column_a from table) AS A (a)`
+//
+// A roundtrip example for table alias with columns
+//
+// query: SELECT id FROM (SELECT j1_id from j1) AS c (id)
+//
+// LogicPlan:
+// Projection: c.id
+//   SubqueryAlias: c
+//     Projection: j1.j1_id AS id
+//       Projection: j1.j1_id
+//         TableScan: j1
+//
+// Before introducing this logic, the unparsed query would be `SELECT c.id 
FROM (SELECT j1.j1_id AS
+// id FROM (SELECT j1.j1_id FROM j1)) AS c`.
+// The query is invalid as `j1.j1_id` is not a valid identifier in the derived 
table
+// `(SELECT j1.j1_id FROM j1)`
+//
+// With this logic, the unparsed query will be:
+// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)`
+//
+// Caveat: this won't handle the case like `select * from (select 1, 2) AS a 
(b, c)`
+// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal 
and
+// Column in the Projections. Once the parser side is fixed, this logic should 
work
+pub(super) fn subquery_alias_inner_query_and_columns(
+    subquery_alias: &datafusion_expr::SubqueryAlias,
+) -> (&LogicalPlan, Vec<Ident>) {
+    let plan: &LogicalPlan = subquery_alias.input.as_ref();
+
+    let LogicalPlan::Projection(outer_projections) = plan else {
+        return (plan, vec![]);
+    };
+
+    // check if it's projection inside projection
+    let Some(inner_projection) = 
find_projection(outer_projections.input.as_ref()) else {
+        return (plan, vec![]);
+    };
+
+    let mut columns: Vec<Ident> = vec![];
+    // check if the inner projection and outer projection have a matching 
pattern like
+    //     Projection: j1.j1_id AS id
+    //       Projection: j1.j1_id
+    for (i, inner_expr) in inner_projection.expr.iter().enumerate() {
+        let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else {
+            return (plan, vec![]);
+        };
+
+        // inner projection schema fields store the projection name which is 
used in outer
+        // projection expr
+        let inner_expr_string = match inner_expr {
+            Expr::Column(_) => inner_expr.to_string(),
+            _ => inner_projection.schema.field(i).name().clone(),
+        };
+
+        if outer_alias.expr.to_string() != inner_expr_string {
+            return (plan, vec![]);
+        };
+
+        columns.push(outer_alias.name.as_str().into());
+    }
+
+    (outer_projections.input.as_ref(), columns)
+}
+
+fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> {
+    match logical_plan {
+        LogicalPlan::Projection(p) => Some(p),
+        LogicalPlan::Limit(p) => find_projection(p.input.as_ref()),
+        LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()),
+        LogicalPlan::Sort(p) => find_projection(p.input.as_ref()),
+        _ => None,
+    }
+}
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 2ac3034873..9bbdbe8dbf 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -373,6 +373,38 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
             parser_dialect: Box::new(GenericDialect {}),
             unparser_dialect: Box::new(UnparserDefaultDialect {}),
         },
+        // Test query that has calculation in derived table with columns
+        TestStatementWithDialect {
+            sql: "SELECT id FROM (SELECT j1_id + 1 * 3 from j1) AS c (id)",
+            expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM 
j1) AS c (id)"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        // Test query that has limit/distinct/order in derived table with 
columns
+        TestStatementWithDialect {
+            sql: "SELECT id FROM (SELECT distinct (j1_id + 1 * 3) FROM j1 
LIMIT 1) AS c (id)",
+            expected: r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 
3)) FROM j1 LIMIT 1) AS c (id)"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        TestStatementWithDialect {
+            sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC 
LIMIT 1) AS c (id)",
+            expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER 
BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        TestStatementWithDialect {
+            sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as 
int) * 10 FROM j1 LIMIT 1) AS c (id)",
+            expected: r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS 
BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        TestStatementWithDialect {
+            sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 
ORDER BY j1_id LIMIT 1) AS c (id)",
+            expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 
1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        }
     ];
 
     for query in tests {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to