This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 71a99b8462 minor: consolidate unparser integration tests (#10736)
71a99b8462 is described below
commit 71a99b84627de49033037021cfeea1f2cd29db84
Author: Devin D'Angelo <[email protected]>
AuthorDate: Sat Jun 1 16:31:46 2024 -0400
minor: consolidate unparser integration tests (#10736)
* consolidate unparser integration tests
* add license to new files
* surpress dead code warnings
* run as one integration test binary
* add license
---
datafusion/sql/tests/cases/mod.rs | 18 ++
datafusion/sql/tests/cases/plan_to_sql.rs | 290 ++++++++++++++++++
datafusion/sql/tests/common/mod.rs | 227 ++++++++++++++
datafusion/sql/tests/sql_integration.rs | 477 +-----------------------------
4 files changed, 543 insertions(+), 469 deletions(-)
diff --git a/datafusion/sql/tests/cases/mod.rs
b/datafusion/sql/tests/cases/mod.rs
new file mode 100644
index 0000000000..fc4c59cc88
--- /dev/null
+++ b/datafusion/sql/tests/cases/mod.rs
@@ -0,0 +1,18 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod plan_to_sql;
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs
b/datafusion/sql/tests/cases/plan_to_sql.rs
new file mode 100644
index 0000000000..1bf441351a
--- /dev/null
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -0,0 +1,290 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::vec;
+
+use arrow_schema::*;
+use datafusion_common::{DFSchema, Result, TableReference};
+use datafusion_expr::{col, table_scan};
+use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
+use datafusion_sql::unparser::dialect::{
+ DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
+ MySqlDialect as UnparserMySqlDialect,
+};
+use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};
+
+use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
+use sqlparser::parser::Parser;
+
+use crate::common::MockContextProvider;
+
+#[test]
+fn roundtrip_expr() {
+ let tests: Vec<(TableReference, &str, &str)> = vec![
+ (TableReference::bare("person"), "age > 35", r#"(age > 35)"#),
+ (
+ TableReference::bare("person"),
+ "id = '10'",
+ r#"(id = '10')"#,
+ ),
+ (
+ TableReference::bare("person"),
+ "CAST(id AS VARCHAR)",
+ r#"CAST(id AS VARCHAR)"#,
+ ),
+ (
+ TableReference::bare("person"),
+ "SUM((age * 2))",
+ r#"SUM((age * 2))"#,
+ ),
+ ];
+
+ let roundtrip = |table, sql: &str| -> Result<String> {
+ let dialect = GenericDialect {};
+ let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?;
+
+ let context = MockContextProvider::default();
+ let schema = context.get_table_source(table)?.schema();
+ let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
+ let sql_to_rel = SqlToRel::new(&context);
+ let expr =
+ sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut
PlannerContext::new())?;
+
+ let ast = expr_to_sql(&expr)?;
+
+ Ok(format!("{}", ast))
+ };
+
+ for (table, query, expected) in tests {
+ let actual = roundtrip(table, query).unwrap();
+ assert_eq!(actual, expected);
+ }
+}
+
+#[test]
+fn roundtrip_statement() -> Result<()> {
+ let tests: Vec<&str> = vec![
+ "select ta.j1_id from j1 ta;",
+ "select ta.j1_id from j1 ta order by ta.j1_id;",
+ "select * from j1 ta order by ta.j1_id, ta.j1_string desc;",
+ "select * from j1 limit 10;",
+ "select ta.j1_id from j1 ta where ta.j1_id > 1;",
+ "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id
= tb.j2_id);",
+ "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb
on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
+ "select * from (select id, first_name from person)",
+ "select * from (select id, first_name from (select * from
person))",
+ "select id, count(*) as cnt from (select id from person) group by
id",
+ "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from
(select (id-1) as id from person) group by id",
+ "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where
NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id =
tb.j2_id))",
+ r#"select "First Name" from person_quoted_cols"#,
+ "select DISTINCT id FROM person",
+ "select DISTINCT on (id) id, first_name from person",
+ "select DISTINCT on (id) id, first_name from person order by id",
+ r#"select id, count("First Name") as cnt from (select id, "First
Name" from person_quoted_cols) group by id"#,
+ "select id, count(*) as cnt from (select p1.id as id from person
p1 inner join person p2 on p1.id=p2.id) group by id",
+ "select id, count(*), first_name from person group by first_name,
id",
+ "select id, sum(age), first_name from person group by first_name,
id",
+ "select id, count(*), first_name
+ from person
+ where id!=3 and first_name=='test'
+ group by first_name, id
+ having count(*)>5 and count(*)<10
+ order by count(*)",
+ r#"select id, count("First Name") as count_first_name, "Last Name"
+ from person_quoted_cols
+ where id!=3 and "First Name"=='test'
+ group by "Last Name", id
+ having count_first_name>5 and count_first_name<10
+ order by count_first_name, "Last Name""#,
+ r#"select p.id, count("First Name") as count_first_name,
+ "Last Name", sum(qp.id/p.id - (select sum(id) from
person_quoted_cols) ) / (select count(*) from person)
+ from (select id, "First Name", "Last Name" from
person_quoted_cols) qp
+ inner join (select * from person) p
+ on p.id = qp.id
+ where p.id!=3 and "First Name"=='test' and qp.id in
+ (select id from (select id, count(*) from person group by id
having count(*) > 0))
+ group by "Last Name", p.id
+ having count_first_name>5 and count_first_name<10
+ order by count_first_name, "Last Name""#,
+ r#"SELECT j1_string as string FROM j1
+ UNION ALL
+ SELECT j2_string as string FROM j2"#,
+ r#"SELECT j1_string as string FROM j1
+ UNION ALL
+ SELECT j2_string as string FROM j2
+ ORDER BY string DESC
+ LIMIT 10"#
+ ];
+
+ // For each test sql string, we transform as follows:
+ // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2)
-> LogicalPlan (p2)
+ // We test not that s1==s2, but rather p1==p2. This ensures that unparser
preserves the logical
+ // query information of the original sql string and disreguards other
differences in syntax or
+ // quoting.
+ for query in tests {
+ let dialect = GenericDialect {};
+ let statement = Parser::new(&dialect)
+ .try_with_sql(query)?
+ .parse_statement()?;
+
+ let context = MockContextProvider::default();
+ let sql_to_rel = SqlToRel::new(&context);
+ let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+ let roundtrip_statement = plan_to_sql(&plan)?;
+
+ let actual = format!("{}", &roundtrip_statement);
+ println!("roundtrip sql: {actual}");
+ println!("plan {}", plan.display_indent());
+
+ let plan_roundtrip = sql_to_rel
+ .sql_statement_to_plan(roundtrip_statement.clone())
+ .unwrap();
+
+ assert_eq!(plan, plan_roundtrip);
+ }
+
+ Ok(())
+}
+
+#[test]
+fn roundtrip_crossjoin() -> Result<()> {
+ let query = "select j1.j1_id, j2.j2_string from j1, j2";
+
+ let dialect = GenericDialect {};
+ let statement = Parser::new(&dialect)
+ .try_with_sql(query)?
+ .parse_statement()?;
+
+ let context = MockContextProvider::default();
+ let sql_to_rel = SqlToRel::new(&context);
+ let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+ let roundtrip_statement = plan_to_sql(&plan)?;
+
+ let actual = format!("{}", &roundtrip_statement);
+ println!("roundtrip sql: {actual}");
+ println!("plan {}", plan.display_indent());
+
+ let plan_roundtrip = sql_to_rel
+ .sql_statement_to_plan(roundtrip_statement.clone())
+ .unwrap();
+
+ let expected = "Projection: j1.j1_id, j2.j2_string\
+ \n Inner Join: Filter: Boolean(true)\
+ \n TableScan: j1\
+ \n TableScan: j2";
+
+ assert_eq!(format!("{plan_roundtrip:?}"), expected);
+
+ Ok(())
+}
+
+#[test]
+fn roundtrip_statement_with_dialect() -> Result<()> {
+ struct TestStatementWithDialect {
+ sql: &'static str,
+ expected: &'static str,
+ parser_dialect: Box<dyn Dialect>,
+ unparser_dialect: Box<dyn UnparserDialect>,
+ }
+ let tests: Vec<TestStatementWithDialect> = vec![
+ TestStatementWithDialect {
+ sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
+ expected:
+ "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id`
ASC LIMIT 10",
+ parser_dialect: Box::new(MySqlDialect {}),
+ unparser_dialect: Box::new(UnparserMySqlDialect {}),
+ },
+ TestStatementWithDialect {
+ sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
+ expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC
NULLS LAST LIMIT 10"#,
+ parser_dialect: Box::new(GenericDialect {}),
+ unparser_dialect: Box::new(UnparserDefaultDialect {}),
+ },
+ ];
+
+ for query in tests {
+ let statement = Parser::new(&*query.parser_dialect)
+ .try_with_sql(query.sql)?
+ .parse_statement()?;
+
+ let context = MockContextProvider::default();
+ let sql_to_rel = SqlToRel::new(&context);
+ let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+ let unparser = Unparser::new(&*query.unparser_dialect);
+ let roundtrip_statement = unparser.plan_to_sql(&plan)?;
+
+ let actual = format!("{}", &roundtrip_statement);
+ println!("roundtrip sql: {actual}");
+ println!("plan {}", plan.display_indent());
+
+ assert_eq!(query.expected, actual);
+ }
+
+ Ok(())
+}
+
+#[test]
+fn test_unnest_logical_plan() -> Result<()> {
+ let query = "select unnest(struct_col), unnest(array_col), struct_col,
array_col from unnest_table";
+
+ let dialect = GenericDialect {};
+ let statement = Parser::new(&dialect)
+ .try_with_sql(query)?
+ .parse_statement()?;
+
+ let context = MockContextProvider::default();
+ let sql_to_rel = SqlToRel::new(&context);
+ let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
+
+ let expected = "Projection: unnest(unnest_table.struct_col).field1,
unnest(unnest_table.struct_col).field2, unnest(unnest_table.array_col),
unnest_table.struct_col, unnest_table.array_col\
+ \n Unnest: lists[unnest(unnest_table.array_col)]
structs[unnest(unnest_table.struct_col)]\
+ \n Projection: unnest_table.struct_col AS
unnest(unnest_table.struct_col), unnest_table.array_col AS
unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\
+ \n TableScan: unnest_table";
+
+ assert_eq!(format!("{plan:?}"), expected);
+
+ Ok(())
+}
+
+#[test]
+fn test_table_references_in_plan_to_sql() {
+ fn test(table_name: &str, expected_sql: &str) {
+ let schema = Schema::new(vec![
+ Field::new("id", DataType::Utf8, false),
+ Field::new("value", DataType::Utf8, false),
+ ]);
+ let plan = table_scan(Some(table_name), &schema, None)
+ .unwrap()
+ .project(vec![col("id"), col("value")])
+ .unwrap()
+ .build()
+ .unwrap();
+ let sql = plan_to_sql(&plan).unwrap();
+
+ assert_eq!(format!("{}", sql), expected_sql)
+ }
+
+ test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id,
catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\"");
+ test("schema.table", "SELECT \"schema\".\"table\".id,
\"schema\".\"table\".\"value\" FROM \"schema\".\"table\"");
+ test(
+ "table",
+ "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
+ );
+}
diff --git a/datafusion/sql/tests/common/mod.rs
b/datafusion/sql/tests/common/mod.rs
new file mode 100644
index 0000000000..79de4bc826
--- /dev/null
+++ b/datafusion/sql/tests/common/mod.rs
@@ -0,0 +1,227 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[cfg(test)]
+use std::collections::HashMap;
+use std::{sync::Arc, vec};
+
+use arrow_schema::*;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::{plan_err, Result, TableReference};
+use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
+use datafusion_sql::planner::ContextProvider;
+
+#[derive(Default)]
+pub(crate) struct MockContextProvider {
+ options: ConfigOptions,
+ udfs: HashMap<String, Arc<ScalarUDF>>,
+ udafs: HashMap<String, Arc<AggregateUDF>>,
+}
+
+impl MockContextProvider {
+ // Surpressing dead code warning, as this is used in integration test
crates
+ #[allow(dead_code)]
+ pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions {
+ &mut self.options
+ }
+
+ #[allow(dead_code)]
+ pub(crate) fn with_udf(mut self, udf: ScalarUDF) -> Self {
+ self.udfs.insert(udf.name().to_string(), Arc::new(udf));
+ self
+ }
+}
+
+impl ContextProvider for MockContextProvider {
+ fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn
TableSource>> {
+ let schema = match name.table() {
+ "test" => Ok(Schema::new(vec![
+ Field::new("t_date32", DataType::Date32, false),
+ Field::new("t_date64", DataType::Date64, false),
+ ])),
+ "j1" => Ok(Schema::new(vec![
+ Field::new("j1_id", DataType::Int32, false),
+ Field::new("j1_string", DataType::Utf8, false),
+ ])),
+ "j2" => Ok(Schema::new(vec![
+ Field::new("j2_id", DataType::Int32, false),
+ Field::new("j2_string", DataType::Utf8, false),
+ ])),
+ "j3" => Ok(Schema::new(vec![
+ Field::new("j3_id", DataType::Int32, false),
+ Field::new("j3_string", DataType::Utf8, false),
+ ])),
+ "test_decimal" => Ok(Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("price", DataType::Decimal128(10, 2), false),
+ ])),
+ "person" => Ok(Schema::new(vec![
+ Field::new("id", DataType::UInt32, false),
+ Field::new("first_name", DataType::Utf8, false),
+ Field::new("last_name", DataType::Utf8, false),
+ Field::new("age", DataType::Int32, false),
+ Field::new("state", DataType::Utf8, false),
+ Field::new("salary", DataType::Float64, false),
+ Field::new(
+ "birth_date",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ false,
+ ),
+ Field::new("😀", DataType::Int32, false),
+ ])),
+ "person_quoted_cols" => Ok(Schema::new(vec![
+ Field::new("id", DataType::UInt32, false),
+ Field::new("First Name", DataType::Utf8, false),
+ Field::new("Last Name", DataType::Utf8, false),
+ Field::new("Age", DataType::Int32, false),
+ Field::new("State", DataType::Utf8, false),
+ Field::new("Salary", DataType::Float64, false),
+ Field::new(
+ "Birth Date",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ false,
+ ),
+ Field::new("😀", DataType::Int32, false),
+ ])),
+ "orders" => Ok(Schema::new(vec![
+ Field::new("order_id", DataType::UInt32, false),
+ Field::new("customer_id", DataType::UInt32, false),
+ Field::new("o_item_id", DataType::Utf8, false),
+ Field::new("qty", DataType::Int32, false),
+ Field::new("price", DataType::Float64, false),
+ Field::new("delivered", DataType::Boolean, false),
+ ])),
+ "array" => Ok(Schema::new(vec![
+ Field::new(
+ "left",
+ DataType::List(Arc::new(Field::new("item",
DataType::Int64, true))),
+ false,
+ ),
+ Field::new(
+ "right",
+ DataType::List(Arc::new(Field::new("item",
DataType::Int64, true))),
+ false,
+ ),
+ ])),
+ "lineitem" => Ok(Schema::new(vec![
+ Field::new("l_item_id", DataType::UInt32, false),
+ Field::new("l_description", DataType::Utf8, false),
+ Field::new("price", DataType::Float64, false),
+ ])),
+ "aggregate_test_100" => Ok(Schema::new(vec![
+ Field::new("c1", DataType::Utf8, false),
+ Field::new("c2", DataType::UInt32, false),
+ Field::new("c3", DataType::Int8, false),
+ Field::new("c4", DataType::Int16, false),
+ Field::new("c5", DataType::Int32, false),
+ Field::new("c6", DataType::Int64, false),
+ Field::new("c7", DataType::UInt8, false),
+ Field::new("c8", DataType::UInt16, false),
+ Field::new("c9", DataType::UInt32, false),
+ Field::new("c10", DataType::UInt64, false),
+ Field::new("c11", DataType::Float32, false),
+ Field::new("c12", DataType::Float64, false),
+ Field::new("c13", DataType::Utf8, false),
+ ])),
+ "UPPERCASE_test" => Ok(Schema::new(vec![
+ Field::new("Id", DataType::UInt32, false),
+ Field::new("lower", DataType::UInt32, false),
+ ])),
+ "unnest_table" => Ok(Schema::new(vec![
+ Field::new(
+ "array_col",
+ DataType::List(Arc::new(Field::new("item",
DataType::Int64, true))),
+ false,
+ ),
+ Field::new(
+ "struct_col",
+ DataType::Struct(Fields::from(vec![
+ Field::new("field1", DataType::Int64, true),
+ Field::new("field2", DataType::Utf8, true),
+ ])),
+ false,
+ ),
+ ])),
+ _ => plan_err!("No table named: {} found", name.table()),
+ };
+
+ match schema {
+ Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))),
+ Err(e) => Err(e),
+ }
+ }
+
+ fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
+ self.udfs.get(name).cloned()
+ }
+
+ fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
+ self.udafs.get(name).cloned()
+ }
+
+ fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
+ unimplemented!()
+ }
+
+ fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
+ None
+ }
+
+ fn options(&self) -> &ConfigOptions {
+ &self.options
+ }
+
+ fn create_cte_work_table(
+ &self,
+ _name: &str,
+ schema: SchemaRef,
+ ) -> Result<Arc<dyn TableSource>> {
+ Ok(Arc::new(EmptyTable::new(schema)))
+ }
+
+ fn udf_names(&self) -> Vec<String> {
+ self.udfs.keys().cloned().collect()
+ }
+
+ fn udaf_names(&self) -> Vec<String> {
+ self.udafs.keys().cloned().collect()
+ }
+
+ fn udwf_names(&self) -> Vec<String> {
+ Vec::new()
+ }
+}
+
+struct EmptyTable {
+ table_schema: SchemaRef,
+}
+
+impl EmptyTable {
+ fn new(table_schema: SchemaRef) -> Self {
+ Self { table_schema }
+ }
+}
+
+impl TableSource for EmptyTable {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.table_schema.clone()
+ }
+}
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index a7224805f3..1f064ea0f5 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -18,35 +18,29 @@
use std::any::Any;
#[cfg(test)]
use std::collections::HashMap;
-use std::{sync::Arc, vec};
+use std::vec;
use arrow_schema::TimeUnit::Nanosecond;
use arrow_schema::*;
-use datafusion_common::config::ConfigOptions;
+use common::MockContextProvider;
use datafusion_common::{
- assert_contains, plan_err, DFSchema, DataFusionError, ParamValues, Result,
- ScalarValue, TableReference,
+ assert_contains, DataFusionError, ParamValues, Result, ScalarValue,
};
-use datafusion_expr::{col, table_scan};
use datafusion_expr::{
logical_plan::{LogicalPlan, Prepare},
- AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature,
TableSource,
- Volatility, WindowUDF,
+ ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions::{string, unicode};
-use datafusion_sql::unparser::dialect::{
- DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
- MySqlDialect as UnparserMySqlDialect,
-};
-use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};
use datafusion_sql::{
parser::DFParser,
- planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel},
+ planner::{ParserOptions, SqlToRel},
};
use rstest::rstest;
use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
-use sqlparser::parser::Parser;
+
+mod cases;
+mod common;
#[test]
fn test_schema_support() {
@@ -2797,184 +2791,6 @@ fn prepare_stmt_replace_params_quick_test(
plan
}
-#[derive(Default)]
-struct MockContextProvider {
- options: ConfigOptions,
- udfs: HashMap<String, Arc<ScalarUDF>>,
- udafs: HashMap<String, Arc<AggregateUDF>>,
-}
-
-impl MockContextProvider {
- fn options_mut(&mut self) -> &mut ConfigOptions {
- &mut self.options
- }
-
- fn with_udf(mut self, udf: ScalarUDF) -> Self {
- self.udfs.insert(udf.name().to_string(), Arc::new(udf));
- self
- }
-}
-
-impl ContextProvider for MockContextProvider {
- fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn
TableSource>> {
- let schema = match name.table() {
- "test" => Ok(Schema::new(vec![
- Field::new("t_date32", DataType::Date32, false),
- Field::new("t_date64", DataType::Date64, false),
- ])),
- "j1" => Ok(Schema::new(vec![
- Field::new("j1_id", DataType::Int32, false),
- Field::new("j1_string", DataType::Utf8, false),
- ])),
- "j2" => Ok(Schema::new(vec![
- Field::new("j2_id", DataType::Int32, false),
- Field::new("j2_string", DataType::Utf8, false),
- ])),
- "j3" => Ok(Schema::new(vec![
- Field::new("j3_id", DataType::Int32, false),
- Field::new("j3_string", DataType::Utf8, false),
- ])),
- "test_decimal" => Ok(Schema::new(vec![
- Field::new("id", DataType::Int32, false),
- Field::new("price", DataType::Decimal128(10, 2), false),
- ])),
- "person" => Ok(Schema::new(vec![
- Field::new("id", DataType::UInt32, false),
- Field::new("first_name", DataType::Utf8, false),
- Field::new("last_name", DataType::Utf8, false),
- Field::new("age", DataType::Int32, false),
- Field::new("state", DataType::Utf8, false),
- Field::new("salary", DataType::Float64, false),
- Field::new(
- "birth_date",
- DataType::Timestamp(TimeUnit::Nanosecond, None),
- false,
- ),
- Field::new("😀", DataType::Int32, false),
- ])),
- "person_quoted_cols" => Ok(Schema::new(vec![
- Field::new("id", DataType::UInt32, false),
- Field::new("First Name", DataType::Utf8, false),
- Field::new("Last Name", DataType::Utf8, false),
- Field::new("Age", DataType::Int32, false),
- Field::new("State", DataType::Utf8, false),
- Field::new("Salary", DataType::Float64, false),
- Field::new(
- "Birth Date",
- DataType::Timestamp(TimeUnit::Nanosecond, None),
- false,
- ),
- Field::new("😀", DataType::Int32, false),
- ])),
- "orders" => Ok(Schema::new(vec![
- Field::new("order_id", DataType::UInt32, false),
- Field::new("customer_id", DataType::UInt32, false),
- Field::new("o_item_id", DataType::Utf8, false),
- Field::new("qty", DataType::Int32, false),
- Field::new("price", DataType::Float64, false),
- Field::new("delivered", DataType::Boolean, false),
- ])),
- "array" => Ok(Schema::new(vec![
- Field::new(
- "left",
- DataType::List(Arc::new(Field::new("item",
DataType::Int64, true))),
- false,
- ),
- Field::new(
- "right",
- DataType::List(Arc::new(Field::new("item",
DataType::Int64, true))),
- false,
- ),
- ])),
- "lineitem" => Ok(Schema::new(vec![
- Field::new("l_item_id", DataType::UInt32, false),
- Field::new("l_description", DataType::Utf8, false),
- Field::new("price", DataType::Float64, false),
- ])),
- "aggregate_test_100" => Ok(Schema::new(vec![
- Field::new("c1", DataType::Utf8, false),
- Field::new("c2", DataType::UInt32, false),
- Field::new("c3", DataType::Int8, false),
- Field::new("c4", DataType::Int16, false),
- Field::new("c5", DataType::Int32, false),
- Field::new("c6", DataType::Int64, false),
- Field::new("c7", DataType::UInt8, false),
- Field::new("c8", DataType::UInt16, false),
- Field::new("c9", DataType::UInt32, false),
- Field::new("c10", DataType::UInt64, false),
- Field::new("c11", DataType::Float32, false),
- Field::new("c12", DataType::Float64, false),
- Field::new("c13", DataType::Utf8, false),
- ])),
- "UPPERCASE_test" => Ok(Schema::new(vec![
- Field::new("Id", DataType::UInt32, false),
- Field::new("lower", DataType::UInt32, false),
- ])),
- "unnest_table" => Ok(Schema::new(vec![
- Field::new(
- "array_col",
- DataType::List(Arc::new(Field::new("item",
DataType::Int64, true))),
- false,
- ),
- Field::new(
- "struct_col",
- DataType::Struct(Fields::from(vec![
- Field::new("field1", DataType::Int64, true),
- Field::new("field2", DataType::Utf8, true),
- ])),
- false,
- ),
- ])),
- _ => plan_err!("No table named: {} found", name.table()),
- };
-
- match schema {
- Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))),
- Err(e) => Err(e),
- }
- }
-
- fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
- self.udfs.get(name).cloned()
- }
-
- fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
- self.udafs.get(name).cloned()
- }
-
- fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
- unimplemented!()
- }
-
- fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
- None
- }
-
- fn options(&self) -> &ConfigOptions {
- &self.options
- }
-
- fn create_cte_work_table(
- &self,
- _name: &str,
- schema: SchemaRef,
- ) -> Result<Arc<dyn TableSource>> {
- Ok(Arc::new(EmptyTable::new(schema)))
- }
-
- fn udf_names(&self) -> Vec<String> {
- self.udfs.keys().cloned().collect()
- }
-
- fn udaf_names(&self) -> Vec<String> {
- self.udafs.keys().cloned().collect()
- }
-
- fn udwf_names(&self) -> Vec<String> {
- Vec::new()
- }
-}
-
#[test]
fn select_partially_qualified_column() {
let sql = r#"SELECT person.first_name FROM public.person"#;
@@ -4552,283 +4368,6 @@ fn assert_field_not_found(err: DataFusionError, name:
&str) {
}
}
-struct EmptyTable {
- table_schema: SchemaRef,
-}
-
-impl EmptyTable {
- fn new(table_schema: SchemaRef) -> Self {
- Self { table_schema }
- }
-}
-
-impl TableSource for EmptyTable {
- fn as_any(&self) -> &dyn std::any::Any {
- self
- }
-
- fn schema(&self) -> SchemaRef {
- self.table_schema.clone()
- }
-}
-
-#[test]
-fn roundtrip_expr() {
- let tests: Vec<(TableReference, &str, &str)> = vec![
- (TableReference::bare("person"), "age > 35", r#"(age > 35)"#),
- (
- TableReference::bare("person"),
- "id = '10'",
- r#"(id = '10')"#,
- ),
- (
- TableReference::bare("person"),
- "CAST(id AS VARCHAR)",
- r#"CAST(id AS VARCHAR)"#,
- ),
- (
- TableReference::bare("person"),
- "SUM((age * 2))",
- r#"SUM((age * 2))"#,
- ),
- ];
-
- let roundtrip = |table, sql: &str| -> Result<String> {
- let dialect = GenericDialect {};
- let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?;
-
- let context = MockContextProvider::default();
- let schema = context.get_table_source(table)?.schema();
- let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
- let sql_to_rel = SqlToRel::new(&context);
- let expr =
- sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut
PlannerContext::new())?;
-
- let ast = expr_to_sql(&expr)?;
-
- Ok(format!("{}", ast))
- };
-
- for (table, query, expected) in tests {
- let actual = roundtrip(table, query).unwrap();
- assert_eq!(actual, expected);
- }
-}
-
-#[test]
-fn roundtrip_statement() -> Result<()> {
- let tests: Vec<&str> = vec![
- "select ta.j1_id from j1 ta;",
- "select ta.j1_id from j1 ta order by ta.j1_id;",
- "select * from j1 ta order by ta.j1_id, ta.j1_string desc;",
- "select * from j1 limit 10;",
- "select ta.j1_id from j1 ta where ta.j1_id > 1;",
- "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id
= tb.j2_id);",
- "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb
on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
- "select * from (select id, first_name from person)",
- "select * from (select id, first_name from (select * from
person))",
- "select id, count(*) as cnt from (select id from person) group by
id",
- "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from
(select (id-1) as id from person) group by id",
- "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where
NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id =
tb.j2_id))",
- r#"select "First Name" from person_quoted_cols"#,
- "select DISTINCT id FROM person",
- "select DISTINCT on (id) id, first_name from person",
- "select DISTINCT on (id) id, first_name from person order by id",
- r#"select id, count("First Name") as cnt from (select id, "First
Name" from person_quoted_cols) group by id"#,
- "select id, count(*) as cnt from (select p1.id as id from person
p1 inner join person p2 on p1.id=p2.id) group by id",
- "select id, count(*), first_name from person group by first_name,
id",
- "select id, sum(age), first_name from person group by first_name,
id",
- "select id, count(*), first_name
- from person
- where id!=3 and first_name=='test'
- group by first_name, id
- having count(*)>5 and count(*)<10
- order by count(*)",
- r#"select id, count("First Name") as count_first_name, "Last Name"
- from person_quoted_cols
- where id!=3 and "First Name"=='test'
- group by "Last Name", id
- having count_first_name>5 and count_first_name<10
- order by count_first_name, "Last Name""#,
- r#"select p.id, count("First Name") as count_first_name,
- "Last Name", sum(qp.id/p.id - (select sum(id) from
person_quoted_cols) ) / (select count(*) from person)
- from (select id, "First Name", "Last Name" from
person_quoted_cols) qp
- inner join (select * from person) p
- on p.id = qp.id
- where p.id!=3 and "First Name"=='test' and qp.id in
- (select id from (select id, count(*) from person group by id
having count(*) > 0))
- group by "Last Name", p.id
- having count_first_name>5 and count_first_name<10
- order by count_first_name, "Last Name""#,
- r#"SELECT j1_string as string FROM j1
- UNION ALL
- SELECT j2_string as string FROM j2"#,
- r#"SELECT j1_string as string FROM j1
- UNION ALL
- SELECT j2_string as string FROM j2
- ORDER BY string DESC
- LIMIT 10"#
- ];
-
- // For each test sql string, we transform as follows:
- // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2)
-> LogicalPlan (p2)
- // We test not that s1==s2, but rather p1==p2. This ensures that unparser
preserves the logical
- // query information of the original sql string and disreguards other
differences in syntax or
- // quoting.
- for query in tests {
- let dialect = GenericDialect {};
- let statement = Parser::new(&dialect)
- .try_with_sql(query)?
- .parse_statement()?;
-
- let context = MockContextProvider::default();
- let sql_to_rel = SqlToRel::new(&context);
- let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
- let roundtrip_statement = plan_to_sql(&plan)?;
-
- let actual = format!("{}", &roundtrip_statement);
- println!("roundtrip sql: {actual}");
- println!("plan {}", plan.display_indent());
-
- let plan_roundtrip = sql_to_rel
- .sql_statement_to_plan(roundtrip_statement.clone())
- .unwrap();
-
- assert_eq!(plan, plan_roundtrip);
- }
-
- Ok(())
-}
-
-#[test]
-fn roundtrip_crossjoin() -> Result<()> {
- let query = "select j1.j1_id, j2.j2_string from j1, j2";
-
- let dialect = GenericDialect {};
- let statement = Parser::new(&dialect)
- .try_with_sql(query)?
- .parse_statement()?;
-
- let context = MockContextProvider::default();
- let sql_to_rel = SqlToRel::new(&context);
- let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
- let roundtrip_statement = plan_to_sql(&plan)?;
-
- let actual = format!("{}", &roundtrip_statement);
- println!("roundtrip sql: {actual}");
- println!("plan {}", plan.display_indent());
-
- let plan_roundtrip = sql_to_rel
- .sql_statement_to_plan(roundtrip_statement.clone())
- .unwrap();
-
- let expected = "Projection: j1.j1_id, j2.j2_string\
- \n Inner Join: Filter: Boolean(true)\
- \n TableScan: j1\
- \n TableScan: j2";
-
- assert_eq!(format!("{plan_roundtrip:?}"), expected);
-
- Ok(())
-}
-
-#[test]
-fn roundtrip_statement_with_dialect() -> Result<()> {
- struct TestStatementWithDialect {
- sql: &'static str,
- expected: &'static str,
- parser_dialect: Box<dyn Dialect>,
- unparser_dialect: Box<dyn UnparserDialect>,
- }
- let tests: Vec<TestStatementWithDialect> = vec![
- TestStatementWithDialect {
- sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
- expected:
- "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id`
ASC LIMIT 10",
- parser_dialect: Box::new(MySqlDialect {}),
- unparser_dialect: Box::new(UnparserMySqlDialect {}),
- },
- TestStatementWithDialect {
- sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
- expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC
NULLS LAST LIMIT 10"#,
- parser_dialect: Box::new(GenericDialect {}),
- unparser_dialect: Box::new(UnparserDefaultDialect {}),
- },
- ];
-
- for query in tests {
- let statement = Parser::new(&*query.parser_dialect)
- .try_with_sql(query.sql)?
- .parse_statement()?;
-
- let context = MockContextProvider::default();
- let sql_to_rel = SqlToRel::new(&context);
- let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
- let unparser = Unparser::new(&*query.unparser_dialect);
- let roundtrip_statement = unparser.plan_to_sql(&plan)?;
-
- let actual = format!("{}", &roundtrip_statement);
- println!("roundtrip sql: {actual}");
- println!("plan {}", plan.display_indent());
-
- assert_eq!(query.expected, actual);
- }
-
- Ok(())
-}
-
-#[test]
-fn test_unnest_logical_plan() -> Result<()> {
- let query = "select unnest(struct_col), unnest(array_col), struct_col,
array_col from unnest_table";
-
- let dialect = GenericDialect {};
- let statement = Parser::new(&dialect)
- .try_with_sql(query)?
- .parse_statement()?;
-
- let context = MockContextProvider::default();
- let sql_to_rel = SqlToRel::new(&context);
- let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
-
- let expected = "Projection: unnest(unnest_table.struct_col).field1,
unnest(unnest_table.struct_col).field2, unnest(unnest_table.array_col),
unnest_table.struct_col, unnest_table.array_col\
- \n Unnest: lists[unnest(unnest_table.array_col)]
structs[unnest(unnest_table.struct_col)]\
- \n Projection: unnest_table.struct_col AS
unnest(unnest_table.struct_col), unnest_table.array_col AS
unnest(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col\
- \n TableScan: unnest_table";
-
- assert_eq!(format!("{plan:?}"), expected);
-
- Ok(())
-}
-
-#[test]
-fn test_table_references_in_plan_to_sql() {
- fn test(table_name: &str, expected_sql: &str) {
- let schema = Schema::new(vec![
- Field::new("id", DataType::Utf8, false),
- Field::new("value", DataType::Utf8, false),
- ]);
- let plan = table_scan(Some(table_name), &schema, None)
- .unwrap()
- .project(vec![col("id"), col("value")])
- .unwrap()
- .build()
- .unwrap();
- let sql = plan_to_sql(&plan).unwrap();
-
- assert_eq!(format!("{}", sql), expected_sql)
- }
-
- test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id,
catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\"");
- test("schema.table", "SELECT \"schema\".\"table\".id,
\"schema\".\"table\".\"value\" FROM \"schema\".\"table\"");
- test(
- "table",
- "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
- );
-}
-
#[cfg(test)]
#[ctor::ctor]
fn init() {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]