This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 71a99b8462 minor: consolidate unparser integration tests (#10736)
71a99b8462 is described below

commit 71a99b84627de49033037021cfeea1f2cd29db84
Author: Devin D'Angelo <[email protected]>
AuthorDate: Sat Jun 1 16:31:46 2024 -0400

    minor: consolidate unparser integration tests (#10736)
    
    * consolidate unparser integration tests
    
    * add license to new files
    
    * surpress dead code warnings
    
    * run as one integration test binary
    
    * add license
---
 datafusion/sql/tests/cases/mod.rs         |  18 ++
 datafusion/sql/tests/cases/plan_to_sql.rs | 290 ++++++++++++++++++
 datafusion/sql/tests/common/mod.rs        | 227 ++++++++++++++
 datafusion/sql/tests/sql_integration.rs   | 477 +-----------------------------
 4 files changed, 543 insertions(+), 469 deletions(-)

diff --git a/datafusion/sql/tests/cases/mod.rs 
b/datafusion/sql/tests/cases/mod.rs
new file mode 100644
index 0000000000..fc4c59cc88
--- /dev/null
+++ b/datafusion/sql/tests/cases/mod.rs
@@ -0,0 +1,18 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod plan_to_sql;
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
new file mode 100644
index 0000000000..1bf441351a
--- /dev/null
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::vec;
+
+use arrow_schema::*;
+use datafusion_common::{DFSchema, Result, TableReference};
+use datafusion_expr::{col, table_scan};
+use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
+use datafusion_sql::unparser::dialect::{
+    DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
+    MySqlDialect as UnparserMySqlDialect,
+};
+use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};
+
+use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
+use sqlparser::parser::Parser;
+
+use crate::common::MockContextProvider;
+
+#[test]
+fn roundtrip_expr() {
+    let tests: Vec<(TableReference, &str, &str)> = vec![
+        (TableReference::bare("person"), "age > 35", r#"(age > 35)"#),
+        (
+            TableReference::bare("person"),
+            "id = '10'",
+            r#"(id = '10')"#,
+        ),
+        (
+            TableReference::bare("person"),
+            "CAST(id AS VARCHAR)",
+            r#"CAST(id AS VARCHAR)"#,
+        ),
+        (
+            TableReference::bare("person"),
+            "SUM((age * 2))",
+            r#"SUM((age * 2))"#,
+        ),
+    ];
+
+    let roundtrip = |table, sql: &str| -> Result<String> {
+        let dialect = GenericDialect {};
+        let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?;
+
+        let context = MockContextProvider::default();
+        let schema = context.get_table_source(table)?.schema();
+        let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
+        let sql_to_rel = SqlToRel::new(&context);
+        let expr =
+            sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut 
PlannerContext::new())?;
+
+        let ast = expr_to_sql(&expr)?;
+
+        Ok(format!("{}", ast))
+    };
+
+    for (table, query, expected) in tests {
+        let actual = roundtrip(table, query).unwrap();
+        assert_eq!(actual, expected);
+    }
+}
+
+#[test]
+fn roundtrip_statement() -> Result<()> {
+    let tests: Vec<&str> = vec![
+            "select ta.j1_id from j1 ta;",
+            "select ta.j1_id from j1 ta order by ta.j1_id;",
+            "select * from j1 ta order by ta.j1_id, ta.j1_string desc;",
+            "select * from j1 limit 10;",
+            "select ta.j1_id from j1 ta where ta.j1_id > 1;",
+            "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id 
= tb.j2_id);",
+            "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb 
on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
+            "select * from (select id, first_name from person)",
+            "select * from (select id, first_name from (select * from 
person))",
+            "select id, count(*) as cnt from (select id from person) group by 
id",
+            "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from 
(select (id-1) as id from person) group by id",
+            "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where 
NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = 
tb.j2_id))",
+            r#"select "First Name" from person_quoted_cols"#,
+            "select DISTINCT id FROM person",
+            "select DISTINCT on (id) id, first_name from person",
+            "select DISTINCT on (id) id, first_name from person order by id",
+            r#"select id, count("First Name") as cnt from (select id, "First 
Name" from person_quoted_cols) group by id"#,
+            "select id, count(*) as cnt from (select p1.id as id from person 
p1 inner join person p2 on p1.id=p2.id) group by id",
+            "select id, count(*), first_name from person group by first_name, 
id",
+            "select id, sum(age), first_name from person group by first_name, 
id",
+            "select id, count(*), first_name 
+            from person 
+            where id!=3 and first_name=='test'
+            group by first_name, id 
+            having count(*)>5 and count(*)<10
+            order by count(*)",
+            r#"select id, count("First Name") as count_first_name, "Last Name" 
+            from person_quoted_cols
+            where id!=3 and "First Name"=='test'
+            group by "Last Name", id 
+            having count_first_name>5 and count_first_name<10
+            order by count_first_name, "Last Name""#,
+            r#"select p.id, count("First Name") as count_first_name,
+            "Last Name", sum(qp.id/p.id - (select sum(id) from 
person_quoted_cols) ) / (select count(*) from person) 
+            from (select id, "First Name", "Last Name" from 
person_quoted_cols) qp
+            inner join (select * from person) p
+            on p.id = qp.id
+            where p.id!=3 and "First Name"=='test' and qp.id in 
+            (select id from (select id, count(*) from person group by id 
having count(*) > 0))
+            group by "Last Name", p.id 
+            having count_first_name>5 and count_first_name<10
+            order by count_first_name, "Last Name""#,
+            r#"SELECT j1_string as string FROM j1
+            UNION ALL
+            SELECT j2_string as string FROM j2"#,
+            r#"SELECT j1_string as string FROM j1
+            UNION ALL
+            SELECT j2_string as string FROM j2
+            ORDER BY string DESC
+            LIMIT 10"#
+        ];
+
+    // For each test sql string, we transform as follows:
+    // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) 
-> LogicalPlan (p2)
+    // We test not that s1==s2, but rather p1==p2. This ensures that unparser 
preserves the logical
+    // query information of the original sql string and disreguards other 
differences in syntax or
+    // quoting.
+    for query in tests {
+        let dialect = GenericDialect {};
+        let statement = Parser::new(&dialect)
+            .try_with_sql(query)?
+            .parse_statement()?;
+
+        let context = MockContextProvider::default();
+        let sql_to_rel = SqlToRel::new(&context);
+        let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+        let roundtrip_statement = plan_to_sql(&plan)?;
+
+        let actual = format!("{}", &roundtrip_statement);
+        println!("roundtrip sql: {actual}");
+        println!("plan {}", plan.display_indent());
+
+        let plan_roundtrip = sql_to_rel
+            .sql_statement_to_plan(roundtrip_statement.clone())
+            .unwrap();
+
+        assert_eq!(plan, plan_roundtrip);
+    }
+
+    Ok(())
+}
+
+#[test]
+fn roundtrip_crossjoin() -> Result<()> {
+    let query = "select j1.j1_id, j2.j2_string from j1, j2";
+
+    let dialect = GenericDialect {};
+    let statement = Parser::new(&dialect)
+        .try_with_sql(query)?
+        .parse_statement()?;
+
+    let context = MockContextProvider::default();
+    let sql_to_rel = SqlToRel::new(&context);
+    let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+    let roundtrip_statement = plan_to_sql(&plan)?;
+
+    let actual = format!("{}", &roundtrip_statement);
+    println!("roundtrip sql: {actual}");
+    println!("plan {}", plan.display_indent());
+
+    let plan_roundtrip = sql_to_rel
+        .sql_statement_to_plan(roundtrip_statement.clone())
+        .unwrap();
+
+    let expected = "Projection: j1.j1_id, j2.j2_string\
+        \n  Inner Join:  Filter: Boolean(true)\
+        \n    TableScan: j1\
+        \n    TableScan: j2";
+
+    assert_eq!(format!("{plan_roundtrip:?}"), expected);
+
+    Ok(())
+}
+
+#[test]
+fn roundtrip_statement_with_dialect() -> Result<()> {
+    struct TestStatementWithDialect {
+        sql: &'static str,
+        expected: &'static str,
+        parser_dialect: Box<dyn Dialect>,
+        unparser_dialect: Box<dyn UnparserDialect>,
+    }
+    let tests: Vec<TestStatementWithDialect> = vec![
+        TestStatementWithDialect {
+            sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
+            expected:
+                "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` 
ASC LIMIT 10",
+            parser_dialect: Box::new(MySqlDialect {}),
+            unparser_dialect: Box::new(UnparserMySqlDialect {}),
+        },
+        TestStatementWithDialect {
+            sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
+            expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC 
NULLS LAST LIMIT 10"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+    ];
+
+    for query in tests {
+        let statement = Parser::new(&*query.parser_dialect)
+            .try_with_sql(query.sql)?
+            .parse_statement()?;
+
+        let context = MockContextProvider::default();
+        let sql_to_rel = SqlToRel::new(&context);
+        let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+        let unparser = Unparser::new(&*query.unparser_dialect);
+        let roundtrip_statement = unparser.plan_to_sql(&plan)?;
+
+        let actual = format!("{}", &roundtrip_statement);
+        println!("roundtrip sql: {actual}");
+        println!("plan {}", plan.display_indent());
+
+        assert_eq!(query.expected, actual);
+    }
+
+    Ok(())
+}
+
+#[test]
+fn test_unnest_logical_plan() -> Result<()> {
+    let query = "select unnest(struct_col), unnest(array_col), struct_col, 
array_col from unnest_table";
+
+    let dialect = GenericDialect {};
+    let statement = Parser::new(&dialect)
+        .try_with_sql(query)?
+        .parse_statement()?;
+
+    let context = MockContextProvider::default();
+    let sql_to_rel = SqlToRel::new(&context);
+    let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+    let expected = "Projection: unnest(unnest_table.struct_col).field1, 
unnest(unnest_table.struct_col).field2, unnest(unnest_table.array_col), 
unnest_table.struct_col, unnest_table.array_col\
+        \n  Unnest: lists[unnest(unnest_table.array_col)] 
structs[unnest(unnest_table.struct_col)]\
+        \n    Projection: unnest_table.struct_col AS 
unnest(unnest_table.struct_col), unnest_table.array_col AS 
unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\
+        \n      TableScan: unnest_table";
+
+    assert_eq!(format!("{plan:?}"), expected);
+
+    Ok(())
+}
+
+#[test]
+fn test_table_references_in_plan_to_sql() {
+    fn test(table_name: &str, expected_sql: &str) {
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("value", DataType::Utf8, false),
+        ]);
+        let plan = table_scan(Some(table_name), &schema, None)
+            .unwrap()
+            .project(vec![col("id"), col("value")])
+            .unwrap()
+            .build()
+            .unwrap();
+        let sql = plan_to_sql(&plan).unwrap();
+
+        assert_eq!(format!("{}", sql), expected_sql)
+    }
+
+    test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, 
catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\"");
+    test("schema.table", "SELECT \"schema\".\"table\".id, 
\"schema\".\"table\".\"value\" FROM \"schema\".\"table\"");
+    test(
+        "table",
+        "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
+    );
+}
diff --git a/datafusion/sql/tests/common/mod.rs 
b/datafusion/sql/tests/common/mod.rs
new file mode 100644
index 0000000000..79de4bc826
--- /dev/null
+++ b/datafusion/sql/tests/common/mod.rs
@@ -0,0 +1,227 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[cfg(test)]
+use std::collections::HashMap;
+use std::{sync::Arc, vec};
+
+use arrow_schema::*;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::{plan_err, Result, TableReference};
+use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
+use datafusion_sql::planner::ContextProvider;
+
+#[derive(Default)]
+pub(crate) struct MockContextProvider {
+    options: ConfigOptions,
+    udfs: HashMap<String, Arc<ScalarUDF>>,
+    udafs: HashMap<String, Arc<AggregateUDF>>,
+}
+
+impl MockContextProvider {
+    // Surpressing dead code warning, as this is used in integration test 
crates
+    #[allow(dead_code)]
+    pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions {
+        &mut self.options
+    }
+
+    #[allow(dead_code)]
+    pub(crate) fn with_udf(mut self, udf: ScalarUDF) -> Self {
+        self.udfs.insert(udf.name().to_string(), Arc::new(udf));
+        self
+    }
+}
+
+impl ContextProvider for MockContextProvider {
+    fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn 
TableSource>> {
+        let schema = match name.table() {
+            "test" => Ok(Schema::new(vec![
+                Field::new("t_date32", DataType::Date32, false),
+                Field::new("t_date64", DataType::Date64, false),
+            ])),
+            "j1" => Ok(Schema::new(vec![
+                Field::new("j1_id", DataType::Int32, false),
+                Field::new("j1_string", DataType::Utf8, false),
+            ])),
+            "j2" => Ok(Schema::new(vec![
+                Field::new("j2_id", DataType::Int32, false),
+                Field::new("j2_string", DataType::Utf8, false),
+            ])),
+            "j3" => Ok(Schema::new(vec![
+                Field::new("j3_id", DataType::Int32, false),
+                Field::new("j3_string", DataType::Utf8, false),
+            ])),
+            "test_decimal" => Ok(Schema::new(vec![
+                Field::new("id", DataType::Int32, false),
+                Field::new("price", DataType::Decimal128(10, 2), false),
+            ])),
+            "person" => Ok(Schema::new(vec![
+                Field::new("id", DataType::UInt32, false),
+                Field::new("first_name", DataType::Utf8, false),
+                Field::new("last_name", DataType::Utf8, false),
+                Field::new("age", DataType::Int32, false),
+                Field::new("state", DataType::Utf8, false),
+                Field::new("salary", DataType::Float64, false),
+                Field::new(
+                    "birth_date",
+                    DataType::Timestamp(TimeUnit::Nanosecond, None),
+                    false,
+                ),
+                Field::new("😀", DataType::Int32, false),
+            ])),
+            "person_quoted_cols" => Ok(Schema::new(vec![
+                Field::new("id", DataType::UInt32, false),
+                Field::new("First Name", DataType::Utf8, false),
+                Field::new("Last Name", DataType::Utf8, false),
+                Field::new("Age", DataType::Int32, false),
+                Field::new("State", DataType::Utf8, false),
+                Field::new("Salary", DataType::Float64, false),
+                Field::new(
+                    "Birth Date",
+                    DataType::Timestamp(TimeUnit::Nanosecond, None),
+                    false,
+                ),
+                Field::new("😀", DataType::Int32, false),
+            ])),
+            "orders" => Ok(Schema::new(vec![
+                Field::new("order_id", DataType::UInt32, false),
+                Field::new("customer_id", DataType::UInt32, false),
+                Field::new("o_item_id", DataType::Utf8, false),
+                Field::new("qty", DataType::Int32, false),
+                Field::new("price", DataType::Float64, false),
+                Field::new("delivered", DataType::Boolean, false),
+            ])),
+            "array" => Ok(Schema::new(vec![
+                Field::new(
+                    "left",
+                    DataType::List(Arc::new(Field::new("item", 
DataType::Int64, true))),
+                    false,
+                ),
+                Field::new(
+                    "right",
+                    DataType::List(Arc::new(Field::new("item", 
DataType::Int64, true))),
+                    false,
+                ),
+            ])),
+            "lineitem" => Ok(Schema::new(vec![
+                Field::new("l_item_id", DataType::UInt32, false),
+                Field::new("l_description", DataType::Utf8, false),
+                Field::new("price", DataType::Float64, false),
+            ])),
+            "aggregate_test_100" => Ok(Schema::new(vec![
+                Field::new("c1", DataType::Utf8, false),
+                Field::new("c2", DataType::UInt32, false),
+                Field::new("c3", DataType::Int8, false),
+                Field::new("c4", DataType::Int16, false),
+                Field::new("c5", DataType::Int32, false),
+                Field::new("c6", DataType::Int64, false),
+                Field::new("c7", DataType::UInt8, false),
+                Field::new("c8", DataType::UInt16, false),
+                Field::new("c9", DataType::UInt32, false),
+                Field::new("c10", DataType::UInt64, false),
+                Field::new("c11", DataType::Float32, false),
+                Field::new("c12", DataType::Float64, false),
+                Field::new("c13", DataType::Utf8, false),
+            ])),
+            "UPPERCASE_test" => Ok(Schema::new(vec![
+                Field::new("Id", DataType::UInt32, false),
+                Field::new("lower", DataType::UInt32, false),
+            ])),
+            "unnest_table" => Ok(Schema::new(vec![
+                Field::new(
+                    "array_col",
+                    DataType::List(Arc::new(Field::new("item", 
DataType::Int64, true))),
+                    false,
+                ),
+                Field::new(
+                    "struct_col",
+                    DataType::Struct(Fields::from(vec![
+                        Field::new("field1", DataType::Int64, true),
+                        Field::new("field2", DataType::Utf8, true),
+                    ])),
+                    false,
+                ),
+            ])),
+            _ => plan_err!("No table named: {} found", name.table()),
+        };
+
+        match schema {
+            Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))),
+            Err(e) => Err(e),
+        }
+    }
+
+    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
+        self.udfs.get(name).cloned()
+    }
+
+    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
+        self.udafs.get(name).cloned()
+    }
+
+    fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
+        unimplemented!()
+    }
+
+    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
+        None
+    }
+
+    fn options(&self) -> &ConfigOptions {
+        &self.options
+    }
+
+    fn create_cte_work_table(
+        &self,
+        _name: &str,
+        schema: SchemaRef,
+    ) -> Result<Arc<dyn TableSource>> {
+        Ok(Arc::new(EmptyTable::new(schema)))
+    }
+
+    fn udf_names(&self) -> Vec<String> {
+        self.udfs.keys().cloned().collect()
+    }
+
+    fn udaf_names(&self) -> Vec<String> {
+        self.udafs.keys().cloned().collect()
+    }
+
+    fn udwf_names(&self) -> Vec<String> {
+        Vec::new()
+    }
+}
+
+struct EmptyTable {
+    table_schema: SchemaRef,
+}
+
+impl EmptyTable {
+    fn new(table_schema: SchemaRef) -> Self {
+        Self { table_schema }
+    }
+}
+
+impl TableSource for EmptyTable {
+    fn as_any(&self) -> &dyn std::any::Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.table_schema.clone()
+    }
+}
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index a7224805f3..1f064ea0f5 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -18,35 +18,29 @@
 use std::any::Any;
 #[cfg(test)]
 use std::collections::HashMap;
