This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 27b15fdcf Support non-tuple expression for exists-subquery to join 
(#5264)
27b15fdcf is described below

commit 27b15fdcf3234ba050e254ec179f51e5d92e9784
Author: ygf11 <[email protected]>
AuthorDate: Sun Feb 19 00:10:51 2023 +0800

    Support non-tuple expression for exists-subquery to join (#5264)
    
    * Support non-tuple expression for exists-subquery to join
    
    * fix tests
    
    * add tests
    
    * add comments
    
    * fix tests
    
    * fix test comment
---
 benchmarks/expected-plans/q21.txt                  |  12 +-
 benchmarks/expected-plans/q22.txt                  |   3 +-
 benchmarks/expected-plans/q4.txt                   |   5 +-
 datafusion/core/tests/sql/joins.rs                 | 203 ++++++++++++++++--
 .../optimizer/src/decorrelate_where_exists.rs      | 237 ++++++++++++---------
 datafusion/optimizer/src/decorrelate_where_in.rs   |  31 +--
 datafusion/optimizer/src/utils.rs                  |  39 +++-
 datafusion/optimizer/tests/integration-test.rs     |  15 +-
 8 files changed, 389 insertions(+), 156 deletions(-)

diff --git a/benchmarks/expected-plans/q21.txt 
b/benchmarks/expected-plans/q21.txt
index 3ef6269de..a91632df4 100644
--- a/benchmarks/expected-plans/q21.txt
+++ b/benchmarks/expected-plans/q21.txt
@@ -14,8 +14,10 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS 
LAST
                 TableScan: orders projection=[o_orderkey, o_orderstatus]
             Filter: nation.n_name = Utf8("SAUDI ARABIA")
               TableScan: nation projection=[n_nationkey, n_name]
-          SubqueryAlias: l2
-            TableScan: lineitem projection=[l_orderkey, l_suppkey]
-        SubqueryAlias: l3
-          Filter: lineitem.l_receiptdate > lineitem.l_commitdate
-            TableScan: lineitem projection=[l_orderkey, l_suppkey, 
l_commitdate, l_receiptdate]
\ No newline at end of file
+          Projection: l2.l_orderkey, l2.l_suppkey
+            SubqueryAlias: l2
+              TableScan: lineitem projection=[l_orderkey, l_suppkey]
+        Projection: l3.l_orderkey, l3.l_suppkey
+          SubqueryAlias: l3
+            Filter: lineitem.l_receiptdate > lineitem.l_commitdate
+              TableScan: lineitem projection=[l_orderkey, l_suppkey, 
l_commitdate, l_receiptdate]
\ No newline at end of file
diff --git a/benchmarks/expected-plans/q22.txt 
b/benchmarks/expected-plans/q22.txt
index 0fd7a590a..11b438085 100644
--- a/benchmarks/expected-plans/q22.txt
+++ b/benchmarks/expected-plans/q22.txt
@@ -8,7 +8,8 @@ Sort: custsale.cntrycode ASC NULLS LAST
               LeftAnti Join: customer.c_custkey = orders.o_custkey
                 Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN 
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), 
Utf8("17")])
                   TableScan: customer projection=[c_custkey, c_phone, 
c_acctbal]
-                TableScan: orders projection=[o_custkey]
+                Projection: orders.o_custkey
+                  TableScan: orders projection=[o_custkey]
               SubqueryAlias: __scalar_sq_1
                 Projection: AVG(customer.c_acctbal) AS __value
                   Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
diff --git a/benchmarks/expected-plans/q4.txt b/benchmarks/expected-plans/q4.txt
index 3610ae175..e677f3a98 100644
--- a/benchmarks/expected-plans/q4.txt
+++ b/benchmarks/expected-plans/q4.txt
@@ -4,5 +4,6 @@ Sort: orders.o_orderpriority ASC NULLS LAST
       LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey
         Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < 
