This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 27b15fdcf Support non-tuple expression for exists-subquery to join
(#5264)
27b15fdcf is described below
commit 27b15fdcf3234ba050e254ec179f51e5d92e9784
Author: ygf11 <[email protected]>
AuthorDate: Sun Feb 19 00:10:51 2023 +0800
Support non-tuple expression for exists-subquery to join (#5264)
* Support non-tuple expression for exists-subquery to join
* fix tests
* add tests
* add comments
* fix tests
* fix test comment
---
benchmarks/expected-plans/q21.txt | 12 +-
benchmarks/expected-plans/q22.txt | 3 +-
benchmarks/expected-plans/q4.txt | 5 +-
datafusion/core/tests/sql/joins.rs | 203 ++++++++++++++++--
.../optimizer/src/decorrelate_where_exists.rs | 237 ++++++++++++---------
datafusion/optimizer/src/decorrelate_where_in.rs | 31 +--
datafusion/optimizer/src/utils.rs | 39 +++-
datafusion/optimizer/tests/integration-test.rs | 15 +-
8 files changed, 389 insertions(+), 156 deletions(-)
diff --git a/benchmarks/expected-plans/q21.txt
b/benchmarks/expected-plans/q21.txt
index 3ef6269de..a91632df4 100644
--- a/benchmarks/expected-plans/q21.txt
+++ b/benchmarks/expected-plans/q21.txt
@@ -14,8 +14,10 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS
LAST
TableScan: orders projection=[o_orderkey, o_orderstatus]
Filter: nation.n_name = Utf8("SAUDI ARABIA")
TableScan: nation projection=[n_nationkey, n_name]
- SubqueryAlias: l2
- TableScan: lineitem projection=[l_orderkey, l_suppkey]
- SubqueryAlias: l3
- Filter: lineitem.l_receiptdate > lineitem.l_commitdate
- TableScan: lineitem projection=[l_orderkey, l_suppkey,
l_commitdate, l_receiptdate]
\ No newline at end of file
+ Projection: l2.l_orderkey, l2.l_suppkey
+ SubqueryAlias: l2
+ TableScan: lineitem projection=[l_orderkey, l_suppkey]
+ Projection: l3.l_orderkey, l3.l_suppkey
+ SubqueryAlias: l3
+ Filter: lineitem.l_receiptdate > lineitem.l_commitdate
+ TableScan: lineitem projection=[l_orderkey, l_suppkey,
l_commitdate, l_receiptdate]
\ No newline at end of file
diff --git a/benchmarks/expected-plans/q22.txt
b/benchmarks/expected-plans/q22.txt
index 0fd7a590a..11b438085 100644
--- a/benchmarks/expected-plans/q22.txt
+++ b/benchmarks/expected-plans/q22.txt
@@ -8,7 +8,8 @@ Sort: custsale.cntrycode ASC NULLS LAST
LeftAnti Join: customer.c_custkey = orders.o_custkey
Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"),
Utf8("17")])
TableScan: customer projection=[c_custkey, c_phone,
c_acctbal]
- TableScan: orders projection=[o_custkey]
+ Projection: orders.o_custkey
+ TableScan: orders projection=[o_custkey]
SubqueryAlias: __scalar_sq_1
Projection: AVG(customer.c_acctbal) AS __value
Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
diff --git a/benchmarks/expected-plans/q4.txt b/benchmarks/expected-plans/q4.txt
index 3610ae175..e677f3a98 100644
--- a/benchmarks/expected-plans/q4.txt
+++ b/benchmarks/expected-plans/q4.txt
@@ -4,5 +4,6 @@ Sort: orders.o_orderpriority ASC NULLS LAST
LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey
Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate <
Date32("8674")
TableScan: orders projection=[o_orderkey, o_orderdate,
o_orderpriority]
- Filter: lineitem.l_commitdate < lineitem.l_receiptdate
- TableScan: lineitem projection=[l_orderkey, l_commitdate,
l_receiptdate]
\ No newline at end of file
+ Projection: lineitem.l_orderkey
+ Filter: lineitem.l_commitdate < lineitem.l_receiptdate
+ TableScan: lineitem projection=[l_orderkey, l_commitdate,
l_receiptdate]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/joins.rs
b/datafusion/core/tests/sql/joins.rs
index 3c0aa8b3f..6d1b1e91b 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2187,9 +2187,8 @@ async fn left_anti_join() -> Result<()> {
}
#[tokio::test]
-#[ignore = "Test ignored, will be enabled after fixing the anti join plan bug"]
-// https://github.com/apache/arrow-datafusion/issues/4366
async fn error_left_anti_join() -> Result<()> {
+ // https://github.com/apache/arrow-datafusion/issues/4366
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_left_semi_anti_join_context_with_null_ids(
@@ -2255,19 +2254,20 @@ async fn right_semi_join() -> Result<()> {
let dataframe = ctx.sql(sql).await.expect(&msg);
let physical_plan = dataframe.create_physical_plan().await?;
let expected = if repartition_joins {
- vec![ "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
- " SortExec: expr=[t1_id@0 ASC NULLS LAST]",
- " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t1_int@2 as t1_int]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " HashJoinExec: mode=Partitioned,
join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name:
\"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\",
index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column {
name: \"t2_id\", index: 0 }], 2), input_partitions=2",
- " RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=1",
- " MemoryExec: partitions=1,
partition_sizes=[1]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column {
name: \"t1_id\", index: 0 }], 2), input_partitions=2",
- " RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=1",
- " MemoryExec: partitions=1,
partition_sizes=[1]",
+ vec!["SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
+ " SortExec: expr=[t1_id@0 ASC NULLS LAST]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t1_int@2 as t1_int]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=RightSemi,
on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0
})], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op:
NotEq, right: Column { name: \"t1_name\", index: 0 } }",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column {
name: \"t2_id\", index: 0 }], 2), input_partitions=2",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=1",
+ " ProjectionExec: expr=[t2_id@0 as t2_id,
t2_name@1 as t2_name]",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column {
name: \"t1_id\", index: 0 }], 2), input_partitions=2",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=1",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
]
} else {
vec![
@@ -2275,7 +2275,8 @@ async fn right_semi_join() -> Result<()> {
" ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t1_int@2 as t1_int]",
" CoalesceBatchesExec: target_batch_size=4096",
" HashJoinExec: mode=CollectLeft, join_type=RightSemi,
on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0
})], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op:
NotEq, right: Column { name: \"t1_name\", index: 0 } }",
- " MemoryExec: partitions=1, partition_sizes=[1]",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as
t2_name]",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
" MemoryExec: partitions=1, partition_sizes=[1]",
]
};
@@ -3393,3 +3394,173 @@ async fn left_as_inner_table_nested_loop_join() ->
Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn exists_subquery_to_join_expr_filter() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ // exists subquery to LeftSemi join
+ let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE
t1.t1_id + 1 > t2.t2_id * 2)";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() +
sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan()?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) >
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: t2.t2_id [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 22 | b | 2 |",
+ "| 33 | c | 3 |",
+ "| 44 | d | 4 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn exists_subquery_to_join_inner_filter() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ // exists subquery to LeftSemi join
+ let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE
t1.t1_id + 1 > t2.t2_id * 2 AND t2.t2_int < 3)";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() +
sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan()?;
+
+ // `t2.t2_int < 3` will be kept in the subquery filter.
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) >
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: t2.t2_id [t2_id:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N,
t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_int]
[t2_id:UInt32;N, t2_int:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 44 | d | 4 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn exists_subquery_to_join_outer_filter() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ // exists subquery to LeftSemi join
+ let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE
t1.t1_id + 1 > t2.t2_id * 2 AND t1.t1_int < 3)";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() +
sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan()?;
+
+ // `t1.t1_int < 3` will be moved to the filter of t1.
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) >
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N]",
+ " Filter: t1.t1_int < UInt32(3) [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: t2.t2_id [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 22 | b | 2 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn not_exists_subquery_to_join_expr_filter() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ // not exists subquery to LeftAnti join
+ let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT t2_id FROM t2
WHERE t1.t1_id + 1 > t2.t2_id * 2)";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() +
sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan()?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) >
CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: t2.t2_id [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs
b/datafusion/optimizer/src/decorrelate_where_exists.rs
index bdc4afe90..3f6b160fa 100644
--- a/datafusion/optimizer/src/decorrelate_where_exists.rs
+++ b/datafusion/optimizer/src/decorrelate_where_exists.rs
@@ -16,16 +16,14 @@
// under the License.
use crate::optimizer::ApplyOrder;
-use crate::utils::{
- conjunction, exprs_to_join_cols, find_join_exprs, split_conjunction,
- verify_not_disjunction,
-};
+use crate::utils::{conjunction, extract_join_filters, split_conjunction};
use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{context, Result};
+use datafusion_common::{Column, DataFusionError, Result};
use datafusion_expr::{
logical_plan::{Filter, JoinType, Subquery},
Expr, LogicalPlan, LogicalPlanBuilder,
};
+use std::collections::BTreeSet;
use std::sync::Arc;
/// Optimizer rule for rewriting subquery filters to joins
@@ -144,55 +142,68 @@ fn optimize_exists(
query_info: &SubqueryInfo,
outer_input: &LogicalPlan,
) -> Result<Option<LogicalPlan>> {
- let subqry_filter = match query_info.query.subquery.as_ref() {
+ let maybe_subqury_filter = match query_info.query.subquery.as_ref() {
LogicalPlan::Distinct(subqry_distinct) => match
subqry_distinct.input.as_ref() {
- LogicalPlan::Projection(subqry_proj) => {
- Filter::try_from_plan(&subqry_proj.input)
- }
+ LogicalPlan::Projection(subqry_proj) => &subqry_proj.input,
_ => {
- // Subquery currently only supports distinct or projection
return Ok(None);
}
},
- LogicalPlan::Projection(subqry_proj) =>
Filter::try_from_plan(&subqry_proj.input),
+ LogicalPlan::Projection(subqry_proj) => &subqry_proj.input,
_ => {
// Subquery currently only supports distinct or projection
return Ok(None);
}
}
- .map_err(|e| context!("cannot optimize non-correlated subquery", e))?;
-
- // split into filters
- let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate);
- verify_not_disjunction(&subqry_filter_exprs)?;
-
- // Grab column names to join on
- let (col_exprs, other_subqry_exprs) =
- find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())?;
- let (outer_cols, subqry_cols, join_filters) =
- exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)?;
- if subqry_cols.is_empty() || outer_cols.is_empty() {
- // cannot optimize non-correlated subquery
+ .as_ref();
+
+ // extract join filters
+ let (join_filters, subquery_input) =
extract_join_filters(maybe_subqury_filter)?;
+ // cannot optimize non-correlated subquery
+ if join_filters.is_empty() {
return Ok(None);
}
- // build subquery side of join - the thing the subquery was querying
- let mut subqry_plan =
LogicalPlanBuilder::from(subqry_filter.input.as_ref().clone());
- if let Some(expr) = conjunction(other_subqry_exprs) {
- subqry_plan = subqry_plan.filter(expr)? // if the subquery had
additional expressions, restore them
- }
- let subqry_plan = subqry_plan.build()?;
+ let input_schema = subquery_input.schema();
+ let subquery_cols: BTreeSet<Column> =
+ join_filters
+ .iter()
+ .try_fold(BTreeSet::new(), |mut cols, expr| {
+ let using_cols: Vec<Column> = expr
+ .to_columns()?
+ .into_iter()
+ .filter(|col| input_schema.field_from_column(col).is_ok())
+ .collect::<_>();
+
+ cols.extend(using_cols);
+ Result::<_, DataFusionError>::Ok(cols)
+ })?;
+
+ let projection_exprs: Vec<Expr> =
+ subquery_cols.into_iter().map(Expr::Column).collect();
+
+ let right = LogicalPlanBuilder::from(subquery_input)
+ .project(projection_exprs)?
+ .build()?;
- let join_keys = (subqry_cols, outer_cols);
+ let join_filter = conjunction(join_filters);
// join our sub query into the main plan
let join_type = match query_info.negated {
true => JoinType::LeftAnti,
false => JoinType::LeftSemi,
};
+
+ // TODO: add Distinct if the original plan is a Distinct.
let new_plan = LogicalPlanBuilder::from(outer_input.clone())
- .join(subqry_plan, join_type, join_keys, join_filters)?
+ .join(
+ right,
+ join_type,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ join_filter,
+ )?
.build()?;
+
Ok(Some(new_plan))
}
@@ -241,13 +252,14 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
- LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64,
c_name:Utf8]
- LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64,
c_name:Utf8]
- TableScan: customer [c_custkey:Int64, c_name:Utf8]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8,
o_totalprice:Float64;N]"#;
-
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: orders.o_custkey =
customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: orders.o_custkey =
customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_plan_eq(&plan, expected)
}
@@ -276,13 +288,14 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
- LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64,
c_name:Utf8]
- TableScan: customer [c_custkey:Int64, c_name:Utf8]
- LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]
- TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64,
l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#;
-
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: orders.o_custkey =
customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n LeftSemi Join: Filter: lineitem.l_orderkey =
orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8,
o_totalprice:Float64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n Projection: lineitem.l_orderkey
[l_orderkey:Int64]\
+ \n TableScan: lineitem [l_orderkey:Int64,
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64,
l_extendedprice:Float64]";
assert_plan_eq(&plan, expected)
}
@@ -305,21 +318,21 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
- LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64,
c_name:Utf8]
- TableScan: customer [c_custkey:Int64, c_name:Utf8]
- Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey =
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n Filter: orders.o_orderkey = Int32(1)
[o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_plan_eq(&plan, expected)
}
- /// Test for correlated exists subquery with no columns in schema
#[test]
fn exists_subquery_no_cols() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
-
.filter(col("customer.c_custkey").eq(col("customer.c_custkey")))?
+ .filter(col("customer.c_custkey").eq(lit(1u32)))?
.project(vec![col("orders.o_custkey")])?
.build()?,
);
@@ -329,7 +342,14 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()),
&plan)
+ // Other rule will pushdown `customer.c_custkey = 1`,
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey =
UInt32(1) [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: []\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_plan_eq(&plan, expected)
}
/// Test for exists subquery with both columns in schema
@@ -365,7 +385,13 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()),
&plan)
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey !=
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated exists subquery less than
@@ -383,10 +409,13 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- let expected = r#"can't optimize < column comparison"#;
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey <
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan,
expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated exists subquery filter with subquery disjunction
@@ -408,10 +437,13 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- let expected = r#"Optimizing disjunctions not supported!"#;
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey =
orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey, orders.o_orderkey
[o_custkey:Int64, o_orderkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan,
expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated exists without projection
@@ -446,11 +478,11 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- // Doesn't matter we projected an expression, just that we returned a
result
- let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
- LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64,
c_name:Utf8]
- TableScan: customer [c_custkey:Int64, c_name:Utf8]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8,
o_totalprice:Float64;N]"#;
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey =
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_plan_eq(&plan, expected)
}
@@ -469,11 +501,12 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
- Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
- LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64,
c_name:Utf8]
- TableScan: customer [c_custkey:Int64, c_name:Utf8]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n Filter: customer.c_custkey = Int32(1)
[c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey =
orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64,
c_name:Utf8]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_plan_eq(&plan, expected)
}
@@ -520,10 +553,11 @@ mod tests {
.project(vec![col("test.c")])?
.build()?;
- let expected = r#"Projection: test.c [c:UInt32]
- LeftSemi Join: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]
- TableScan: test [a:UInt32, b:UInt32, c:UInt32]
- TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
+ let expected = "Projection: test.c [c:UInt32]\
+ \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32,
b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(&plan, expected)
}
@@ -537,10 +571,7 @@ mod tests {
.project(vec![col("test.b")])?
.build()?;
- let expected = "cannot optimize non-correlated subquery";
-
- assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan,
expected);
- Ok(())
+ assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()),
&plan)
}
/// Test for single NOT exists subquery filter
@@ -552,10 +583,7 @@ mod tests {
.project(vec![col("test.b")])?
.build()?;
- let expected = "cannot optimize non-correlated subquery";
-
- assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan,
expected);
- Ok(())
+ assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()),
&plan)
}
#[test]
@@ -583,18 +611,37 @@ mod tests {
.build()?;
let expected = "Projection: test.b [b:UInt32]\
- \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
- \n LeftSemi Join: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\
- \n LeftSemi Join: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
+ \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32,
c:UInt32]\
+ \n LeftSemi Join: Filter: test.a = sq2.a
[a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.a = sq1.a
[a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]\
+ \n Projection: sq1.a [a:UInt32]\
+ \n TableScan: sq1 [a:UInt32, b:UInt32,
c:UInt32]\
+ \n Projection: sq2.a [a:UInt32]\
+ \n TableScan: sq2 [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereExists::new()),
- &plan,
- expected,
- );
- Ok(())
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn exists_subquery_expr_filter() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))?
+ .project(vec![lit(1u32)])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(exists(Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a
* UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
}
}
diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs
b/datafusion/optimizer/src/decorrelate_where_in.rs
index 7a9a75ff4..c8ff65f12 100644
--- a/datafusion/optimizer/src/decorrelate_where_in.rs
+++ b/datafusion/optimizer/src/decorrelate_where_in.rs
@@ -17,12 +17,11 @@
use crate::alias::AliasGenerator;
use crate::optimizer::ApplyOrder;
-use crate::utils::{conjunction, only_or_err, split_conjunction};
+use crate::utils::{conjunction, extract_join_filters, only_or_err,
split_conjunction};
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{context, Column, DataFusionError, Result};
use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col};
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
-use datafusion_expr::utils::check_all_columns_from_schema;
use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder};
use log::debug;
use std::collections::{BTreeSet, HashMap};
@@ -220,34 +219,6 @@ fn optimize_where_in(
Ok(new_plan)
}
-fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec<Expr>,
LogicalPlan)> {
- if let LogicalPlan::Filter(plan_filter) = maybe_filter {
- let input_schema = plan_filter.input.schema();
- let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
-
- let mut join_filters: Vec<Expr> = vec![];
- let mut subquery_filters: Vec<Expr> = vec![];
- for expr in subquery_filter_exprs {
- let cols = expr.to_columns()?;
- if check_all_columns_from_schema(&cols, input_schema.clone())? {
- subquery_filters.push(expr.clone());
- } else {
- join_filters.push(expr.clone())
- }
- }
-
- // if the subquery still has filter expressions, restore them.
- let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());
- if let Some(expr) = conjunction(subquery_filters) {
- plan = plan.filter(expr)?
- }
-
- Ok((join_filters, plan.build()?))
- } else {
- Ok((vec![], maybe_filter.clone()))
- }
-}
-
fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: Expr) ->
Vec<Expr> {
filters
.into_iter()
diff --git a/datafusion/optimizer/src/utils.rs
b/datafusion/optimizer/src/utils.rs
index 4d9d10d51..747fa6208 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -23,10 +23,11 @@ use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::{BinaryExpr, Sort};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::inspect_expr_pre;
+use datafusion_expr::logical_plan::LogicalPlanBuilder;
+use datafusion_expr::utils::{check_all_columns_from_schema, from_plan};
use datafusion_expr::{
and,
logical_plan::{Filter, LogicalPlan},
- utils::from_plan,
Expr, Operator,
};
use std::collections::HashSet;
@@ -468,6 +469,42 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema
{
})
}
+/// Extract join predicates from the correclated subquery.
+/// The join predicate means that the expression references columns
+/// from both the subquery and outer table or only from the outer table.
+///
+/// Returns join predicates and subquery(extracted).
+/// ```
+pub(crate) fn extract_join_filters(
+ maybe_filter: &LogicalPlan,
+) -> Result<(Vec<Expr>, LogicalPlan)> {
+ if let LogicalPlan::Filter(plan_filter) = maybe_filter {
+ let input_schema = plan_filter.input.schema();
+ let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
+
+ let mut join_filters: Vec<Expr> = vec![];
+ let mut subquery_filters: Vec<Expr> = vec![];
+ for expr in subquery_filter_exprs {
+ let cols = expr.to_columns()?;
+ if check_all_columns_from_schema(&cols, input_schema.clone())? {
+ subquery_filters.push(expr.clone());
+ } else {
+ join_filters.push(expr.clone())
+ }
+ }
+
+ // if the subquery still has filter expressions, restore them.
+ let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());
+ if let Some(expr) = conjunction(subquery_filters) {
+ plan = plan.filter(expr)?
+ }
+
+ Ok((join_filters, plan.build()?))
+ } else {
+ Ok((vec![], maybe_filter.clone()))
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/optimizer/tests/integration-test.rs
b/datafusion/optimizer/tests/integration-test.rs
index f901e33d4..eac849e34 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -121,8 +121,9 @@ fn semi_join_with_join_filter() -> Result<()> {
let expected = "Projection: test.col_utf8\
\n LeftSemi Join: test.col_int32 = t2.col_int32 Filter:
test.col_uint32 != t2.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32,
col_utf8]\
- \n SubqueryAlias: t2\
- \n TableScan: test projection=[col_int32, col_uint32,
col_utf8]";
+ \n Projection: t2.col_int32, t2.col_uint32\
+ \n SubqueryAlias: t2\
+ \n TableScan: test projection=[col_int32,
col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
@@ -137,8 +138,9 @@ fn anti_join_with_join_filter() -> Result<()> {
let expected = "Projection: test.col_utf8\
\n LeftAnti Join: test.col_int32 = t2.col_int32 Filter:
test.col_uint32 != t2.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32,
col_utf8]\
- \n SubqueryAlias: t2\
- \n TableScan: test projection=[col_int32, col_uint32,
col_utf8]";
+ \n Projection: t2.col_int32, t2.col_uint32\
+ \n SubqueryAlias: t2\
+ \n TableScan: test projection=[col_int32,
col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
@@ -152,8 +154,9 @@ fn where_exists_distinct() -> Result<()> {
let expected = "Projection: test.col_int32\
\n LeftSemi Join: test.col_int32 = t2.col_int32\
\n TableScan: test projection=[col_int32]\
- \n SubqueryAlias: t2\
- \n TableScan: test projection=[col_int32]";
+ \n Projection: t2.col_int32\
+ \n SubqueryAlias: t2\
+ \n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}