This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 13addce65 put subquery's equal clause into join on clauses instead of 
filter cl… (#3862)
13addce65 is described below

commit 13addce657aba922ac467c832baa54ae079284c2
Author: AssHero <[email protected]>
AuthorDate: Wed Oct 19 23:12:20 2022 +0800

    put subquery's equal clause into join on clauses instead of filter cl… 
(#3862)
    
    * put subquery's equal clause into join on clauses instead of filter clauses
    
    * only do this optimization for correlated subqueries
---
 benchmarks/expected-plans/q2.txt                   |  43 ++++---
 datafusion/core/tests/sql/subqueries.rs            |  43 ++++---
 .../optimizer/src/scalar_subquery_to_join.rs       | 134 ++++++++++++++++-----
 3 files changed, 145 insertions(+), 75 deletions(-)

diff --git a/benchmarks/expected-plans/q2.txt b/benchmarks/expected-plans/q2.txt
index 10d68cd37..c5f6fb0fd 100644
--- a/benchmarks/expected-plans/q2.txt
+++ b/benchmarks/expected-plans/q2.txt
@@ -1,25 +1,24 @@
 Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, 
supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
   Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, 
part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, 
supplier.s_comment
-    Filter: partsupp.ps_supplycost = __sq_1.__value
-      Inner Join: part.p_partkey = __sq_1.ps_partkey
-        Inner Join: nation.n_regionkey = region.r_regionkey
-          Inner Join: supplier.s_nationkey = nation.n_nationkey
-            Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
-              Inner Join: part.p_partkey = partsupp.ps_partkey
-                Filter: part.p_size = Int32(15) AND part.p_type LIKE 
Utf8("%BRASS")
-                  TableScan: part projection=[p_partkey, p_mfgr, p_type, 
p_size]
+    Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = 
__sq_1.__value
+      Inner Join: nation.n_regionkey = region.r_regionkey
+        Inner Join: supplier.s_nationkey = nation.n_nationkey
+          Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+            Inner Join: part.p_partkey = partsupp.ps_partkey
+              Filter: part.p_size = Int32(15) AND part.p_type LIKE 
Utf8("%BRASS")
+                TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size]
+              TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+            TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+          TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+        Filter: region.r_name = Utf8("EUROPE")
+          TableScan: region projection=[r_regionkey, r_name]
+      Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, 
alias=__sq_1
+        Aggregate: groupBy=[[partsupp.ps_partkey]], 
aggr=[[MIN(partsupp.ps_supplycost)]]
+          Inner Join: nation.n_regionkey = region.r_regionkey
+            Inner Join: supplier.s_nationkey = nation.n_nationkey
+              Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
                 TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
-              TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
-            TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
-          Filter: region.r_name = Utf8("EUROPE")
-            TableScan: region projection=[r_regionkey, r_name]
-        Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS 
__value, alias=__sq_1
-          Aggregate: groupBy=[[partsupp.ps_partkey]], 
aggr=[[MIN(partsupp.ps_supplycost)]]
-            Inner Join: nation.n_regionkey = region.r_regionkey
-              Inner Join: supplier.s_nationkey = nation.n_nationkey
-                Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
-                  TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
-                  TableScan: supplier projection=[s_suppkey, s_name, 
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
-                TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
-              Filter: region.r_name = Utf8("EUROPE")
-                TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
+                TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+              TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+            Filter: region.r_name = Utf8("EUROPE")
+              TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/subqueries.rs 
b/datafusion/core/tests/sql/subqueries.rs
index 803c24995..8c77d860e 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -141,29 +141,28 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
     let actual = format!("{}", plan.display_indent());
     let expected = r#"Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name 
ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
   Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, 
part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, 
supplier.s_comment
-    Filter: partsupp.ps_supplycost = __sq_1.__value
-      Inner Join: part.p_partkey = __sq_1.ps_partkey
-        Inner Join: nation.n_regionkey = region.r_regionkey
-          Inner Join: supplier.s_nationkey = nation.n_nationkey
-            Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
-              Inner Join: part.p_partkey = partsupp.ps_partkey
-                Filter: part.p_size = Int32(15) AND part.p_type LIKE 
Utf8("%BRASS")
-                  TableScan: part projection=[p_partkey, p_mfgr, p_type, 
p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE 
Utf8("%BRASS")]
+    Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = 
__sq_1.__value
+      Inner Join: nation.n_regionkey = region.r_regionkey
+        Inner Join: supplier.s_nationkey = nation.n_nationkey
+          Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+            Inner Join: part.p_partkey = partsupp.ps_partkey
+              Filter: part.p_size = Int32(15) AND part.p_type LIKE 
Utf8("%BRASS")
+                TableScan: part projection=[p_partkey, p_mfgr, p_type, 
p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE 
Utf8("%BRASS")]
+              TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+            TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+          TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+        Filter: region.r_name = Utf8("EUROPE")
+          TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[region.r_name = Utf8("EUROPE")]
+      Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, 
alias=__sq_1
+        Aggregate: groupBy=[[partsupp.ps_partkey]], 
aggr=[[MIN(partsupp.ps_supplycost)]]
+          Inner Join: nation.n_regionkey = region.r_regionkey
+            Inner Join: supplier.s_nationkey = nation.n_nationkey
+              Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
                 TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
-              TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
-            TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
-          Filter: region.r_name = Utf8("EUROPE")
-            TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[region.r_name = Utf8("EUROPE")]
-        Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS 
__value, alias=__sq_1
-          Aggregate: groupBy=[[partsupp.ps_partkey]], 
aggr=[[MIN(partsupp.ps_supplycost)]]
-            Inner Join: nation.n_regionkey = region.r_regionkey
-              Inner Join: supplier.s_nationkey = nation.n_nationkey
-                Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
-                  TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
-                  TableScan: supplier projection=[s_suppkey, s_name, 
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
-                TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
-              Filter: region.r_name = Utf8("EUROPE")
-                TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[region.r_name = Utf8("EUROPE")]"#
+                TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+              TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+            Filter: region.r_name = Utf8("EUROPE")
+              TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[region.r_name = Utf8("EUROPE")]"#
         .to_string();
     assert_eq!(actual, expected);
 
diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs 
b/datafusion/optimizer/src/scalar_subquery_to_join.rs
index d8ff6583c..0e53ecfae 100644
--- a/datafusion/optimizer/src/scalar_subquery_to_join.rs
+++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs
@@ -244,7 +244,7 @@ fn optimize_scalar(
     // Grab column names to join on
     let (col_exprs, other_subqry_exprs) =
         find_join_exprs(subqry_filter_exprs, input.schema())?;
-    let (outer_cols, subqry_cols, join_filters) =
+    let (mut outer_cols, subqry_cols, join_filters) =
         exprs_to_join_cols(&col_exprs, input.schema(), false)?;
     if join_filters.is_some() {
         plan_err!("only joins on column equality are presently supported")?;
@@ -275,13 +275,31 @@ fn optimize_scalar(
         .build()?;
 
     // qualify the join columns for outside the subquery
-    let subqry_cols: Vec<_> = subqry_cols
+    let mut subqry_cols: Vec<_> = subqry_cols
         .iter()
         .map(|it| Column {
             relation: Some(subqry_alias.clone()),
             name: it.name.clone(),
         })
         .collect();
+
+    let qry_expr = Expr::Column(Column {
+        relation: Some(subqry_alias),
+        name: "__value".to_string(),
+    });
+
+    // if correlated subquery's operation is column equality, put the clause 
into join on clause.
+    let mut restore_where_clause = true;
+
+    if let (Operator::Eq, Expr::Column(column)) = (query_info.op, 
&query_info.expr) {
+        // only do this optimization for correlated subquery
+        if !outer_cols.is_empty() {
+            outer_cols.push(column.clone());
+            subqry_cols.push(qry_expr.try_into_col().unwrap());
+            restore_where_clause = false;
+        }
+    }
+
     let join_keys = (outer_cols, subqry_cols);
 
     // join our sub query into the main plan
@@ -295,24 +313,22 @@ fn optimize_scalar(
     };
 
     // restore where in condition
-    let qry_expr = Box::new(Expr::Column(Column {
-        relation: Some(subqry_alias),
-        name: "__value".to_string(),
-    }));
-    let filter_expr = if query_info.expr_on_left {
-        Expr::BinaryExpr(BinaryExpr::new(
-            Box::new(query_info.expr.clone()),
-            query_info.op,
-            qry_expr,
-        ))
-    } else {
-        Expr::BinaryExpr(BinaryExpr::new(
-            qry_expr,
-            query_info.op,
-            Box::new(query_info.expr.clone()),
-        ))
-    };
-    new_plan = new_plan.filter(filter_expr)?;
+    if restore_where_clause {
+        let filter_expr = if query_info.expr_on_left {
+            Expr::BinaryExpr(BinaryExpr::new(
+                Box::new(query_info.expr.clone()),
+                query_info.op,
+                Box::new(qry_expr),
+            ))
+        } else {
+            Expr::BinaryExpr(BinaryExpr::new(
+                Box::new(qry_expr),
+                query_info.op,
+                Box::new(query_info.expr.clone()),
+            ))
+        };
+        new_plan = new_plan.filter(filter_expr)?;
+    }
 
     // if the main query had additional expressions, restore them
     if let Some(expr) = conjunction(outer_others.to_vec()) {
@@ -461,13 +477,12 @@ mod tests {
             .build()?;
 
         let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
-  Filter: customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8, 
o_custkey:Int64, __value:Int64;N]
-    Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, 
c_name:Utf8, o_custkey:Int64, __value:Int64;N]
-      TableScan: customer [c_custkey:Int64, c_name:Utf8]
-      Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, 
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
-        Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
-          Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
-            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+  Inner Join: customer.c_custkey = __sq_1.o_custkey, customer.c_custkey = 
__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+    TableScan: customer [c_custkey:Int64, c_name:Utf8]
+    Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, 
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
+      Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] 
[o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+        Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, 
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
 
         assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, 
expected);
         Ok(())
@@ -677,7 +692,7 @@ mod tests {
 
     /// Test for correlated scalar subquery filter with additional filters
     #[test]
-    fn scalar_subquery_additional_filters() -> Result<()> {
+    fn scalar_subquery_additional_filters_with_non_equal_clause() -> 
Result<()> {
         let sq = Arc::new(
             LogicalPlanBuilder::from(scan_tpch_table("orders"))
                 .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
@@ -689,7 +704,7 @@ mod tests {
         let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
             .filter(
                 col("customer.c_custkey")
-                    .eq(scalar_subquery(sq))
+                    .gt_eq(scalar_subquery(sq))
                     .and(col("c_custkey").eq(lit(1))),
             )?
             .project(vec![col("customer.c_custkey")])?
@@ -697,7 +712,7 @@ mod tests {
 
         let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
   Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, 
o_custkey:Int64, __value:Int64;N]
-    Filter: customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8, 
o_custkey:Int64, __value:Int64;N]
+    Filter: customer.c_custkey >= __sq_1.__value [c_custkey:Int64, 
c_name:Utf8, o_custkey:Int64, __value:Int64;N]
       Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, 
c_name:Utf8, o_custkey:Int64, __value:Int64;N]
         TableScan: customer [c_custkey:Int64, c_name:Utf8]
         Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, 
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
@@ -708,6 +723,37 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
+        let sq = Arc::new(
+            LogicalPlanBuilder::from(scan_tpch_table("orders"))
+                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+                .aggregate(Vec::<Expr>::new(), 
vec![max(col("orders.o_custkey"))])?
+                .project(vec![max(col("orders.o_custkey"))])?
+                .build()?,
+        );
+
+        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+            .filter(
+                col("customer.c_custkey")
+                    .eq(scalar_subquery(sq))
+                    .and(col("c_custkey").eq(lit(1))),
+            )?
+            .project(vec![col("customer.c_custkey")])?
+            .build()?;
+
+        let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
+  Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, 
o_custkey:Int64, __value:Int64;N]
+    Inner Join: customer.c_custkey = __sq_1.o_custkey, customer.c_custkey = 
__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+      TableScan: customer [c_custkey:Int64, c_name:Utf8]
+      Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value, 
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
+        Aggregate: groupBy=[[orders.o_custkey]], 
aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+        assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, 
expected);
+        Ok(())
+    }
+
     /// Test for correlated scalar subquery filter with disjustions
     #[test]
     fn scalar_subquery_disjunction() -> Result<()> {
@@ -771,7 +817,33 @@ mod tests {
 
     /// Test for non-correlated scalar subquery with no filters
     #[test]
-    fn scalar_subquery_non_correlated_no_filters() -> Result<()> {
+    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> 
Result<()> {
+        let sq = Arc::new(
+            LogicalPlanBuilder::from(scan_tpch_table("orders"))
+                .aggregate(Vec::<Expr>::new(), 
vec![max(col("orders.o_custkey"))])?
+                .project(vec![max(col("orders.o_custkey"))])?
+                .build()?,
+        );
+
+        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
+            .project(vec![col("customer.c_custkey")])?
+            .build()?;
+
+        let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
+  Filter: customer.c_custkey < __sq_1.__value [c_custkey:Int64, c_name:Utf8, 
__value:Int64;N]
+    CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
+      TableScan: customer [c_custkey:Int64, c_name:Utf8]
+      Projection: MAX(orders.o_custkey) AS __value, alias=__sq_1 
[__value:Int64;N]
+        Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] 
[MAX(orders.o_custkey):Int64;N]
+          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, 
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+        assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, 
expected);
+        Ok(())
+    }
+
+    #[test]
+    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> 
Result<()> {
         let sq = Arc::new(
             LogicalPlanBuilder::from(scan_tpch_table("orders"))
                 .aggregate(Vec::<Expr>::new(), 
vec![max(col("orders.o_custkey"))])?

Reply via email to