Date32("8674")
           TableScan: orders projection=[o_orderkey, o_orderdate, 
o_orderpriority]
-        Filter: lineitem.l_commitdate < lineitem.l_receiptdate
-          TableScan: lineitem projection=[l_orderkey, l_commitdate, 
l_receiptdate]
\ No newline at end of file
+        Projection: lineitem.l_orderkey
+          Filter: lineitem.l_commitdate < lineitem.l_receiptdate
+            TableScan: lineitem projection=[l_orderkey, l_commitdate, 
l_receiptdate]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index 3c0aa8b3f..6d1b1e91b 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2187,9 +2187,8 @@ async fn left_anti_join() -> Result<()> {
 }
 
 #[tokio::test]
-#[ignore = "Test ignored, will be enabled after fixing the anti join plan bug"]
-// https://github.com/apache/arrow-datafusion/issues/4366
 async fn error_left_anti_join() -> Result<()> {
+    // https://github.com/apache/arrow-datafusion/issues/4366
     let test_repartition_joins = vec![true, false];
     for repartition_joins in test_repartition_joins {
         let ctx = create_left_semi_anti_join_context_with_null_ids(
@@ -2255,19 +2254,20 @@ async fn right_semi_join() -> Result<()> {
         let dataframe = ctx.sql(sql).await.expect(&msg);
         let physical_plan = dataframe.create_physical_plan().await?;
         let expected = if repartition_joins {
-            vec![ "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
-                  "  SortExec: expr=[t1_id@0 ASC NULLS LAST]",
-                  "    ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t1_int@2 as t1_int]",
-                  "      CoalesceBatchesExec: target_batch_size=4096",
-                  "        HashJoinExec: mode=Partitioned, 
join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: 
\"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", 
index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }",
-                  "          CoalesceBatchesExec: target_batch_size=4096",
-                  "            RepartitionExec: partitioning=Hash([Column { 
name: \"t2_id\", index: 0 }], 2), input_partitions=2",
-                  "              RepartitionExec: 
partitioning=RoundRobinBatch(2), input_partitions=1",
-                  "                MemoryExec: partitions=1, 
partition_sizes=[1]",
-                  "          CoalesceBatchesExec: target_batch_size=4096",
-                  "            RepartitionExec: partitioning=Hash([Column { 
name: \"t1_id\", index: 0 }], 2), input_partitions=2",
-                  "              RepartitionExec: 
partitioning=RoundRobinBatch(2), input_partitions=1",
-                  "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+            vec!["SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
+                 "  SortExec: expr=[t1_id@0 ASC NULLS LAST]",
+                 "    ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t1_int@2 as t1_int]",
+                 "      CoalesceBatchesExec: target_batch_size=4096",
+                 "        HashJoinExec: mode=Partitioned, join_type=RightSemi, 
on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 
})], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: 
NotEq, right: Column { name: \"t1_name\", index: 0 } }",
+                 "          CoalesceBatchesExec: target_batch_size=4096",
+                 "            RepartitionExec: partitioning=Hash([Column { 
name: \"t2_id\", index: 0 }], 2), input_partitions=2",
+                 "              RepartitionExec: 
partitioning=RoundRobinBatch(2), input_partitions=1",
+                 "                ProjectionExec: expr=[t2_id@0 as t2_id, 
t2_name@1 as t2_name]",
+                 "                  MemoryExec: partitions=1, 
partition_sizes=[1]",
+                 "          CoalesceBatchesExec: target_batch_size=4096",
+                 "            RepartitionExec: partitioning=Hash([Column { 
name: \"t1_id\", index: 0 }], 2), input_partitions=2",
+                 "              RepartitionExec: 
partitioning=RoundRobinBatch(2), input_partitions=1",
+                 "                MemoryExec: partitions=1, 
partition_sizes=[1]",
             ]
         } else {
             vec![
@@ -2275,7 +2275,8 @@ async fn right_semi_join() -> Result<()> {
                 "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t1_int@2 as t1_int]",
                 "    CoalesceBatchesExec: target_batch_size=4096",
                 "      HashJoinExec: mode=CollectLeft, join_type=RightSemi, 
on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 
})], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: 
NotEq, right: Column { name: \"t1_name\", index: 0 } }",
-                "        MemoryExec: partitions=1, partition_sizes=[1]",
+                "        ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as 
t2_name]",
+                "          MemoryExec: partitions=1, partition_sizes=[1]",
                 "        MemoryExec: partitions=1, partition_sizes=[1]",
             ]
         };
