This is an automated email from the ASF dual-hosted git repository.

goldmedal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d20b6d18b9 fix: unparsing left/ right semi/mark join (#15212)
d20b6d18b9 is described below

commit d20b6d18b901708e48f965385af2119fed01a4c7
Author: Chen Chongchen <[email protected]>
AuthorDate: Wed Mar 19 19:16:14 2025 +0800

    fix: unparsing left/ right semi/mark join (#15212)
    
    * fix: unparse semi/mark join
    
    * recursive
    
    * fix use
    
    * update
    
    * stackoverflow
    
    * update stack size
    
    * update test
    
    * fix test
    
    * format
    
    * refine
    
    * refine ci based on goldmedal's suggestion
---
 .github/workflows/extended.yml            |   2 +-
 datafusion/sql/src/unparser/ast.rs        |  34 +++++-
 datafusion/sql/src/unparser/expr.rs       |   3 +-
 datafusion/sql/src/unparser/plan.rs       |  90 +++++++++++++---
 datafusion/sql/tests/cases/plan_to_sql.rs | 173 +++++++++++++++++++++++++++++-
 5 files changed, 280 insertions(+), 22 deletions(-)

diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml
index 9ee72653b2..3942e75257 100644
--- a/.github/workflows/extended.yml
+++ b/.github/workflows/extended.yml
@@ -81,7 +81,7 @@ jobs:
       - name: Run tests (excluding doctests)
         env:
           RUST_BACKTRACE: 1
-        run: cargo test --profile ci --exclude datafusion-examples --exclude 
datafusion-benchmarks --workspace --lib --tests --bins --features 
avro,json,backtrace,extended_tests
+        run: cargo test --profile ci --exclude datafusion-examples --exclude 
datafusion-benchmarks --workspace --lib --tests --bins --features 
avro,json,backtrace,extended_tests,recursive_protection
       - name: Verify Working Directory Clean
         run: git diff --exit-code
       - name: Cleanup
diff --git a/datafusion/sql/src/unparser/ast.rs 
b/datafusion/sql/src/unparser/ast.rs
index 6d77c01ea8..211ae84a00 100644
--- a/datafusion/sql/src/unparser/ast.rs
+++ b/datafusion/sql/src/unparser/ast.rs
@@ -16,9 +16,10 @@
 // under the License.
 
 use core::fmt;
+use std::ops::ControlFlow;
 
-use sqlparser::ast;
 use sqlparser::ast::helpers::attached_token::AttachedToken;
+use sqlparser::ast::{self, visit_expressions_mut};
 
 #[derive(Clone)]
 pub struct QueryBuilder {
@@ -176,6 +177,37 @@ impl SelectBuilder {
         self.lateral_views = value;
         self
     }
+
+    /// Replaces the selection with a new value.
+    ///
+    /// This function is used to replace a specific expression within the 
selection.
+    /// Unlike the `selection` method which combines existing and new 
selections with AND,
+    /// this method searches for and replaces occurrences of a specific 
expression.
+    ///
+    /// This method is primarily used to modify LEFT MARK JOIN expressions.
+    /// When processing a LEFT MARK JOIN, we need to replace the placeholder 
expression
+    /// with the actual join condition in the selection clause.
+    ///
+    /// # Arguments
+    ///
+    /// * `existing_expr` - The expression to replace
+    /// * `value` - The new expression to set as the selection
+    pub fn replace_mark(
+        &mut self,
+        existing_expr: &ast::Expr,
+        value: &ast::Expr,
+    ) -> &mut Self {
+        if let Some(selection) = &mut self.selection {
+            visit_expressions_mut(selection, |expr| {
+                if expr == existing_expr {
+                    *expr = value.clone();
+                }
+                ControlFlow::<()>::Continue(())
+            });
+        }
+        self
+    }
+
     pub fn selection(&mut self, value: Option<ast::Expr>) -> &mut Self {
         // With filter pushdown optimization, the LogicalPlan can have filters 
defined as part of `TableScan` and `Filter` nodes.
         // To avoid overwriting one of the filters, we combine the existing 
filter with the additional filter.
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index 7905207faf..5e74849cd9 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -94,6 +94,7 @@ impl Unparser<'_> {
         Ok(root_expr)
     }
 
+    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
     fn expr_to_sql_inner(&self, expr: &Expr) -> Result<ast::Expr> {
         match expr {
             Expr::InList(InList {
@@ -674,7 +675,7 @@ impl Unparser<'_> {
         }
     }
 
-    fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
+    pub fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
         if let Some(table_ref) = &col.relation {
             let mut id = if self.dialect.full_qualified_col() {
                 table_ref.to_vec()
diff --git a/datafusion/sql/src/unparser/plan.rs 
b/datafusion/sql/src/unparser/plan.rs
index b14fbdff23..507a6b2761 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -322,6 +322,7 @@ impl Unparser<'_> {
         }
     }
 
+    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
     fn select_to_sql_recursively(
         &self,
         plan: &LogicalPlan,
@@ -566,14 +567,20 @@ impl Unparser<'_> {
             }
             LogicalPlan::Join(join) => {
                 let mut table_scan_filters = vec![];
+                let (left_plan, right_plan) = match join.join_type {
+                    JoinType::RightSemi | JoinType::RightAnti => {
+                        (&join.right, &join.left)
+                    }
+                    _ => (&join.left, &join.right),
+                };
 
                 let left_plan =
-                    match 
try_transform_to_simple_table_scan_with_filters(&join.left)? {
+                    match 
try_transform_to_simple_table_scan_with_filters(left_plan)? {
                         Some((plan, filters)) => {
                             table_scan_filters.extend(filters);
                             Arc::new(plan)
                         }
-                        None => Arc::clone(&join.left),
+                        None => Arc::clone(left_plan),
                     };
 
                 self.select_to_sql_recursively(
@@ -584,12 +591,12 @@ impl Unparser<'_> {
                 )?;
 
                 let right_plan =
-                    match 
try_transform_to_simple_table_scan_with_filters(&join.right)? {
+                    match 
try_transform_to_simple_table_scan_with_filters(right_plan)? {
                         Some((plan, filters)) => {
                             table_scan_filters.extend(filters);
                             Arc::new(plan)
                         }
-                        None => Arc::clone(&join.right),
+                        None => Arc::clone(right_plan),
                     };
 
                 let mut right_relation = RelationBuilder::default();
@@ -641,19 +648,70 @@ impl Unparser<'_> {
                     &mut right_relation,
                 )?;
 
-                let Ok(Some(relation)) = right_relation.build() else {
-                    return internal_err!("Failed to build right relation");
-                };
-
-                let ast_join = ast::Join {
-                    relation,
-                    global: false,
-                    join_operator: self
-                        .join_operator_to_sql(join.join_type, 
join_constraint)?,
+                match join.join_type {
+                    JoinType::LeftSemi
+                    | JoinType::LeftAnti
+                    | JoinType::LeftMark
+                    | JoinType::RightSemi
+                    | JoinType::RightAnti => {
+                        let mut query_builder = QueryBuilder::default();
+                        let mut from = TableWithJoinsBuilder::default();
+                        let mut exists_select: SelectBuilder = 
SelectBuilder::default();
+                        from.relation(right_relation);
+                        exists_select.push_from(from);
+                        if let Some(filter) = &join.filter {
+                            
exists_select.selection(Some(self.expr_to_sql(filter)?));
+                        }
+                        for (left, right) in &join.on {
+                            exists_select.selection(Some(
+                                
self.expr_to_sql(&left.clone().eq(right.clone()))?,
+                            ));
+                        }
+                        
exists_select.projection(vec![ast::SelectItem::UnnamedExpr(
+                            
ast::Expr::Value(ast::Value::Number("1".to_string(), false)),
+                        )]);
+                        query_builder.body(Box::new(SetExpr::Select(Box::new(
+                            exists_select.build()?,
+                        ))));
+
+                        let negated = match join.join_type {
+                            JoinType::LeftSemi
+                            | JoinType::RightSemi
+                            | JoinType::LeftMark => false,
+                            JoinType::LeftAnti | JoinType::RightAnti => true,
+                            _ => unreachable!(),
+                        };
+                        let exists_expr = ast::Expr::Exists {
+                            subquery: Box::new(query_builder.build()?),
+                            negated,
+                        };
+                        if join.join_type == JoinType::LeftMark {
+                            let (table_ref, _) = 
right_plan.schema().qualified_field(0);
+                            let column = self
+                                .col_to_sql(&Column::new(table_ref.cloned(), 
"mark"))?;
+                            select.replace_mark(&column, &exists_expr);
+                        } else {
+                            select.selection(Some(exists_expr));
+                        }
+                    }
+                    JoinType::Inner
+                    | JoinType::Left
+                    | JoinType::Right
+                    | JoinType::Full => {
+                        let Ok(Some(relation)) = right_relation.build() else {
+                            return internal_err!("Failed to build right 
relation");
+                        };
+                        let ast_join = ast::Join {
+                            relation,
+                            global: false,
+                            join_operator: self
+                                .join_operator_to_sql(join.join_type, 
join_constraint)?,
+                        };
+                        let mut from = select.pop_from().unwrap();
+                        from.push_join(ast_join);
+                        select.push_from(from);
+                    }
                 };
-                let mut from = select.pop_from().unwrap();
-                from.push_join(ast_join);
-                select.push_from(from);
 
                 Ok(())
             }
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 0abc890dfa..c3a28f050f 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -16,7 +16,9 @@
 // under the License.
 
 use arrow::datatypes::{DataType, Field, Schema};
-use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, 
TableReference};
+use datafusion_common::{
+    assert_contains, Column, DFSchema, DFSchemaRef, Result, TableReference,
+};
 use datafusion_expr::test::function_stub::{
     count_udaf, max_udaf, min_udaf, sum, sum_udaf,
 };
@@ -32,7 +34,8 @@ use datafusion_functions_window::rank::rank_udwf;
 use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
 use datafusion_sql::unparser::dialect::{
     CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, 
DefaultDialect,
-    Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect, 
SqliteDialect,
+    Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect,
+    PostgreSqlDialect as UnparserPostgreSqlDialect, SqliteDialect,
 };
 use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};
 use sqlparser::ast::Statement;
@@ -43,7 +46,7 @@ use std::{fmt, vec};
 
 use crate::common::{MockContextProvider, MockSessionState};
 use datafusion_expr::builder::{
-    project, table_scan_with_filter_and_fetch, table_scan_with_filters,
+    project, subquery_alias, table_scan_with_filter_and_fetch, 
table_scan_with_filters,
 };
 use datafusion_functions::core::planner::CoreFunctionPlanner;
 use datafusion_functions_nested::extract::array_element_udf;
@@ -1746,3 +1749,167 @@ fn test_unparse_subquery_alias_with_table_pushdown() -> 
Result<()> {
     assert_eq!(sql.to_string(), expected);
     Ok(())
 }
+
+#[test]
+fn test_unparse_left_anti_join() -> Result<()> {
+    // select t1.d from t1 where c not in (select c from t2)
+    let schema = Schema::new(vec![
+        Field::new("c", DataType::Int32, false),
+        Field::new("d", DataType::Int32, false),
+    ]);
+
+    // LeftAnti Join: t1.c = __correlated_sq_1.c
+    //   TableScan: t1 projection=[c]
+    //   SubqueryAlias: __correlated_sq_1
+    //     TableScan: t2 projection=[c]
+
+    let table_scan1 = table_scan(Some("t1"), &schema, Some(vec![0, 
1]))?.build()?;
+    let table_scan2 = table_scan(Some("t2"), &schema, Some(vec![0]))?.build()?;
+    let subquery = subquery_alias(table_scan2, "__correlated_sq_1")?;
+    let plan = LogicalPlanBuilder::from(table_scan1)
+        .project(vec![col("t1.d")])?
+        .join_on(
+            subquery,
+            datafusion_expr::JoinType::LeftAnti,
+            vec![col("t1.c").eq(col("__correlated_sq_1.c"))],
+        )?
+        .build()?;
+
+    let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+    let sql = unparser.plan_to_sql(&plan)?;
+    assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE NOT EXISTS (SELECT 1 
FROM \"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" = 
\"__correlated_sq_1\".\"c\"))", sql.to_string());
+    Ok(())
+}
+
+#[test]
+fn test_unparse_left_semi_join() -> Result<()> {
+    // select t1.d from t1 where c in (select c from t2)
+    let schema = Schema::new(vec![
+        Field::new("c", DataType::Int32, false),
+        Field::new("d", DataType::Int32, false),
+    ]);
+
+    // LeftSemi Join: t1.c = __correlated_sq_1.c
+    //   TableScan: t1 projection=[c]
+    //   SubqueryAlias: __correlated_sq_1
+    //     TableScan: t2 projection=[c]
+
+    let table_scan1 = table_scan(Some("t1"), &schema, Some(vec![0, 
1]))?.build()?;
+    let table_scan2 = table_scan(Some("t2"), &schema, Some(vec![0]))?.build()?;
+    let subquery = subquery_alias(table_scan2, "__correlated_sq_1")?;
+    let plan = LogicalPlanBuilder::from(table_scan1)
+        .project(vec![col("t1.d")])?
+        .join_on(
+            subquery,
+            datafusion_expr::JoinType::LeftSemi,
+            vec![col("t1.c").eq(col("__correlated_sq_1.c"))],
+        )?
+        .build()?;
+
+    let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+    let sql = unparser.plan_to_sql(&plan)?;
+    assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE EXISTS (SELECT 1 FROM 
\"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" = 
\"__correlated_sq_1\".\"c\"))", sql.to_string());
+    Ok(())
+}
+
+#[test]
+fn test_unparse_left_mark_join() -> Result<()> {
+    // select t1.d from t1 where t1.d < 0 OR exists (select 1 from t2 where 
t1.c = t2.c)
+    let schema = Schema::new(vec![
+        Field::new("c", DataType::Int32, false),
+        Field::new("d", DataType::Int32, false),
+    ]);
+    // Filter: __correlated_sq_1.mark OR t1.d < Int32(0)
+    //   Projection: t1.d
+    //     LeftMark Join:  Filter: t1.c = __correlated_sq_1.c
+    //       TableScan: t1 projection=[c, d]
+    //       SubqueryAlias: __correlated_sq_1
+    //         TableScan: t2 projection=[c]
+    let table_scan1 = table_scan(Some("t1"), &schema, Some(vec![0, 
1]))?.build()?;
+    let table_scan2 = table_scan(Some("t2"), &schema, Some(vec![0]))?.build()?;
+    let subquery = subquery_alias(table_scan2, "__correlated_sq_1")?;
+    let plan = LogicalPlanBuilder::from(table_scan1)
+        .join_on(
+            subquery,
+            datafusion_expr::JoinType::LeftMark,
+            vec![col("t1.c").eq(col("__correlated_sq_1.c"))],
+        )?
+        .project(vec![col("t1.d")])?
+        .filter(col("mark").or(col("t1.d").lt(lit(0))))?
+        .build()?;
+
+    let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+    let sql = unparser.plan_to_sql(&plan)?;
+    assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE (EXISTS (SELECT 1 FROM 
\"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" = 
\"__correlated_sq_1\".\"c\")) OR (\"t1\".\"d\" < 0))", sql.to_string());
+    Ok(())
+}
+
+#[test]
+fn test_unparse_right_semi_join() -> Result<()> {
+    // select t2.c, t2.d from t1 right semi join t2 on t1.c = t2.c where t2.c 
<= 1
+    let schema = Schema::new(vec![
+        Field::new("c", DataType::Int32, false),
+        Field::new("d", DataType::Int32, false),
+    ]);
+    // Filter: t2.c <= Int64(1)
+    //   RightSemi Join: t1.c = t2.c
+    //     TableScan: t1 projection=[c, d]
+    //     Projection: t2.c, t2.d
+    //       TableScan: t2 projection=[c, d]
+    let left = table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?;
+    let right_table_scan = table_scan(Some("t2"), &schema, Some(vec![0, 
1]))?.build()?;
+    let right = LogicalPlanBuilder::from(right_table_scan)
+        .project(vec![col("c"), col("d")])?
+        .build()?;
+    let plan = LogicalPlanBuilder::from(left)
+        .join(
+            right,
+            datafusion_expr::JoinType::RightSemi,
+            (
+                vec![Column::from_qualified_name("t1.c")],
+                vec![Column::from_qualified_name("t2.c")],
+            ),
+            None,
+        )?
+        .filter(col("t2.c").lt_eq(lit(1i64)))?
+        .build()?;
+    let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+    let sql = unparser.plan_to_sql(&plan)?;
+    assert_eq!("SELECT \"t2\".\"c\", \"t2\".\"d\" FROM \"t2\" WHERE 
(\"t2\".\"c\" <= 1) AND EXISTS (SELECT 1 FROM \"t1\" WHERE (\"t1\".\"c\" = 
\"t2\".\"c\"))", sql.to_string());
+    Ok(())
+}
+
+#[test]
+fn test_unparse_right_anti_join() -> Result<()> {
+    // select t2.c, t2.d from t1 right anti join t2 on t1.c = t2.c where t2.c 
<= 1
+    let schema = Schema::new(vec![
+        Field::new("c", DataType::Int32, false),
+        Field::new("d", DataType::Int32, false),
+    ]);
+    // Filter: t2.c <= Int64(1)
+    //   RightAnti Join: t1.c = t2.c
+    //     TableScan: t1 projection=[c, d]
+    //     Projection: t2.c, t2.d
+    //       TableScan: t2 projection=[c, d]
+    let left = table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?;
+    let right_table_scan = table_scan(Some("t2"), &schema, Some(vec![0, 
1]))?.build()?;
+    let right = LogicalPlanBuilder::from(right_table_scan)
+        .project(vec![col("c"), col("d")])?
+        .build()?;
+    let plan = LogicalPlanBuilder::from(left)
+        .join(
+            right,
+            datafusion_expr::JoinType::RightAnti,
+            (
+                vec![Column::from_qualified_name("t1.c")],
+                vec![Column::from_qualified_name("t2.c")],
+            ),
+            None,
+        )?
+        .filter(col("t2.c").lt_eq(lit(1i64)))?
+        .build()?;
+    let unparser = Unparser::new(&UnparserPostgreSqlDialect {});
+    let sql = unparser.plan_to_sql(&plan)?;
+    assert_eq!("SELECT \"t2\".\"c\", \"t2\".\"d\" FROM \"t2\" WHERE 
(\"t2\".\"c\" <= 1) AND NOT EXISTS (SELECT 1 FROM \"t1\" WHERE (\"t1\".\"c\" = 
\"t2\".\"c\"))", sql.to_string());
+    Ok(())
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to