This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 13addce65 put subquery's equal clause into join on clauses instead of
filter cl… (#3862)
13addce65 is described below
commit 13addce657aba922ac467c832baa54ae079284c2
Author: AssHero <[email protected]>
AuthorDate: Wed Oct 19 23:12:20 2022 +0800
put subquery's equal clause into join on clauses instead of filter cl…
(#3862)
* put subquery's equal clause into join on clauses instead of filter clauses
* only do this optimization for correlated subqueries
---
benchmarks/expected-plans/q2.txt | 43 ++++---
datafusion/core/tests/sql/subqueries.rs | 43 ++++---
.../optimizer/src/scalar_subquery_to_join.rs | 134 ++++++++++++++++-----
3 files changed, 145 insertions(+), 75 deletions(-)
diff --git a/benchmarks/expected-plans/q2.txt b/benchmarks/expected-plans/q2.txt
index 10d68cd37..c5f6fb0fd 100644
--- a/benchmarks/expected-plans/q2.txt
+++ b/benchmarks/expected-plans/q2.txt
@@ -1,25 +1,24 @@
Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST,
supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
Projection: supplier.s_acctbal, supplier.s_name, nation.n_name,
part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone,
supplier.s_comment
- Filter: partsupp.ps_supplycost = __sq_1.__value
- Inner Join: part.p_partkey = __sq_1.ps_partkey
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
- Inner Join: part.p_partkey = partsupp.ps_partkey
- Filter: part.p_size = Int32(15) AND part.p_type LIKE
Utf8("%BRASS")
- TableScan: part projection=[p_partkey, p_mfgr, p_type,
p_size]
+ Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost =
__sq_1.__value
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ Inner Join: part.p_partkey = partsupp.ps_partkey
+ Filter: part.p_size = Int32(15) AND part.p_type LIKE
Utf8("%BRASS")
+ TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size]
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_name, s_address,
s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name]
+ Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value,
alias=__sq_1
+ Aggregate: groupBy=[[partsupp.ps_partkey]],
aggr=[[MIN(partsupp.ps_supplycost)]]
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name, s_address,
s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name]
- Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS
__value, alias=__sq_1
- Aggregate: groupBy=[[partsupp.ps_partkey]],
aggr=[[MIN(partsupp.ps_supplycost)]]
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
- TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name,
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
+ TableScan: supplier projection=[s_suppkey, s_name, s_address,
s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/subqueries.rs
b/datafusion/core/tests/sql/subqueries.rs
index 803c24995..8c77d860e 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -141,29 +141,28 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name
ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
Projection: supplier.s_acctbal, supplier.s_name, nation.n_name,
part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone,
supplier.s_comment
- Filter: partsupp.ps_supplycost = __sq_1.__value
- Inner Join: part.p_partkey = __sq_1.ps_partkey
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
- Inner Join: part.p_partkey = partsupp.ps_partkey
- Filter: part.p_size = Int32(15) AND part.p_type LIKE
Utf8("%BRASS")
- TableScan: part projection=[p_partkey, p_mfgr, p_type,
p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE
Utf8("%BRASS")]
+ Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost =
__sq_1.__value
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ Inner Join: part.p_partkey = partsupp.ps_partkey
+ Filter: part.p_size = Int32(15) AND part.p_type LIKE
Utf8("%BRASS")
+ TableScan: part projection=[p_partkey, p_mfgr, p_type,
p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE
Utf8("%BRASS")]
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_name, s_address,
s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name],
partial_filters=[region.r_name = Utf8("EUROPE")]
+ Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value,
alias=__sq_1
+ Aggregate: groupBy=[[partsupp.ps_partkey]],
aggr=[[MIN(partsupp.ps_supplycost)]]
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name, s_address,
s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name],
partial_filters=[region.r_name = Utf8("EUROPE")]
- Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS
__value, alias=__sq_1
- Aggregate: groupBy=[[partsupp.ps_partkey]],
aggr=[[MIN(partsupp.ps_supplycost)]]
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
- TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name,
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name],
partial_filters=[region.r_name = Utf8("EUROPE")]"#
+ TableScan: supplier projection=[s_suppkey, s_name, s_address,
s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name],
partial_filters=[region.r_name = Utf8("EUROPE")]"#
.to_string();
assert_eq!(actual, expected);
diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs
b/datafusion/optimizer/src/scalar_subquery_to_join.rs
index d8ff6583c..0e53ecfae 100644
--- a/datafusion/optimizer/src/scalar_subquery_to_join.rs
+++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs
@@ -244,7 +244,7 @@ fn optimize_scalar(
// Grab column names to join on
let (col_exprs, other_subqry_exprs) =
find_join_exprs(subqry_filter_exprs, input.schema())?;
- let (outer_cols, subqry_cols, join_filters) =
+ let (mut outer_cols, subqry_cols, join_filters) =
exprs_to_join_cols(&col_exprs, input.schema(), false)?;
if join_filters.is_some() {
plan_err!("only joins on column equality are presently supported")?;
@@ -275,13 +275,31 @@ fn optimize_scalar(
.build()?;
// qualify the join columns for outside the subquery
- let subqry_cols: Vec<_> = subqry_cols
+ let mut subqry_cols: Vec<_> = subqry_cols
.iter()
.map(|it| Column {
relation: Some(subqry_alias.clone()),
name: it.name.clone(),
})
.collect();
+
+ let qry_expr = Expr::Column(Column {
+ relation: Some(subqry_alias),
+ name: "__value".to_string(),
+ });
+
+ // if correlated subquery's operation is column equality, put the clause
into join on clause.
+ let mut restore_where_clause = true;
+
+ if let (Operator::Eq, Expr::Column(column)) = (query_info.op,
&query_info.expr) {
+ // only do this optimization for correlated subquery
+ if !outer_cols.is_empty() {
+ outer_cols.push(column.clone());
+ subqry_cols.push(qry_expr.try_into_col().unwrap());
+ restore_where_clause = false;
+ }
+ }
+
let join_keys = (outer_cols, subqry_cols);
// join our sub query into the main plan
@@ -295,24 +313,22 @@ fn optimize_scalar(
};
// restore where in condition
- let qry_expr = Box::new(Expr::Column(Column {
- relation: Some(subqry_alias),
- name: "__value".to_string(),
- }));
- let filter_expr = if query_info.expr_on_left {
- Expr::BinaryExpr(BinaryExpr::new(
- Box::new(query_info.expr.clone()),
- query_info.op,
- qry_expr,
- ))
- } else {
- Expr::BinaryExpr(BinaryExpr::new(
- qry_expr,
- query_info.op,
- Box::new(query_info.expr.clone()),
- ))
- };
- new_plan = new_plan.filter(filter_expr)?;
+ if restore_where_clause {
+ let filter_expr = if query_info.expr_on_left {
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(query_info.expr.clone()),
+ query_info.op,
+ Box::new(qry_expr),
+ ))
+ } else {
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(qry_expr),
+ query_info.op,
+ Box::new(query_info.expr.clone()),
+ ))
+ };
+ new_plan = new_plan.filter(filter_expr)?;
+ }
// if the main query had additional expressions, restore them
if let Some(expr) = conjunction(outer_others.to_vec()) {
@@ -461,13 +477,12 @@ mod tests {
.build()?;
let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
- Filter: customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8,
o_custkey:Int64, __value:Int64;N]
- Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64,
c_name:Utf8, o_custkey:Int64, __value:Int64;N]
- TableScan: customer [c_custkey:Int64, c_name:Utf8]
- Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value,
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
- Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
- Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+ Inner Join: customer.c_custkey = __sq_1.o_custkey, customer.c_custkey =
__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value,
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
+ Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]]
[o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+ Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan,
expected);
Ok(())
@@ -677,7 +692,7 @@ mod tests {
/// Test for correlated scalar subquery filter with additional filters
#[test]
- fn scalar_subquery_additional_filters() -> Result<()> {
+ fn scalar_subquery_additional_filters_with_non_equal_clause() ->
Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
@@ -689,7 +704,7 @@ mod tests {
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
col("customer.c_custkey")
- .eq(scalar_subquery(sq))
+ .gt_eq(scalar_subquery(sq))
.and(col("c_custkey").eq(lit(1))),
)?
.project(vec![col("customer.c_custkey")])?
@@ -697,7 +712,7 @@ mod tests {
let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8,
o_custkey:Int64, __value:Int64;N]
- Filter: customer.c_custkey = __sq_1.__value [c_custkey:Int64, c_name:Utf8,
o_custkey:Int64, __value:Int64;N]
+ Filter: customer.c_custkey >= __sq_1.__value [c_custkey:Int64,
c_name:Utf8, o_custkey:Int64, __value:Int64;N]
Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64,
c_name:Utf8, o_custkey:Int64, __value:Int64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value,
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
@@ -708,6 +723,37 @@ mod tests {
Ok(())
}
+ #[test]
+ fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .aggregate(Vec::<Expr>::new(),
vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(scalar_subquery(sq))
+ .and(col("c_custkey").eq(lit(1))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
+ Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8,
o_custkey:Int64, __value:Int64;N]
+ Inner Join: customer.c_custkey = __sq_1.o_custkey, customer.c_custkey =
__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value,
alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
+ Aggregate: groupBy=[[orders.o_custkey]],
aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan,
expected);
+ Ok(())
+ }
+
/// Test for correlated scalar subquery filter with disjustions
#[test]
fn scalar_subquery_disjunction() -> Result<()> {
@@ -771,7 +817,33 @@ mod tests {
/// Test for non-correlated scalar subquery with no filters
#[test]
- fn scalar_subquery_non_correlated_no_filters() -> Result<()> {
+ fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() ->
Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .aggregate(Vec::<Expr>::new(),
vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
+ Filter: customer.c_custkey < __sq_1.__value [c_custkey:Int64, c_name:Utf8,
__value:Int64;N]
+ CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: MAX(orders.o_custkey) AS __value, alias=__sq_1
[__value:Int64;N]
+ Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]]
[MAX(orders.o_custkey):Int64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan,
expected);
+ Ok(())
+ }
+
+ #[test]
+ fn scalar_subquery_non_correlated_no_filters_with_equal_clause() ->
Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(),
vec![max(col("orders.o_custkey"))])?