@@ -3393,3 +3394,173 @@ async fn left_as_inner_table_nested_loop_join() -> 
Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn exists_subquery_to_join_expr_filter() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        // exists subquery to LeftSemi join
+        let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE 
t1.t1_id + 1 > t2.t2_id * 2)";
+        let msg = format!("Creating logical plan for '{sql}'");
+        let dataframe = ctx.sql(&("explain ".to_owned() + 
sql)).await.expect(&msg);
+        let plan = dataframe.into_optimized_plan()?;
+
+        let expected = vec![
+            "Explain [plan_type:Utf8, plan:Utf8]",
+            "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+            "    LeftSemi Join:  Filter: CAST(t1.t1_id AS Int64) + Int64(1) > 
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N]",
+            "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+            "      Projection: t2.t2_id [t2_id:UInt32;N]",
+            "        TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+        ];
+        let formatted = plan.display_indent_schema().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+        );
+        let expected = vec![
+            "+-------+---------+--------+",
+            "| t1_id | t1_name | t1_int |",
+            "+-------+---------+--------+",
+            "| 22    | b       | 2      |",
+            "| 33    | c       | 3      |",
+            "| 44    | d       | 4      |",
+            "+-------+---------+--------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn exists_subquery_to_join_inner_filter() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        // exists subquery to LeftSemi join
+        let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE 
t1.t1_id + 1 > t2.t2_id * 2 AND t2.t2_int < 3)";
+        let msg = format!("Creating logical plan for '{sql}'");
+        let dataframe = ctx.sql(&("explain ".to_owned() + 
sql)).await.expect(&msg);
+        let plan = dataframe.into_optimized_plan()?;
+
+        // `t2.t2_int < 3` will be kept in the subquery filter.
+        let expected = vec![
+            "Explain [plan_type:Utf8, plan:Utf8]",
+            "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+            "    LeftSemi Join:  Filter: CAST(t1.t1_id AS Int64) + Int64(1) > 
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N]",
+            "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+            "      Projection: t2.t2_id [t2_id:UInt32;N]",
+            "        Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, 
t2_int:UInt32;N]",
+            "          TableScan: t2 projection=[t2_id, t2_int] 
[t2_id:UInt32;N, t2_int:UInt32;N]",
+        ];
+        let formatted = plan.display_indent_schema().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+        );
+        let expected = vec![
+            "+-------+---------+--------+",
+            "| t1_id | t1_name | t1_int |",
+            "+-------+---------+--------+",
+            "| 44    | d       | 4      |",
+            "+-------+---------+--------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn exists_subquery_to_join_outer_filter() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        // exists subquery to LeftSemi join
+        let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE 
t1.t1_id + 1 > t2.t2_id * 2 AND t1.t1_int < 3)";
+        let msg = format!("Creating logical plan for '{sql}'");
+        let dataframe = ctx.sql(&("explain ".to_owned() + 
sql)).await.expect(&msg);
+        let plan = dataframe.into_optimized_plan()?;
+
+        // `t1.t1_int < 3` will be moved to the filter of t1.
+        let expected = vec![
+            "Explain [plan_type:Utf8, plan:Utf8]",
+            "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+            "    LeftSemi Join:  Filter: CAST(t1.t1_id AS Int64) + Int64(1) > 
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N]",
+            "      Filter: t1.t1_int < UInt32(3) [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+            "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+            "      Projection: t2.t2_id [t2_id:UInt32;N]",
+            "        TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+        ];
+        let formatted = plan.display_indent_schema().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+        );
+        let expected = vec![
+            "+-------+---------+--------+",
+            "| t1_id | t1_name | t1_int |",
+            "+-------+---------+--------+",
+            "| 22    | b       | 2      |",
+            "+-------+---------+--------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn not_exists_subquery_to_join_expr_filter() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        // not exists subquery to LeftAnti join
+        let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT t2_id FROM t2 
WHERE t1.t1_id + 1 > t2.t2_id * 2)";
+        let msg = format!("Creating logical plan for '{sql}'");
+        let dataframe = ctx.sql(&("explain ".to_owned() + 
sql)).await.expect(&msg);
+        let plan = dataframe.into_optimized_plan()?;
+
+        let expected = vec![
+            "Explain [plan_type:Utf8, plan:Utf8]",
+            "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+            "    LeftAnti Join:  Filter: CAST(t1.t1_id AS Int64) + Int64(1) > 
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N]",
+            "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+            "      Projection: t2.t2_id [t2_id:UInt32;N]",
+            "        TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+        ];
+        let formatted = plan.display_indent_schema().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+        );
+        let expected = vec![
+            "+-------+---------+--------+",
+            "| t1_id | t1_name | t1_int |",
+            "+-------+---------+--------+",
+            "| 11    | a       | 1      |",
+            "+-------+---------+--------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs 
b/datafusion/optimizer/src/decorrelate_where_exists.rs
index bdc4afe90..3f6b160fa 100644
--- a/datafusion/optimizer/src/decorrelate_where_exists.rs
+++ b/datafusion/optimizer/src/decorrelate_where_exists.rs
@@ -16,16 +16,14 @@
 // under the License.
 
 use crate::optimizer::ApplyOrder;
-use crate::utils::{
-    conjunction, exprs_to_join_cols, find_join_exprs, split_conjunction,
-    verify_not_disjunction,
-};
+use crate::utils::{conjunction, extract_join_filters, split_conjunction};
 use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{context, Result};
+use datafusion_common::{Column, DataFusionError, Result};
 use datafusion_expr::{
     logical_plan::{Filter, JoinType, Subquery},
     Expr, LogicalPlan, LogicalPlanBuilder,
 };
+use std::collections::BTreeSet;
 use std::sync::Arc;
 
 /// Optimizer rule for rewriting subquery filters to joins
@@ -144,55 +142,68 @@ fn optimize_exists(
     query_info: &SubqueryInfo,
     outer_input: &LogicalPlan,
 ) -> Result<Option<LogicalPlan>> {
-    let subqry_filter = match query_info.query.subquery.as_ref() {
+    let maybe_subqury_filter = match query_info.query.subquery.as_ref() {
         LogicalPlan::Distinct(subqry_distinct) => match 
subqry_distinct.input.as_ref() {
-            LogicalPlan::Projection(subqry_proj) => {
-                Filter::try_from_plan(&subqry_proj.input)
-            }
+            LogicalPlan::Projection(subqry_proj) => &subqry_proj.input,
             _ => {
-                // Subquery currently only supports distinct or projection
                 return Ok(None);
             }
         },
-        LogicalPlan::Projection(subqry_proj) => 
Filter::try_from_plan(&subqry_proj.input),
+        LogicalPlan::Projection(subqry_proj) => &subqry_proj.input,
         _ => {
             // Subquery currently only supports distinct or projection
             return Ok(None);
         }
     }
-    .map_err(|e| context!("cannot optimize non-correlated subquery", e))?;
-
-    // split into filters
-    let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate);
-    verify_not_disjunction(&subqry_filter_exprs)?;
-
-    // Grab column names to join on
-    let (col_exprs, other_subqry_exprs) =
-        find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())?;
-    let (outer_cols, subqry_cols, join_filters) =
-        exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)?;
-    if subqry_cols.is_empty() || outer_cols.is_empty() {
-        // cannot optimize non-correlated subquery
+    .as_ref();
+
+    // extract join filters
+    let (join_filters, subquery_input) = 
extract_join_filters(maybe_subqury_filter)?;
+    // cannot optimize non-correlated subquery
+    if join_filters.is_empty() {
         return Ok(None);
     }
 
-    // build subquery side of join - the thing the subquery was querying
-    let mut subqry_plan = 
LogicalPlanBuilder::from(subqry_filter.input.as_ref().clone());
-    if let Some(expr) = conjunction(other_subqry_exprs) {
-        subqry_plan = subqry_plan.filter(expr)? // if the subquery had 
additional expressions, restore them
-    }
-    let subqry_plan = subqry_plan.build()?;
+    let input_schema = subquery_input.schema();
+    let subquery_cols: BTreeSet<Column> =
+        join_filters
+            .iter()
+            .try_fold(BTreeSet::new(), |mut cols, expr| {
+                let using_cols: Vec<Column> = expr
+                    .to_columns()?
+                    .into_iter()
+                    .filter(|col| input_schema.field_from_column(col).is_ok())
+                    .collect::<_>();
+
+                cols.extend(using_cols);
+                Result::<_, DataFusionError>::Ok(cols)
+            })?;
+
+    let projection_exprs: Vec<Expr> =
+        subquery_cols.into_iter().map(Expr::Column).collect();
+
+    let right = LogicalPlanBuilder::from(subquery_input)
+        .project(projection_exprs)?
+        .build()?;
 
-    let join_keys = (subqry_cols, outer_cols);
+    let join_filter = conjunction(join_filters);
 
     // join our sub query into the main plan
     let join_type = match query_info.negated {
         true => JoinType::LeftAnti,
         false => JoinType::LeftSemi,
     };
+
+    // TODO: add Distinct if the original plan is a Distinct.
     let new_plan = LogicalPlanBuilder::from(outer_input.clone())
-        .join(subqry_plan, join_type, join_keys, join_filters)?
+        .join(
+            right,
+            join_type,
+            (Vec::<Column>::new(), Vec::<Column>::new()),
+            join_filter,
+        )?
         .build()?;
+
     Ok(Some(new_plan))
 }
 
