This is an automated email from the ASF dual-hosted git repository.
goldmedal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d20b6d18b9 fix: unparsing left/ right semi/mark join (#15212)
d20b6d18b9 is described below
commit d20b6d18b901708e48f965385af2119fed01a4c7
Author: Chen Chongchen <[email protected]>
AuthorDate: Wed Mar 19 19:16:14 2025 +0800
fix: unparsing left/ right semi/mark join (#15212)
* fix: unparse semi/mark join
* recursive
* fix use
* update
* stackoverflow
* update stack size
* update test
* fix test
* format
* refine
* refine ci based on goldmedal's suggestion
---
.github/workflows/extended.yml | 2 +-
datafusion/sql/src/unparser/ast.rs | 34 +++++-
datafusion/sql/src/unparser/expr.rs | 3 +-
datafusion/sql/src/unparser/plan.rs | 90 +++++++++++++---
datafusion/sql/tests/cases/plan_to_sql.rs | 173 +++++++++++++++++++++++++++++-
5 files changed, 280 insertions(+), 22 deletions(-)
diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml
index 9ee72653b2..3942e75257 100644
--- a/.github/workflows/extended.yml
+++ b/.github/workflows/extended.yml
@@ -81,7 +81,7 @@ jobs:
- name: Run tests (excluding doctests)
env:
RUST_BACKTRACE: 1
- run: cargo test --profile ci --exclude datafusion-examples --exclude
datafusion-benchmarks --workspace --lib --tests --bins --features
avro,json,backtrace,extended_tests
+ run: cargo test --profile ci --exclude datafusion-examples --exclude
datafusion-benchmarks --workspace --lib --tests --bins --features
avro,json,backtrace,extended_tests,recursive_protection
- name: Verify Working Directory Clean
run: git diff --exit-code
- name: Cleanup
diff --git a/datafusion/sql/src/unparser/ast.rs
b/datafusion/sql/src/unparser/ast.rs
index 6d77c01ea8..211ae84a00 100644
--- a/datafusion/sql/src/unparser/ast.rs
+++ b/datafusion/sql/src/unparser/ast.rs
@@ -16,9 +16,10 @@
// under the License.
use core::fmt;
+use std::ops::ControlFlow;
-use sqlparser::ast;
use sqlparser::ast::helpers::attached_token::AttachedToken;
+use sqlparser::ast::{self, visit_expressions_mut};
#[derive(Clone)]
pub struct QueryBuilder {
@@ -176,6 +177,37 @@ impl SelectBuilder {
self.lateral_views = value;
self
}
+
+ /// Replaces the selection with a new value.
+ ///
+ /// This function is used to replace a specific expression within the
selection.
+ /// Unlike the `selection` method which combines existing and new
selections with AND,
+ /// this method searches for and replaces occurrences of a specific
expression.
+ ///
+ /// This method is primarily used to modify LEFT MARK JOIN expressions.
+ /// When processing a LEFT MARK JOIN, we need to replace the placeholder
expression
+ /// with the actual join condition in the selection clause.
+ ///
+ /// # Arguments
+ ///
+ /// * `existing_expr` - The expression to replace
+ /// * `value` - The new expression to set as the selection
+ pub fn replace_mark(
+ &mut self,
+ existing_expr: &ast::Expr,
+ value: &ast::Expr,
+ ) -> &mut Self {
+ if let Some(selection) = &mut self.selection {
+ visit_expressions_mut(selection, |expr| {
+ if expr == existing_expr {
+ *expr = value.clone();
+ }
+ ControlFlow::<()>::Continue(())
+ });
+ }
+ self
+ }
+
pub fn selection(&mut self, value: Option<ast::Expr>) -> &mut Self {
// With filter pushdown optimization, the LogicalPlan can have filters
defined as part of `TableScan` and `Filter` nodes.
// To avoid overwriting one of the filters, we combine the existing
filter with the additional filter.
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index 7905207faf..5e74849cd9 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -94,6 +94,7 @@ impl Unparser<'_> {
Ok(root_expr)
}
+ #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn expr_to_sql_inner(&self, expr: &Expr) -> Result<ast::Expr> {
match expr {
Expr::InList(InList {
@@ -674,7 +675,7 @@ impl Unparser<'_> {
}
}
- fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
+ pub fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
if let Some(table_ref) = &col.relation {
let mut id = if self.dialect.full_qualified_col() {
table_ref.to_vec()
diff --git a/datafusion/sql/src/unparser/plan.rs
b/datafusion/sql/src/unparser/plan.rs
index b14fbdff23..507a6b2761 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -322,6 +322,7 @@ impl Unparser<'_> {
}
}
+ #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn select_to_sql_recursively(
&self,
plan: &LogicalPlan,
@@ -566,14 +567,20 @@ impl Unparser<'_> {
}
LogicalPlan::Join(join) => {
let mut table_scan_filters = vec![];
+ let (left_plan, right_plan) = match join.join_type {
+ JoinType::RightSemi | JoinType::RightAnti => {
+ (&join.right, &join.left)
+ }
+ _ => (&join.left, &join.right),
+ };
let left_plan =
- match
try_transform_to_simple_table_scan_with_filters(&join.left)? {
+ match
try_transform_to_simple_table_scan_with_filters(left_plan)? {
Some((plan, filters)) => {
table_scan_filters.extend(filters);
Arc::new(plan)
}
- None => Arc::clone(&join.left),
+ None => Arc::clone(left_plan),
};
self.select_to_sql_recursively(
@@ -584,12 +591,12 @@ impl Unparser<'_> {
)?;
let right_plan =
- match
try_transform_to_simple_table_scan_with_filters(&join.right)? {
+ match
try_transform_to_simple_table_scan_with_filters(right_plan)? {
Some((plan, filters)) => {
table_scan_filters.extend(filters);
Arc::new(plan)
}
- None => Arc::clone(&join.right),
+ None => Arc::clone(right_plan),
};
let mut right_relation = RelationBuilder::default();
@@ -641,19 +648,70 @@ impl Unparser<'_> {
&mut right_relation,
)?;
- let Ok(Some(relation)) = right_relation.build() else {
- return internal_err!("Failed to build right relation");
- };
-
- let ast_join = ast::Join {
- relation,
- global: false,
- join_operator: self
- .join_operator_to_sql(join.join_type,
join_constraint)?,
+ match join.join_type {
+ JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::LeftMark
+ | JoinType::RightSemi
+ | JoinType::RightAnti => {
+ let mut query_builder = QueryBuilder::default();
+ let mut from = TableWithJoinsBuilder::default();
+ let mut exists_select: SelectBuilder =
SelectBuilder::default();
+ from.relation(right_relation);
+ exists_select.push_from(from);
+ if let Some(filter) = &join.filter {
+
exists_select.selection(Some(self.expr_to_sql(filter)?));
+ }
+ for (left, right) in &join.on {
+ exists_select.selection(Some(
+
self.expr_to_sql(&left.clone().eq(right.clone()))?,
+ ));
+ }
+
exists_select.projection(vec![ast::SelectItem::UnnamedExpr(
+
ast::Expr::Value(ast::Value::Number("1".to_string(), false)),
+ )]);
+ query_builder.body(Box::new(SetExpr::Select(Box::new(
+ exists_select.build()?,
+ ))));
+
+ let negated = match join.join_type {
+ JoinType::LeftSemi
+ | JoinType::RightSemi
+ | JoinType::LeftMark => false,
+ JoinType::LeftAnti | JoinType::RightAnti => true,
+ _ => unreachable!(),
+ };
+ let exists_expr = ast::Expr::Exists {
+ subquery: Box::new(query_builder.build()?),
+ negated,
+ };
+ if join.join_type == JoinType::LeftMark {
+ let (table_ref, _) =
right_plan.schema().qualified_field(0);
+ let column = self
+ .col_to_sql(&Column::new(table_ref.cloned(),
"mark"))?;
+ select.replace_mark(&column, &exists_expr);
+ } else {
+ select.selection(Some(exists_expr));
+ }
+ }
+ JoinType::Inner
+ | JoinType::Left
+ | JoinType::Right
+ | JoinType::Full => {
+ let Ok(Some(relation)) = right_relation.build() else {
+ return internal_err!("Failed to build right
relation");
+ };
+ let ast_join = ast::Join {
+ relation,
+ global: false,
+ join_operator: self
+ .join_operator_to_sql(join.join_type,
join_constraint)?,
+ };
+ let mut from = select.pop_from().unwrap();
+ from.push_join(ast_join);
+ select.push_from(from);
+ }
};
- let mut from = select.pop_from().unwrap();
- from.push_join(ast_join);
- select.push_from(from);
Ok(())
}
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 0abc890dfa..c3a28f050f 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -16,7 +16,9 @@
// under the License.
use arrow::datatypes::{DataType, Field, Schema};
-use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result,
TableReference};
+use datafusion_common::{
+ assert_contains, Column, DFSchema, DFSchemaRef, Result, TableReference,
+};
use datafusion_expr::test::function_stub::{
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
};
@@ -32,7 +34,8 @@ use datafusion_functions_window::rank::rank_udwf;
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect,
DefaultDialect,
- Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect,
SqliteDialect,
+ Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect,
+ PostgreSqlDialect as UnparserPostgreSqlDialect, SqliteDialect,
};
use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};
use sqlparser::ast::Statement;
@@ -43,7 +46,7 @@ use std::{fmt, vec};
use crate::common::{MockContextProvider, MockSessionState};
use datafusion_expr::builder::{
- project, table_scan_with_filter_and_fetch, table_scan_with_filters,
+ project, subquery_alias, table_scan_with_filter_and_fetch,
table_scan_with_filters,
};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_nested::extract::array_element_udf;
@@ -1746,3 +1749,167 @@ fn test_unparse_subquery_alias_with_table_pushdown() ->
Result<()> {
assert_eq!(sql.to_string(), expected);
Ok(())
}
+
+#[test]
+fn test_unparse_left_anti_join() -> Result<()> {
+ // select t1.d from t1 where c not in (select c from t2)
+ let schema = Schema::new(vec![
+ Field::new("c", DataType::Int32, false),
+ Field::new("d", DataType::Int32, false),
+ ]);
+
+ // LeftAnti Join: t1.c = __correlated_sq_1.c
+ // TableScan: t1 projection=[c]
+ // SubqueryAlias: __correlated_sq_1
+ // TableScan: t2 projection=[c]
+
+ let table_scan1 = table_scan(Some("t1"), &schema, Some(vec![0,
1]))?.build()?;
+ let table_scan2 = table_scan(Some("t2"), &schema, Some(vec![0]))?.build()?;
+ let subquery = subquery_alias(table_scan2, "__correlated_sq_1")?;
+ let plan = LogicalPlanBuilder::from(table_scan1)
+ .project(vec![col("t1.d")])?
+ .join_on(
+ subquery,
+ datafusion_expr::JoinType::LeftAnti,
+ vec![col("t1.c").eq(col("__correlated_sq_1.c"))],
+ )?
+ .build()?;
+
+ let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+ let sql = unparser.plan_to_sql(&plan)?;
+ assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE NOT EXISTS (SELECT 1
FROM \"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" =
\"__correlated_sq_1\".\"c\"))", sql.to_string());
+ Ok(())
+}
+
+#[test]
+fn test_unparse_left_semi_join() -> Result<()> {
+ // select t1.d from t1 where c in (select c from t2)
+ let schema = Schema::new(vec![
+ Field::new("c", DataType::Int32, false),
+ Field::new("d", DataType::Int32, false),
+ ]);
+
+ // LeftSemi Join: t1.c = __correlated_sq_1.c
+ // TableScan: t1 projection=[c]
+ // SubqueryAlias: __correlated_sq_1
+ // TableScan: t2 projection=[c]
+
+ let table_scan1 = table_scan(Some("t1"), &schema, Some(vec![0,
1]))?.build()?;
+ let table_scan2 = table_scan(Some("t2"), &schema, Some(vec![0]))?.build()?;
+ let subquery = subquery_alias(table_scan2, "__correlated_sq_1")?;
+ let plan = LogicalPlanBuilder::from(table_scan1)
+ .project(vec![col("t1.d")])?
+ .join_on(
+ subquery,
+ datafusion_expr::JoinType::LeftSemi,
+ vec![col("t1.c").eq(col("__correlated_sq_1.c"))],
+ )?
+ .build()?;
+
+ let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+ let sql = unparser.plan_to_sql(&plan)?;
+ assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE EXISTS (SELECT 1 FROM
\"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" =
\"__correlated_sq_1\".\"c\"))", sql.to_string());
+ Ok(())
+}
+
+#[test]
+fn test_unparse_left_mark_join() -> Result<()> {
+ // select t1.d from t1 where t1.d < 0 OR exists (select 1 from t2 where
t1.c = t2.c)
+ let schema = Schema::new(vec![
+ Field::new("c", DataType::Int32, false),
+ Field::new("d", DataType::Int32, false),
+ ]);
+ // Filter: __correlated_sq_1.mark OR t1.d < Int32(0)
+ // Projection: t1.d
+ // LeftMark Join: Filter: t1.c = __correlated_sq_1.c
+ // TableScan: t1 projection=[c, d]
+ // SubqueryAlias: __correlated_sq_1
+ // TableScan: t2 projection=[c]
+ let table_scan1 = table_scan(Some("t1"), &schema, Some(vec![0,
1]))?.build()?;
+ let table_scan2 = table_scan(Some("t2"), &schema, Some(vec![0]))?.build()?;
+ let subquery = subquery_alias(table_scan2, "__correlated_sq_1")?;
+ let plan = LogicalPlanBuilder::from(table_scan1)
+ .join_on(
+ subquery,
+ datafusion_expr::JoinType::LeftMark,
+ vec![col("t1.c").eq(col("__correlated_sq_1.c"))],
+ )?
+ .project(vec![col("t1.d")])?
+ .filter(col("mark").or(col("t1.d").lt(lit(0))))?
+ .build()?;
+
+ let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+ let sql = unparser.plan_to_sql(&plan)?;
+ assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE (EXISTS (SELECT 1 FROM
\"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" =
\"__correlated_sq_1\".\"c\")) OR (\"t1\".\"d\" < 0))", sql.to_string());
+ Ok(())
+}
+
+#[test]
+fn test_unparse_right_semi_join() -> Result<()> {
+ // select t2.c, t2.d from t1 right semi join t2 on t1.c = t2.c where t2.c
<= 1
+ let schema = Schema::new(vec![
+ Field::new("c", DataType::Int32, false),
+ Field::new("d", DataType::Int32, false),
+ ]);
+ // Filter: t2.c <= Int64(1)
+ // RightSemi Join: t1.c = t2.c
+ // TableScan: t1 projection=[c, d]
+ // Projection: t2.c, t2.d
+ // TableScan: t2 projection=[c, d]
+ let left = table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?;
+ let right_table_scan = table_scan(Some("t2"), &schema, Some(vec![0,
1]))?.build()?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("c"), col("d")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join(
+ right,
+ datafusion_expr::JoinType::RightSemi,
+ (
+ vec![Column::from_qualified_name("t1.c")],
+ vec![Column::from_qualified_name("t2.c")],
+ ),
+ None,
+ )?
+ .filter(col("t2.c").lt_eq(lit(1i64)))?
+ .build()?;
+ let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+ let sql = unparser.plan_to_sql(&plan)?;
+ assert_eq!("SELECT \"t2\".\"c\", \"t2\".\"d\" FROM \"t2\" WHERE
(\"t2\".\"c\" <= 1) AND EXISTS (SELECT 1 FROM \"t1\" WHERE (\"t1\".\"c\" =
\"t2\".\"c\"))", sql.to_string());
+ Ok(())
+}
+
+#[test]
+fn test_unparse_right_anti_join() -> Result<()> {
+ // select t2.c, t2.d from t1 right anti join t2 on t1.c = t2.c where t2.c
<= 1
+ let schema = Schema::new(vec![
+ Field::new("c", DataType::Int32, false),
+ Field::new("d", DataType::Int32, false),
+ ]);
+ // Filter: t2.c <= Int64(1)
+ // RightAnti Join: t1.c = t2.c
+ // TableScan: t1 projection=[c, d]
+ // Projection: t2.c, t2.d
+ // TableScan: t2 projection=[c, d]
+ let left = table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?;
+ let right_table_scan = table_scan(Some("t2"), &schema, Some(vec![0,
1]))?.build()?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("c"), col("d")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join(
+ right,
+ datafusion_expr::JoinType::RightAnti,
+ (
+ vec![Column::from_qualified_name("t1.c")],
+ vec![Column::from_qualified_name("t2.c")],
+ ),
+ None,
+ )?
+ .filter(col("t2.c").lt_eq(lit(1i64)))?
+ .build()?;
+ let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+ let sql = unparser.plan_to_sql(&plan)?;
+ assert_eq!("SELECT \"t2\".\"c\", \"t2\".\"d\" FROM \"t2\" WHERE
(\"t2\".\"c\" <= 1) AND NOT EXISTS (SELECT 1 FROM \"t1\" WHERE (\"t1\".\"c\" =
\"t2\".\"c\"))", sql.to_string());
+ Ok(())
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]