devinjdangelo commented on code in PR #9623:
URL: https://github.com/apache/arrow-datafusion/pull/9623#discussion_r1527513778
##########
datafusion/sql/src/unparser/expr.rs:
##########
@@ -151,6 +151,36 @@ impl Unparser<'_> {
order_by: vec![],
}))
}
+ Expr::ScalarSubquery(subq) => {
Review Comment:
I have found that all of the information required to preserve a subquery is
encoded in the subquery expression. This means we can actually completely
ignore subquery LogicalPlan nodes and not lose anything...
At least this is true in every example I have come up with so far :smile:
##########
datafusion/sql/src/unparser/utils.rs:
##########
@@ -0,0 +1,84 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{
+ internal_err,
+ tree_node::{Transformed, TreeNode},
+ Result,
+};
+use datafusion_expr::{Aggregate, Expr, LogicalPlan};
+
+/// Recursively searches children of [LogicalPlan] to find an Aggregate node
if one exists
+/// prior to encountering a Join, TableScan, or a nested subquery (derived
table factor).
+/// If an Aggregate node is not found prior to this or at all before reaching
the end
+/// of the tree, None is returned.
+pub(crate) fn find_agg_node_within_select(
+ plan: &LogicalPlan,
+ already_projected: bool,
+) -> Option<&Aggregate> {
+ // Note that none of the nodes that have a corresponding agg node can have
more
+ // than 1 input node. E.g. Projection / Filter always have 1 input node.
+ let input = plan.inputs();
+ let input = if input.len() > 1 {
+ return None;
+ } else {
+ input.first()?
+ };
+ if let LogicalPlan::Aggregate(agg) = input {
+ Some(agg)
+ } else if let LogicalPlan::TableScan(_) = input {
+ None
+ } else if let LogicalPlan::Projection(_) = input {
+ if already_projected {
+ None
+ } else {
+ find_agg_node_within_select(input, true)
+ }
+ } else {
+ find_agg_node_within_select(input, already_projected)
+ }
+}
+
+/// Recursively identify all Column expressions and transform them into the
appropriate
+/// aggregate expression contained in agg.
+///
+/// For example, if expr contains the column expr "COUNT(*)" it will be
transformed
+/// into an actual aggregate expression COUNT(*) as identified in the
aggregate node.
+pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) ->
Result<Expr> {
+ expr.clone()
+ .transform(&|sub_expr| {
Review Comment:
Here we are borrowing the machinery of logical planning / optimization to
unwind logical planning. :laughing:
##########
datafusion/sql/src/unparser/expr.rs:
##########
@@ -169,7 +199,7 @@ impl Unparser<'_> {
pub(super) fn new_ident(&self, str: String) -> ast::Ident {
ast::Ident {
value: str,
- quote_style: self.dialect.identifier_quote_style(),
+ quote_style:
Some(self.dialect.identifier_quote_style().unwrap_or('"')),
Review Comment:
Identifiers are now always quoted. This makes unparser work for columns with
spaces or other strangeness without any other special handling required.
##########
datafusion/sql/tests/sql_integration.rs:
##########
@@ -4530,79 +4552,72 @@ fn roundtrip_expr() {
}
#[test]
-fn roundtrip_statement() {
- let tests: Vec<(&str, &str)> = vec![
- (
+fn roundtrip_statement() -> Result<()> {
+ let tests: Vec<&str> = vec![
"select ta.j1_id from j1 ta;",
- r#"SELECT ta.j1_id FROM j1 AS ta"#,
- ),
- (
"select ta.j1_id from j1 ta order by ta.j1_id;",
- r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS
LAST"#,
- ),
- (
"select * from j1 ta order by ta.j1_id, ta.j1_string desc;",
- r#"SELECT ta.j1_id, ta.j1_string FROM j1 AS ta ORDER BY ta.j1_id
ASC NULLS LAST, ta.j1_string DESC NULLS FIRST"#,
- ),
- (
"select * from j1 limit 10;",
- r#"SELECT j1.j1_id, j1.j1_string FROM j1 LIMIT 10"#,
- ),
- (
"select ta.j1_id from j1 ta where ta.j1_id > 1;",
- r#"SELECT ta.j1_id FROM j1 AS ta WHERE (ta.j1_id > 1)"#,
- ),
- (
"select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id
= tb.j2_id);",
- r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON
(ta.j1_id = tb.j2_id)"#,
- ),
- (
"select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb
on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
- r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN
j2 AS tb ON (ta.j1_id = tb.j2_id) JOIN j3 AS tc ON (ta.j1_id = tc.j3_id)"#,
- ),
- (
"select * from (select id, first_name from person)",
- "SELECT person.id, person.first_name FROM (SELECT person.id,
person.first_name FROM person)"
- ),
- (
"select * from (select id, first_name from (select * from
person))",
- "SELECT person.id, person.first_name FROM (SELECT person.id,
person.first_name FROM (SELECT person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀 FROM
person))"
- ),
- (
"select id, count(*) as cnt from (select id from person) group by
id",
- "SELECT person.id, COUNT(*) AS cnt FROM (SELECT person.id FROM
person) GROUP BY person.id"
- ),
- (
+ "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from
(select (id-1) as id from person) group by id",
+ r#"select "First Name" from person_quoted_cols"#,
+ r#"select id, count("First Name") as cnt from (select id, "First
Name" from person_quoted_cols) group by id"#,
"select id, count(*) as cnt from (select p1.id as id from person
p1 inner join person p2 on p1.id=p2.id) group by id",
- "SELECT p1.id, COUNT(*) AS cnt FROM (SELECT p1.id FROM person AS
p1 JOIN person AS p2 ON (p1.id = p2.id)) GROUP BY p1.id"
- ),
- (
"select id, count(*), first_name from person group by first_name,
id",
- "SELECT person.id, COUNT(*), person.first_name FROM person GROUP
BY person.first_name, person.id"
- ),
- ];
+ "select id, sum(age), first_name from person group by first_name,
id",
+ "select id, count(*), first_name
+ from person
+ where id!=3 and first_name=='test'
+ group by first_name, id
+ having count(*)>5 and count(*)<10
+ order by count(*)",
+ r#"select id, count("First Name") as count_first_name, "Last Name"
+ from person_quoted_cols
+ where id!=3 and "First Name"=='test'
+ group by "Last Name", id
+ having count_first_name>5 and count_first_name<10
+ order by count_first_name, "Last Name""#,
+ r#"select p.id, count("First Name") as count_first_name,
+ "Last Name", sum(qp.id/p.id - (select sum(id) from
person_quoted_cols) ) / (select count(*) from person)
+ from (select id, "First Name", "Last Name" from
person_quoted_cols) qp
+ inner join (select * from person) p
+ on p.id = qp.id
+ where p.id!=3 and "First Name"=='test' and qp.id in
+ (select id from (select id, count(*) from person group by id
having count(*) > 0))
+ group by "Last Name", p.id
+ having count_first_name>5 and count_first_name<10
+ order by count_first_name, "Last Name""#,
+ ];
- let roundtrip = |sql: &str| -> Result<String> {
+ for query in tests {
let dialect = GenericDialect {};
- let statement =
Parser::new(&dialect).try_with_sql(sql)?.parse_statement()?;
+ let statement = Parser::new(&dialect)
+ .try_with_sql(query)?
+ .parse_statement()?;
let context = MockContextProvider::default();
let sql_to_rel = SqlToRel::new(&context);
- let plan = sql_to_rel.sql_statement_to_plan(statement)?;
+ let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
- println!("{}", plan.display_indent());
+ let roundtrip_statement = plan_to_sql(&plan)?;
- let ast = plan_to_sql(&plan)?;
+ let actual = format!("{}", &roundtrip_statement);
+ println!("roundtrip sql: {actual}");
+ println!("plan {}", plan.display_indent());
- println!("{ast}");
+ let plan_roundtrip = sql_to_rel
+ .sql_statement_to_plan(roundtrip_statement.clone())
+ .unwrap();
- Ok(format!("{}", ast))
- };
-
- for (query, expected) in tests {
- let actual = roundtrip(query).unwrap();
- assert_eq!(actual, expected);
+ assert_eq!(plan, plan_roundtrip);
Review Comment:
As mentioned in the description, the key condition tested is that the
logical plan derived from the original AST and recovered AST are the same
logical plan.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]