@@ -241,13 +252,14 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
-  LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, 
c_name:Utf8]
-    LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, 
c_name:Utf8]
-      TableScan: customer [c_custkey:Int64, c_name:Utf8]
-      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]
-    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, 
o_totalprice:Float64;N]"#;
-
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: orders.o_custkey = 
customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n    LeftSemi Join:  Filter: orders.o_custkey = 
customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n      TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n      Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n        TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+                        \n    Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n      TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
         assert_plan_eq(&plan, expected)
     }
 
@@ -276,13 +288,14 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
-  LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, 
c_name:Utf8]
-    TableScan: customer [c_custkey:Int64, c_name:Utf8]
-    LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
-      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]
-      TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, 
l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#;
-
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: orders.o_custkey = 
customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n    TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n    Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n      LeftSemi Join:  Filter: lineitem.l_orderkey = 
orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, 
o_totalprice:Float64;N]\
+                        \n        TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+                        \n        Projection: lineitem.l_orderkey 
[l_orderkey:Int64]\
+                        \n          TableScan: lineitem [l_orderkey:Int64, 
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, 
l_extendedprice:Float64]";
         assert_plan_eq(&plan, expected)
     }
 
@@ -305,21 +318,21 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
-  LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, 
c_name:Utf8]
-    TableScan: customer [c_custkey:Int64, c_name:Utf8]
-    Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]
-      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: customer.c_custkey = 
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n    TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n    Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n      Filter: orders.o_orderkey = Int32(1) 
[o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+                        \n        TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_plan_eq(&plan, expected)
     }
 