-use std::{sync::Arc, vec};
+use std::vec;
 
 use arrow_schema::TimeUnit::Nanosecond;
 use arrow_schema::*;
-use datafusion_common::config::ConfigOptions;
+use common::MockContextProvider;
 use datafusion_common::{
-    assert_contains, plan_err, DFSchema, DataFusionError, ParamValues, Result,
-    ScalarValue, TableReference,
+    assert_contains, DataFusionError, ParamValues, Result, ScalarValue,
 };
-use datafusion_expr::{col, table_scan};
 use datafusion_expr::{
     logical_plan::{LogicalPlan, Prepare},
-    AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, 
TableSource,
-    Volatility, WindowUDF,
+    ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
 };
 use datafusion_functions::{string, unicode};
-use datafusion_sql::unparser::dialect::{
-    DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
-    MySqlDialect as UnparserMySqlDialect,
-};
-use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};
 use datafusion_sql::{
     parser::DFParser,
-    planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel},
+    planner::{ParserOptions, SqlToRel},
 };
 
 use rstest::rstest;
 use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
-use sqlparser::parser::Parser;
+
+mod cases;
+mod common;
 
 #[test]
 fn test_schema_support() {
@@ -2797,184 +2791,6 @@ fn prepare_stmt_replace_params_quick_test(
     plan
 }
 
-#[derive(Default)]
-struct MockContextProvider {
-    options: ConfigOptions,
-    udfs: HashMap<String, Arc<ScalarUDF>>,
-    udafs: HashMap<String, Arc<AggregateUDF>>,
-}
-
-impl MockContextProvider {
-    fn options_mut(&mut self) -> &mut ConfigOptions {
-        &mut self.options
-    }
-
-    fn with_udf(mut self, udf: ScalarUDF) -> Self {
-        self.udfs.insert(udf.name().to_string(), Arc::new(udf));
-        self
-    }
-}
-
-impl ContextProvider for MockContextProvider {
-    fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn 
TableSource>> {
-        let schema = match name.table() {
-            "test" => Ok(Schema::new(vec![
-                Field::new("t_date32", DataType::Date32, false),
-                Field::new("t_date64", DataType::Date64, false),
-            ])),
-            "j1" => Ok(Schema::new(vec![
-                Field::new("j1_id", DataType::Int32, false),
-                Field::new("j1_string", DataType::Utf8, false),
-            ])),
-            "j2" => Ok(Schema::new(vec![
-                Field::new("j2_id", DataType::Int32, false),
-                Field::new("j2_string", DataType::Utf8, false),
-            ])),
-            "j3" => Ok(Schema::new(vec![
-                Field::new("j3_id", DataType::Int32, false),
-                Field::new("j3_string", DataType::Utf8, false),
-            ])),
-            "test_decimal" => Ok(Schema::new(vec![
-                Field::new("id", DataType::Int32, false),
-                Field::new("price", DataType::Decimal128(10, 2), false),
-            ])),
-            "person" => Ok(Schema::new(vec![
-                Field::new("id", DataType::UInt32, false),
-                Field::new("first_name", DataType::Utf8, false),
-                Field::new("last_name", DataType::Utf8, false),
-                Field::new("age", DataType::Int32, false),
-                Field::new("state", DataType::Utf8, false),
-                Field::new("salary", DataType::Float64, false),
-                Field::new(
-                    "birth_date",
-                    DataType::Timestamp(TimeUnit::Nanosecond, None),
-                    false,
-                ),
-                Field::new("😀", DataType::Int32, false),
-            ])),
-            "person_quoted_cols" => Ok(Schema::new(vec![
-                Field::new("id", DataType::UInt32, false),
-                Field::new("First Name", DataType::Utf8, false),
-                Field::new("Last Name", DataType::Utf8, false),
-                Field::new("Age", DataType::Int32, false),
-                Field::new("State", DataType::Utf8, false),
-                Field::new("Salary", DataType::Float64, false),
-                Field::new(
-                    "Birth Date",
-                    DataType::Timestamp(TimeUnit::Nanosecond, None),
-                    false,
-                ),
-                Field::new("😀", DataType::Int32, false),
-            ])),
-            "orders" => Ok(Schema::new(vec![
-                Field::new("order_id", DataType::UInt32, false),
-                Field::new("customer_id", DataType::UInt32, false),
-                Field::new("o_item_id", DataType::Utf8, false),
-                Field::new("qty", DataType::Int32, false),
-                Field::new("price", DataType::Float64, false),
-                Field::new("delivered", DataType::Boolean, false),
-            ])),
-            "array" => Ok(Schema::new(vec![
-                Field::new(
-                    "left",
-                    DataType::List(Arc::new(Field::new("item", 
DataType::Int64, true))),
-                    false,
-                ),
-                Field::new(
-                    "right",
-                    DataType::List(Arc::new(Field::new("item", 
DataType::Int64, true))),
-                    false,
-                ),
-            ])),
-            "lineitem" => Ok(Schema::new(vec![
-                Field::new("l_item_id", DataType::UInt32, false),
-                Field::new("l_description", DataType::Utf8, false),
-                Field::new("price", DataType::Float64, false),
-            ])),
-            "aggregate_test_100" => Ok(Schema::new(vec![
-                Field::new("c1", DataType::Utf8, false),
-                Field::new("c2", DataType::UInt32, false),
-                Field::new("c3", DataType::Int8, false),
-                Field::new("c4", DataType::Int16, false),
-                Field::new("c5", DataType::Int32, false),
-                Field::new("c6", DataType::Int64, false),
-                Field::new("c7", DataType::UInt8, false),
-                Field::new("c8", DataType::UInt16, false),
-                Field::new("c9", DataType::UInt32, false),
-                Field::new("c10", DataType::UInt64, false),
-                Field::new("c11", DataType::Float32, false),
-                Field::new("c12", DataType::Float64, false),
-                Field::new("c13", DataType::Utf8, false),
-            ])),
-            "UPPERCASE_test" => Ok(Schema::new(vec![
-                Field::new("Id", DataType::UInt32, false),
-                Field::new("lower", DataType::UInt32, false),
-            ])),
-            "unnest_table" => Ok(Schema::new(vec![
-                Field::new(
-                    "array_col",
-                    DataType::List(Arc::new(Field::new("item", 
DataType::Int64, true))),
-                    false,
-                ),
-                Field::new(
-                    "struct_col",
-                    DataType::Struct(Fields::from(vec![
-                        Field::new("field1", DataType::Int64, true),
-                        Field::new("field2", DataType::Utf8, true),
-                    ])),
-                    false,
-                ),
-            ])),
-            _ => plan_err!("No table named: {} found", name.table()),
-        };
-
-        match schema {
-            Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))),
-            Err(e) => Err(e),
-        }
-    }
-
-    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
-        self.udfs.get(name).cloned()
-    }
-
-    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
-        self.udafs.get(name).cloned()
-    }
-
-    fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
-        unimplemented!()
-    }
-
-    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
-        None
-    }
-
-    fn options(&self) -> &ConfigOptions {
-        &self.options
-    }
-
-    fn create_cte_work_table(
-        &self,
-        _name: &str,
-        schema: SchemaRef,
-    ) -> Result<Arc<dyn TableSource>> {
-        Ok(Arc::new(EmptyTable::new(schema)))
-    }
-
-    fn udf_names(&self) -> Vec<String> {
-        self.udfs.keys().cloned().collect()
-    }
-
-    fn udaf_names(&self) -> Vec<String> {
-        self.udafs.keys().cloned().collect()
-    }
-
-    fn udwf_names(&self) -> Vec<String> {
-        Vec::new()
-    }
-}
-
 #[test]
 fn select_partially_qualified_column() {
     let sql = r#"SELECT person.first_name FROM public.person"#;
@@ -4552,283 +4368,6 @@ fn assert_field_not_found(err: DataFusionError, name: 
&str) {
     }
 }
 