-    /// Test for correlated exists subquery with no columns in schema
     #[test]
     fn exists_subquery_no_cols() -> Result<()> {
         let sq = Arc::new(
             LogicalPlanBuilder::from(scan_tpch_table("orders"))
-                
.filter(col("customer.c_custkey").eq(col("customer.c_custkey")))?
+                .filter(col("customer.c_custkey").eq(lit(1u32)))?
                 .project(vec![col("orders.o_custkey")])?
                 .build()?,
         );
@@ -329,7 +342,14 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), 
&plan)
+        // Other rule will pushdown `customer.c_custkey = 1`,
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: customer.c_custkey = 
UInt32(1) [c_custkey:Int64, c_name:Utf8]\
+                        \n    TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n    Projection:  []\
+                        \n      TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+        assert_plan_eq(&plan, expected)
     }
 
     /// Test for exists subquery with both columns in schema
@@ -365,7 +385,13 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), 
&plan)
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: customer.c_custkey != 
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n    TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n    Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n      TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+        assert_plan_eq(&plan, expected)
     }
 
     /// Test for correlated exists subquery less than
@@ -383,10 +409,13 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        let expected = r#"can't optimize < column comparison"#;
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: customer.c_custkey < 
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n    TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n    Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n      TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
-        assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, 
expected);
-        Ok(())
+        assert_plan_eq(&plan, expected)
     }
 
     /// Test for correlated exists subquery filter with subquery disjunction
@@ -408,10 +437,13 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        let expected = r#"Optimizing disjunctions not supported!"#;
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: customer.c_custkey = 
orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
+                        \n    TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n    Projection: orders.o_custkey, orders.o_orderkey 
[o_custkey:Int64, o_orderkey:Int64]\
+                        \n      TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
-        assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, 
expected);
-        Ok(())
+        assert_plan_eq(&plan, expected)
     }
 
     /// Test for correlated exists without projection
@@ -446,11 +478,11 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        // Doesn't matter we projected an expression, just that we returned a 
result
-        let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
-  LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, 
c_name:Utf8]
-    TableScan: customer [c_custkey:Int64, c_name:Utf8]
-    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, 
o_totalprice:Float64;N]"#;
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  LeftSemi Join:  Filter: customer.c_custkey = 
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n    TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n    Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n      TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_plan_eq(&plan, expected)
     }
@@ -469,11 +501,12 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
-  Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
-    LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, 
c_name:Utf8]
-      TableScan: customer [c_custkey:Int64, c_name:Utf8]
-      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+                        \n  Filter: customer.c_custkey = Int32(1) 
[c_custkey:Int64, c_name:Utf8]\
+                        \n    LeftSemi Join:  Filter: customer.c_custkey = 
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+                        \n      TableScan: customer [c_custkey:Int64, 
c_name:Utf8]\
+                        \n      Projection: orders.o_custkey [o_custkey:Int64]\
+                        \n        TableScan: orders [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_plan_eq(&plan, expected)
     }
@@ -520,10 +553,11 @@ mod tests {
             .project(vec![col("test.c")])?
             .build()?;
 
-        let expected = r#"Projection: test.c [c:UInt32]
-  LeftSemi Join: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]
-    TableScan: test [a:UInt32, b:UInt32, c:UInt32]
-    TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
+        let expected  = "Projection: test.c [c:UInt32]\
+                        \n  LeftSemi Join:  Filter: test.a = sq.a [a:UInt32, 
b:UInt32, c:UInt32]\
+                        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+                        \n    Projection: sq.a [a:UInt32]\
+                        \n      TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
 
         assert_plan_eq(&plan, expected)
     }
@@ -537,10 +571,7 @@ mod tests {
             .project(vec![col("test.b")])?
             .build()?;
 
-        let expected = "cannot optimize non-correlated subquery";
-
-        assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, 
expected);
-        Ok(())
+        assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), 
&plan)
     }
 
     /// Test for single NOT exists subquery filter
@@ -552,10 +583,7 @@ mod tests {
             .project(vec![col("test.b")])?
             .build()?;
 
-        let expected = "cannot optimize non-correlated subquery";
-
-        assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, 
expected);
-        Ok(())
+        assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), 
&plan)
     }
 
     #[test]