-struct EmptyTable {
-    table_schema: SchemaRef,
-}
-
-impl EmptyTable {
-    fn new(table_schema: SchemaRef) -> Self {
-        Self { table_schema }
-    }
-}
-
-impl TableSource for EmptyTable {
-    fn as_any(&self) -> &dyn std::any::Any {
-        self
-    }
-
-    fn schema(&self) -> SchemaRef {
-        self.table_schema.clone()
-    }
-}
-
-#[test]
-fn roundtrip_expr() {
-    let tests: Vec<(TableReference, &str, &str)> = vec![
-        (TableReference::bare("person"), "age > 35", r#"(age > 35)"#),
-        (
-            TableReference::bare("person"),
-            "id = '10'",
-            r#"(id = '10')"#,
-        ),
-        (
-            TableReference::bare("person"),
-            "CAST(id AS VARCHAR)",
-            r#"CAST(id AS VARCHAR)"#,
-        ),
-        (
-            TableReference::bare("person"),
-            "SUM((age * 2))",
-            r#"SUM((age * 2))"#,
-        ),
-    ];
-
-    let roundtrip = |table, sql: &str| -> Result<String> {
-        let dialect = GenericDialect {};
-        let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?;
-
-        let context = MockContextProvider::default();
-        let schema = context.get_table_source(table)?.schema();
-        let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
-        let sql_to_rel = SqlToRel::new(&context);
-        let expr =
-            sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut 
PlannerContext::new())?;
-
-        let ast = expr_to_sql(&expr)?;
-
-        Ok(format!("{}", ast))
-    };
-
-    for (table, query, expected) in tests {
-        let actual = roundtrip(table, query).unwrap();
-        assert_eq!(actual, expected);
-    }
-}
-
-#[test]
-fn roundtrip_statement() -> Result<()> {
-    let tests: Vec<&str> = vec![
-            "select ta.j1_id from j1 ta;",
-            "select ta.j1_id from j1 ta order by ta.j1_id;",
-            "select * from j1 ta order by ta.j1_id, ta.j1_string desc;",
-            "select * from j1 limit 10;",
-            "select ta.j1_id from j1 ta where ta.j1_id > 1;",
-            "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id 
= tb.j2_id);",
-            "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb 
on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
-            "select * from (select id, first_name from person)",
-            "select * from (select id, first_name from (select * from 
person))",
-            "select id, count(*) as cnt from (select id from person) group by 
id",
-            "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from 
(select (id-1) as id from person) group by id",
-            "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where 
NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = 
tb.j2_id))",
-            r#"select "First Name" from person_quoted_cols"#,
-            "select DISTINCT id FROM person",
-            "select DISTINCT on (id) id, first_name from person",
-            "select DISTINCT on (id) id, first_name from person order by id",
-            r#"select id, count("First Name") as cnt from (select id, "First 
Name" from person_quoted_cols) group by id"#,
-            "select id, count(*) as cnt from (select p1.id as id from person 
p1 inner join person p2 on p1.id=p2.id) group by id",
-            "select id, count(*), first_name from person group by first_name, 
id",
-            "select id, sum(age), first_name from person group by first_name, 
id",
-            "select id, count(*), first_name 
-            from person 
-            where id!=3 and first_name=='test'
-            group by first_name, id 
-            having count(*)>5 and count(*)<10
-            order by count(*)",
-            r#"select id, count("First Name") as count_first_name, "Last Name" 
-            from person_quoted_cols
-            where id!=3 and "First Name"=='test'
-            group by "Last Name", id 
-            having count_first_name>5 and count_first_name<10
-            order by count_first_name, "Last Name""#,
-            r#"select p.id, count("First Name") as count_first_name,
-            "Last Name", sum(qp.id/p.id - (select sum(id) from 
person_quoted_cols) ) / (select count(*) from person) 
-            from (select id, "First Name", "Last Name" from 
person_quoted_cols) qp
-            inner join (select * from person) p
-            on p.id = qp.id
-            where p.id!=3 and "First Name"=='test' and qp.id in 
-            (select id from (select id, count(*) from person group by id 
having count(*) > 0))
-            group by "Last Name", p.id 
-            having count_first_name>5 and count_first_name<10
-            order by count_first_name, "Last Name""#,
-            r#"SELECT j1_string as string FROM j1
-            UNION ALL
-            SELECT j2_string as string FROM j2"#,
-            r#"SELECT j1_string as string FROM j1
-            UNION ALL
-            SELECT j2_string as string FROM j2
-            ORDER BY string DESC
-            LIMIT 10"#
-        ];
-
-    // For each test sql string, we transform as follows:
-    // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) 
-> LogicalPlan (p2)
-    // We test not that s1==s2, but rather p1==p2. This ensures that unparser 
preserves the logical
-    // query information of the original sql string and disreguards other 
differences in syntax or
-    // quoting.
-    for query in tests {
-        let dialect = GenericDialect {};
-        let statement = Parser::new(&dialect)
-            .try_with_sql(query)?
-            .parse_statement()?;
-
-        let context = MockContextProvider::default();
-        let sql_to_rel = SqlToRel::new(&context);
-        let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
-        let roundtrip_statement = plan_to_sql(&plan)?;
-
-        let actual = format!("{}", &roundtrip_statement);
-        println!("roundtrip sql: {actual}");
-        println!("plan {}", plan.display_indent());
-
-        let plan_roundtrip = sql_to_rel
-            .sql_statement_to_plan(roundtrip_statement.clone())
-            .unwrap();
-
-        assert_eq!(plan, plan_roundtrip);
-    }
-
-    Ok(())
-}
-
-#[test]
-fn roundtrip_crossjoin() -> Result<()> {
-    let query = "select j1.j1_id, j2.j2_string from j1, j2";
-
-    let dialect = GenericDialect {};
-    let statement = Parser::new(&dialect)
-        .try_with_sql(query)?
-        .parse_statement()?;
-
-    let context = MockContextProvider::default();
-    let sql_to_rel = SqlToRel::new(&context);
-    let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
-    let roundtrip_statement = plan_to_sql(&plan)?;
-
-    let actual = format!("{}", &roundtrip_statement);
-    println!("roundtrip sql: {actual}");
-    println!("plan {}", plan.display_indent());
-
-    let plan_roundtrip = sql_to_rel
-        .sql_statement_to_plan(roundtrip_statement.clone())
-        .unwrap();
-
-    let expected = "Projection: j1.j1_id, j2.j2_string\
-        \n  Inner Join:  Filter: Boolean(true)\
-        \n    TableScan: j1\
-        \n    TableScan: j2";
-
-    assert_eq!(format!("{plan_roundtrip:?}"), expected);
-
-    Ok(())
-}
-
-#[test]
-fn roundtrip_statement_with_dialect() -> Result<()> {
-    struct TestStatementWithDialect {
-        sql: &'static str,
-        expected: &'static str,
-        parser_dialect: Box<dyn Dialect>,
-        unparser_dialect: Box<dyn UnparserDialect>,
-    }
-    let tests: Vec<TestStatementWithDialect> = vec![
-        TestStatementWithDialect {
-            sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
-            expected:
-                "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` 
ASC LIMIT 10",
-            parser_dialect: Box::new(MySqlDialect {}),
-            unparser_dialect: Box::new(UnparserMySqlDialect {}),
-        },
-        TestStatementWithDialect {
-            sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
-            expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC 
NULLS LAST LIMIT 10"#,
-            parser_dialect: Box::new(GenericDialect {}),
-            unparser_dialect: Box::new(UnparserDefaultDialect {}),
-        },
-    ];
-
-    for query in tests {
-        let statement = Parser::new(&*query.parser_dialect)
-            .try_with_sql(query.sql)?
-            .parse_statement()?;
-
-        let context = MockContextProvider::default();
-        let sql_to_rel = SqlToRel::new(&context);
-        let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
-        let unparser = Unparser::new(&*query.unparser_dialect);
-        let roundtrip_statement = unparser.plan_to_sql(&plan)?;
-
-        let actual = format!("{}", &roundtrip_statement);
-        println!("roundtrip sql: {actual}");
-        println!("plan {}", plan.display_indent());
-
-        assert_eq!(query.expected, actual);
-    }
-
-    Ok(())
-}
-
-#[test]
-fn test_unnest_logical_plan() -> Result<()> {
-    let query = "select unnest(struct_col), unnest(array_col), struct_col, 
array_col from unnest_table";
-
-    let dialect = GenericDialect {};
-    let statement = Parser::new(&dialect)
-        .try_with_sql(query)?
-        .parse_statement()?;
-
-    let context = MockContextProvider::default();
-    let sql_to_rel = SqlToRel::new(&context);
-    let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
-    let expected = "Projection: unnest(unnest_table.struct_col).field1, 
unnest(unnest_table.struct_col).field2, unnest(unnest_table.array_col), 
unnest_table.struct_col, unnest_table.array_col\
-        \n  Unnest: lists[unnest(unnest_table.array_col)] 
structs[unnest(unnest_table.struct_col)]\
-        \n    Projection: unnest_table.struct_col AS 
unnest(unnest_table.struct_col), unnest_table.array_col AS 
unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\
-        \n      TableScan: unnest_table";
-
-    assert_eq!(format!("{plan:?}"), expected);
-
-    Ok(())
-}
-
-#[test]
-fn test_table_references_in_plan_to_sql() {
-    fn test(table_name: &str, expected_sql: &str) {
-        let schema = Schema::new(vec![
-            Field::new("id", DataType::Utf8, false),
-            Field::new("value", DataType::Utf8, false),
-        ]);
-        let plan = table_scan(Some(table_name), &schema, None)
-            .unwrap()
-            .project(vec![col("id"), col("value")])
-            .unwrap()
-            .build()
-            .unwrap();
-        let sql = plan_to_sql(&plan).unwrap();
-
-        assert_eq!(format!("{}", sql), expected_sql)
-    }
-
-    test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, 
catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\"");
-    test("schema.table", "SELECT \"schema\".\"table\".id, 
\"schema\".\"table\".\"value\" FROM \"schema\".\"table\"");
-    test(
-        "table",
-        "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
-    );
-}
-
 #[cfg(test)]
 #[ctor::ctor]
 fn init() {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to