@@ -583,18 +611,37 @@ mod tests {
             .build()?;
 
         let expected = "Projection: test.b [b:UInt32]\
-        \n  Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
-        \n    LeftSemi Join: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\
-        \n      LeftSemi Join: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\
-        \n        TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
-        \n        TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
-        \n      TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
+                        \n  Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, 
c:UInt32]\
+                        \n    LeftSemi Join:  Filter: test.a = sq2.a 
[a:UInt32, b:UInt32, c:UInt32]\
+                        \n      LeftSemi Join:  Filter: test.a = sq1.a 
[a:UInt32, b:UInt32, c:UInt32]\
+                        \n        TableScan: test [a:UInt32, b:UInt32, 
c:UInt32]\
+                        \n        Projection: sq1.a [a:UInt32]\
+                        \n          TableScan: sq1 [a:UInt32, b:UInt32, 
c:UInt32]\
+                        \n      Projection: sq2.a [a:UInt32]\
+                        \n        TableScan: sq2 [a:UInt32, b:UInt32, 
c:UInt32]";
 
-        assert_optimized_plan_eq_display_indent(
-            Arc::new(DecorrelateWhereExists::new()),
-            &plan,
-            expected,
-        );
-        Ok(())
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn exists_subquery_expr_filter() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan = test_table_scan_with_name("sq")?;
+        let subquery = LogicalPlanBuilder::from(subquery_scan)
+            .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))?
+            .project(vec![lit(1u32)])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(exists(Arc::new(subquery)))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: test.b [b:UInt32]\
+                        \n  LeftSemi Join:  Filter: UInt32(1) + sq.a > test.a 
* UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+                        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+                        \n    Projection: sq.a [a:UInt32]\
+                        \n      TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
     }
 }
diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs 
b/datafusion/optimizer/src/decorrelate_where_in.rs
index 7a9a75ff4..c8ff65f12 100644
--- a/datafusion/optimizer/src/decorrelate_where_in.rs
+++ b/datafusion/optimizer/src/decorrelate_where_in.rs
@@ -17,12 +17,11 @@
 
 use crate::alias::AliasGenerator;
 use crate::optimizer::ApplyOrder;
-use crate::utils::{conjunction, only_or_err, split_conjunction};
+use crate::utils::{conjunction, extract_join_filters, only_or_err, 
split_conjunction};
 use crate::{OptimizerConfig, OptimizerRule};
 use datafusion_common::{context, Column, DataFusionError, Result};
 use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col};
 use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
-use datafusion_expr::utils::check_all_columns_from_schema;
 use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder};
 use log::debug;
 use std::collections::{BTreeSet, HashMap};
@@ -220,34 +219,6 @@ fn optimize_where_in(
     Ok(new_plan)
 }
 
-fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec<Expr>, 
LogicalPlan)> {
-    if let LogicalPlan::Filter(plan_filter) = maybe_filter {
-        let input_schema = plan_filter.input.schema();
-        let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
-
-        let mut join_filters: Vec<Expr> = vec![];
-        let mut subquery_filters: Vec<Expr> = vec![];
-        for expr in subquery_filter_exprs {
-            let cols = expr.to_columns()?;
-            if check_all_columns_from_schema(&cols, input_schema.clone())? {
-                subquery_filters.push(expr.clone());
-            } else {
-                join_filters.push(expr.clone())
-            }
-        }
-
-        // if the subquery still has filter expressions, restore them.
-        let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());
-        if let Some(expr) = conjunction(subquery_filters) {
-            plan = plan.filter(expr)?
-        }
-
-        Ok((join_filters, plan.build()?))
-    } else {
-        Ok((vec![], maybe_filter.clone()))
-    }
-}
-
 fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: Expr) -> 
Vec<Expr> {
     filters
         .into_iter()
diff --git a/datafusion/optimizer/src/utils.rs 
b/datafusion/optimizer/src/utils.rs
index 4d9d10d51..747fa6208 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -23,10 +23,11 @@ use datafusion_common::{DFSchema, Result};
 use datafusion_expr::expr::{BinaryExpr, Sort};
 use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
 use datafusion_expr::expr_visitor::inspect_expr_pre;
+use datafusion_expr::logical_plan::LogicalPlanBuilder;
+use datafusion_expr::utils::{check_all_columns_from_schema, from_plan};
 use datafusion_expr::{
     and,
     logical_plan::{Filter, LogicalPlan},
-    utils::from_plan,
     Expr, Operator,
 };
 use std::collections::HashSet;
@@ -468,6 +469,42 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema 
{
         })
 }
 
+/// Extract join predicates from the correclated subquery.
+/// The join predicate means that the expression references columns
+/// from both the subquery and outer table or only from the outer table.
+///
+/// Returns join predicates and subquery(extracted).
+/// ```
+pub(crate) fn extract_join_filters(
+    maybe_filter: &LogicalPlan,
+) -> Result<(Vec<Expr>, LogicalPlan)> {
+    if let LogicalPlan::Filter(plan_filter) = maybe_filter {
+        let input_schema = plan_filter.input.schema();
+        let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
+
+        let mut join_filters: Vec<Expr> = vec![];
+        let mut subquery_filters: Vec<Expr> = vec![];
+        for expr in subquery_filter_exprs {
+            let cols = expr.to_columns()?;
+            if check_all_columns_from_schema(&cols, input_schema.clone())? {
+                subquery_filters.push(expr.clone());
+            } else {
+                join_filters.push(expr.clone())
+            }
+        }
+
+        // if the subquery still has filter expressions, restore them.
+        let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());
+        if let Some(expr) = conjunction(subquery_filters) {
+            plan = plan.filter(expr)?
+        }
+
+        Ok((join_filters, plan.build()?))
+    } else {
+        Ok((vec![], maybe_filter.clone()))
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/optimizer/tests/integration-test.rs 
b/datafusion/optimizer/tests/integration-test.rs
index f901e33d4..eac849e34 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -121,8 +121,9 @@ fn semi_join_with_join_filter() -> Result<()> {
     let expected = "Projection: test.col_utf8\
                     \n  LeftSemi Join: test.col_int32 = t2.col_int32 Filter: 
test.col_uint32 != t2.col_uint32\
                     \n    TableScan: test projection=[col_int32, col_uint32, 
col_utf8]\
-                    \n    SubqueryAlias: t2\
-                    \n      TableScan: test projection=[col_int32, col_uint32, 
col_utf8]";
+                    \n    Projection: t2.col_int32, t2.col_uint32\
+                    \n      SubqueryAlias: t2\
+                    \n        TableScan: test projection=[col_int32, 
col_uint32]";
     assert_eq!(expected, format!("{plan:?}"));
     Ok(())
 }
@@ -137,8 +138,9 @@ fn anti_join_with_join_filter() -> Result<()> {
     let expected = "Projection: test.col_utf8\
                     \n  LeftAnti Join: test.col_int32 = t2.col_int32 Filter: 
test.col_uint32 != t2.col_uint32\
                     \n    TableScan: test projection=[col_int32, col_uint32, 
col_utf8]\
-                    \n    SubqueryAlias: t2\
-                    \n      TableScan: test projection=[col_int32, col_uint32, 
col_utf8]";
+                    \n    Projection: t2.col_int32, t2.col_uint32\
+                    \n      SubqueryAlias: t2\
+                    \n        TableScan: test projection=[col_int32, 
col_uint32]";
     assert_eq!(expected, format!("{plan:?}"));
     Ok(())
 }
@@ -152,8 +154,9 @@ fn where_exists_distinct() -> Result<()> {
     let expected = "Projection: test.col_int32\
                     \n  LeftSemi Join: test.col_int32 = t2.col_int32\
                     \n    TableScan: test projection=[col_int32]\
-                    \n    SubqueryAlias: t2\
-                    \n      TableScan: test projection=[col_int32]";
+                    \n    Projection: t2.col_int32\
+                    \n      SubqueryAlias: t2\
+                    \n        TableScan: test projection=[col_int32]";
     assert_eq!(expected, format!("{plan:?}"));
     Ok(())
 }


Reply via email to