This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 992624a Fix predicate pushdown for outer joins (#1618)
992624a is described below
commit 992624aefcd24b9df6f49476cf8504ab3c59a155
Author: James Katz <[email protected]>
AuthorDate: Mon Jan 24 15:38:00 2022 -0500
Fix predicate pushdown for outer joins (#1618)
* Add failing test for outer join with null filter
* Fix typos in filter_push_down comment
* Don't push predicates to unpreserved join sides
* Do not duplicate filters for joined columns for non inner joins
* Add some more tests
* Only skip filter duplication for non-preserved side
* Only duplicate filters on the join column for inner joins
* Clarify test comment
* Add some sql tests
* Add some data to test tables
* Add more sql tests
* Clarify definition of preserved
* Remove redundant logic
---
datafusion/src/optimizer/filter_push_down.rs | 463 +++++++++++++++++++--------
datafusion/tests/sql/joins.rs | 196 ++++++++++++
datafusion/tests/sql/mod.rs | 49 +++
3 files changed, 567 insertions(+), 141 deletions(-)
diff --git a/datafusion/src/optimizer/filter_push_down.rs
b/datafusion/src/optimizer/filter_push_down.rs
index ef171e1..6141af1 100644
--- a/datafusion/src/optimizer/filter_push_down.rs
+++ b/datafusion/src/optimizer/filter_push_down.rs
@@ -18,7 +18,7 @@ use
crate::datasource::datasource::TableProviderFilterPushDown;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection};
use crate::logical_plan::{
- and, replace_col, Column, CrossJoin, Limit, LogicalPlan, TableScan,
+ and, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan,
TableScan,
};
use crate::logical_plan::{DFSchema, Expr};
use crate::optimizer::optimizer::OptimizerRule;
@@ -50,9 +50,9 @@ use std::{
/// Projection: #a AS b
/// Filter: #a Gt Int64(10) <--- changed from #b to #a
///
-/// This performs a single pass trought the plan. When it passes trought a
filter, it stores that filter,
+/// This performs a single pass through the plan. When it passes through a
filter, it stores that filter,
/// and when it reaches a node that does not commute with it, it adds the
filter to that place.
-/// When it passes through a projection, it re-writes the filter's expression
taking into accoun that projection.
+/// When it passes through a projection, it re-writes the filter's expression
taking into account that projection.
/// When multiple filters would have been written, it `AND` their expressions
into a single expression.
#[derive(Default)]
pub struct FilterPushDown {}
@@ -83,83 +83,6 @@ fn get_predicates<'a>(
.unzip()
}
-// returns 3 (potentially overlaping) sets of predicates:
-// * pushable to left: its columns are all on the left
-// * pushable to right: its columns is all on the right
-// * keep: the set of columns is not in only either left or right
-// Note that a predicate can be both pushed to the left and to the right.
-fn get_join_predicates<'a>(
- state: &'a State,
- left: &DFSchema,
- right: &DFSchema,
-) -> (
- Vec<&'a HashSet<Column>>,
- Vec<&'a HashSet<Column>>,
- Predicates<'a>,
-) {
- let left_columns = &left
- .fields()
- .iter()
- .map(|f| {
- [
- f.qualified_column(),
- // we need to push down filter using unqualified column as well
- f.unqualified_column(),
- ]
- })
- .flatten()
- .collect::<HashSet<_>>();
- let right_columns = &right
- .fields()
- .iter()
- .map(|f| {
- [
- f.qualified_column(),
- // we need to push down filter using unqualified column as well
- f.unqualified_column(),
- ]
- })
- .flatten()
- .collect::<HashSet<_>>();
-
- let filters = state
- .filters
- .iter()
- .map(|(predicate, columns)| {
- (
- (predicate, columns),
- (
- columns,
- left_columns.intersection(columns).collect::<HashSet<_>>(),
-
right_columns.intersection(columns).collect::<HashSet<_>>(),
- ),
- )
- })
- .collect::<Vec<_>>();
-
- let pushable_to_left = filters
- .iter()
- .filter(|(_, (columns, left, _))| left.len() == columns.len())
- .map(|((_, b), _)| *b)
- .collect();
- let pushable_to_right = filters
- .iter()
- .filter(|(_, (columns, _, right))| right.len() == columns.len())
- .map(|((_, b), _)| *b)
- .collect();
- let keep = filters
- .iter()
- .filter(|(_, (columns, left, right))| {
- // predicates whose columns are not in only one side of the join
need to remain
- let all_in_left = left.len() == columns.len();
- let all_in_right = right.len() == columns.len();
- !all_in_left && !all_in_right
- })
- .map(|((a, b), _)| (a, b))
- .unzip();
- (pushable_to_left, pushable_to_right, keep)
-}
-
/// Optimizes the plan
fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_inputs = plan
@@ -204,11 +127,11 @@ fn remove_filters(
// keeps all filters from `filters` that are in `predicate_columns`
fn keep_filters(
filters: &[(Expr, HashSet<Column>)],
- predicate_columns: &[&HashSet<Column>],
+ relevant_predicates: &Predicates,
) -> Vec<(Expr, HashSet<Column>)> {
filters
.iter()
- .filter(|(_, columns)| predicate_columns.contains(&columns))
+ .filter(|(expr, _)| relevant_predicates.0.contains(&expr))
.cloned()
.collect::<Vec<_>>()
}
@@ -253,33 +176,122 @@ fn split_members<'a>(predicate: &'a Expr, predicates:
&mut Vec<&'a Expr>) {
}
}
+// For a given JOIN logical plan, determine whether each side of the join is
preserved.
+// We say a join side is preserved if the join returns all or a subset of the
rows from
+// the relevant side, such that each row of the output table directly maps to
a row of
+// the preserved input table. If a table is not preserved, it can provide
extra null rows.
+// That is, there may be rows in the output table that don't directly map to a
row in the
+// input table.
+//
+// For example:
+// - In an inner join, both sides are preserved, because each row of the
output
+// maps directly to a row from each side.
+// - In a left join, the left side is preserved and the right is not, because
+// there may be rows in the output that don't directly map to a row in the
+// right input (due to nulls filling where there is no match on the right).
+//
+// This is important because we can always push down post-join filters to a
preserved
+// side of the join, assuming the filter only references columns from that
side. For the
+// non-preserved side it can be more tricky.
+//
+// Returns a tuple of booleans - (left_preserved, right_preserved).
+fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
+ match plan {
+ LogicalPlan::Join(Join { join_type, .. }) => match join_type {
+ JoinType::Inner => (true, true),
+ JoinType::Left => (true, false),
+ JoinType::Right => (false, true),
+ JoinType::Full => (false, false),
+ // No columns from the right side of the join can be referenced in
output
+ // predicates for semi/anti joins, so whether we specify t/f
doesn't matter.
+ JoinType::Semi | JoinType::Anti => (true, false),
+ },
+ LogicalPlan::CrossJoin(_) => (true, true),
+ _ => unreachable!("lr_is_preserved only valid for JOIN nodes"),
+ }
+}
+
+// Determine which predicates in state can be pushed down to a given side of a
join.
+// To determine this, we need to know the schema of the relevant join side and
whether
+// or not the side's rows are preserved when joining. If the side is not
preserved, we
+// do not push down anything. Otherwise we can push down predicates where all
of the
+// relevant columns are contained on the relevant join side's schema.
+fn get_pushable_join_predicates<'a>(
+ state: &'a State,
+ schema: &DFSchema,
+ preserved: bool,
+) -> Predicates<'a> {
+ if !preserved {
+ return (vec![], vec![]);
+ }
+
+ let schema_columns = schema
+ .fields()
+ .iter()
+ .map(|f| {
+ [
+ f.qualified_column(),
+ // we need to push down filter using unqualified column as well
+ f.unqualified_column(),
+ ]
+ })
+ .flatten()
+ .collect::<HashSet<_>>();
+
+ state
+ .filters
+ .iter()
+ .filter(|(_, columns)| {
+ let all_columns_in_schema = schema_columns
+ .intersection(columns)
+ .collect::<HashSet<_>>()
+ .len()
+ == columns.len();
+ all_columns_in_schema
+ })
+ .map(|(a, b)| (a, b))
+ .unzip()
+}
+
fn optimize_join(
mut state: State,
plan: &LogicalPlan,
left: &LogicalPlan,
right: &LogicalPlan,
) -> Result<LogicalPlan> {
- let (pushable_to_left, pushable_to_right, keep) =
- get_join_predicates(&state, left.schema(), right.schema());
+ let (left_preserved, right_preserved) = lr_is_preserved(plan);
+ let to_left = get_pushable_join_predicates(&state, left.schema(),
left_preserved);
+ let to_right = get_pushable_join_predicates(&state, right.schema(),
right_preserved);
+
+ let to_keep: Predicates = state
+ .filters
+ .iter()
+ .filter(|(expr, _)| {
+ let pushed_to_left = to_left.0.contains(&expr);
+ let pushed_to_right = to_right.0.contains(&expr);
+ !pushed_to_left && !pushed_to_right
+ })
+ .map(|(a, b)| (a, b))
+ .unzip();
let mut left_state = state.clone();
- left_state.filters = keep_filters(&left_state.filters, &pushable_to_left);
+ left_state.filters = keep_filters(&left_state.filters, &to_left);
let left = optimize(left, left_state)?;
let mut right_state = state.clone();
- right_state.filters = keep_filters(&right_state.filters,
&pushable_to_right);
+ right_state.filters = keep_filters(&right_state.filters, &to_right);
let right = optimize(right, right_state)?;
// create a new Join with the new `left` and `right`
let expr = plan.expressions();
let plan = utils::from_plan(plan, &expr, &[left, right])?;
- if keep.0.is_empty() {
+ if to_keep.0.is_empty() {
Ok(plan)
} else {
// wrap the join on the filter whose predicates must be kept
- let plan = add_filter(plan, &keep.0);
- state.filters = remove_filters(&state.filters, &keep.1);
+ let plan = add_filter(plan, &to_keep.0);
+ state.filters = remove_filters(&state.filters, &to_keep.1);
Ok(plan)
}
@@ -400,63 +412,68 @@ fn optimize(plan: &LogicalPlan, mut state: State) ->
Result<LogicalPlan> {
optimize_join(state, plan, left, right)
}
LogicalPlan::Join(Join {
- left, right, on, ..
+ left,
+ right,
+ on,
+ join_type,
+ ..
}) => {
- // duplicate filters for joined columns so filters can be pushed
down to both sides.
- // Take the following query as an example:
- //
- // ```sql
- // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
- // ```
- //
- // `t1.id > 1` predicate needs to be pushed down to t1 table scan,
while
- // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
- //
- // Join clauses with `Using` constraints also take advantage of
this logic to make sure
- // predicates reference the shared join columns are pushed to both
sides.
- let join_side_filters = state
- .filters
- .iter()
- .filter_map(|(predicate, columns)| {
- let mut join_cols_to_replace = HashMap::new();
- for col in columns.iter() {
- for (l, r) in on {
- if col == l {
- join_cols_to_replace.insert(col, r);
- break;
- } else if col == r {
- join_cols_to_replace.insert(col, l);
- break;
+ if *join_type == JoinType::Inner {
+ // For inner joins, duplicate filters for joined columns so
filters can be pushed down
+ // to both sides. Take the following query as an example:
+ //
+ // ```sql
+ // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+ // ```
+ //
+ // `t1.id > 1` predicate needs to be pushed down to t1 table
scan, while
+ // `t2.uid > 1` predicate needs to be pushed down to t2 table
scan.
+ //
+ // Join clauses with `Using` constraints also take advantage
of this logic to make sure
+ // predicates reference the shared join columns are pushed to
both sides.
+ let join_side_filters = state
+ .filters
+ .iter()
+ .filter_map(|(predicate, columns)| {
+ let mut join_cols_to_replace = HashMap::new();
+ for col in columns.iter() {
+ for (l, r) in on {
+ if col == l {
+ join_cols_to_replace.insert(col, r);
+ break;
+ } else if col == r {
+ join_cols_to_replace.insert(col, l);
+ break;
+ }
}
}
- }
- if join_cols_to_replace.is_empty() {
- return None;
- }
-
- let join_side_predicate =
- match replace_col(predicate.clone(),
&join_cols_to_replace) {
- Ok(p) => p,
- Err(e) => {
- return Some(Err(e));
- }
- };
-
- let join_side_columns = columns
- .clone()
- .into_iter()
- // replace keys in join_cols_to_replace with values in
resulting column
- // set
- .filter(|c| !join_cols_to_replace.contains_key(c))
- .chain(join_cols_to_replace.iter().map(|(_, v)|
(*v).clone()))
- .collect();
-
- Some(Ok((join_side_predicate, join_side_columns)))
- })
- .collect::<Result<Vec<_>>>()?;
- state.filters.extend(join_side_filters);
+ if join_cols_to_replace.is_empty() {
+ return None;
+ }
+ let join_side_predicate =
+ match replace_col(predicate.clone(),
&join_cols_to_replace) {
+ Ok(p) => p,
+ Err(e) => {
+ return Some(Err(e));
+ }
+ };
+
+ let join_side_columns = columns
+ .clone()
+ .into_iter()
+ // replace keys in join_cols_to_replace with
values in resulting column
+ // set
+ .filter(|c| !join_cols_to_replace.contains_key(c))
+ .chain(join_cols_to_replace.iter().map(|(_, v)|
(*v).clone()))
+ .collect();
+
+ Some(Ok((join_side_predicate, join_side_columns)))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ state.filters.extend(join_side_filters);
+ }
optimize_join(state, plan, left, right)
}
LogicalPlan::TableScan(TableScan {
@@ -1139,6 +1156,170 @@ mod tests {
Ok(())
}
+ /// post-join predicates on the right side of a left join are not
duplicated
+ /// TODO: In this case we can sometimes convert the join to an INNER join
+ #[test]
+ fn filter_using_left_join() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let left = LogicalPlanBuilder::from(table_scan).build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join_using(
+ &right,
+ JoinType::Left,
+ vec![Column::from_name("a".to_string())],
+ )?
+ .filter(col("test2.a").lt_eq(lit(1i64)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{:?}", plan),
+ "\
+ Filter: #test2.a <= Int64(1)\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None"
+ );
+
+ // filter not duplicated nor pushed down - i.e. noop
+ let expected = "\
+ Filter: #test2.a <= Int64(1)\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ /// post-join predicates on the left side of a right join are not
duplicated
+ /// TODO: In this case we can sometimes convert the join to an INNER join
+ #[test]
+ fn filter_using_right_join() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let left = LogicalPlanBuilder::from(table_scan).build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join_using(
+ &right,
+ JoinType::Right,
+ vec![Column::from_name("a".to_string())],
+ )?
+ .filter(col("test.a").lt_eq(lit(1i64)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{:?}", plan),
+ "\
+ Filter: #test.a <= Int64(1)\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None"
+ );
+
+ // filter not duplicated nor pushed down - i.e. noop
+ let expected = "\
+ Filter: #test.a <= Int64(1)\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ /// post-left-join predicate on a column common to both sides is only
pushed to the left side
+ /// i.e. - not duplicated to the right side
+ #[test]
+ fn filter_using_left_join_on_common() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let left = LogicalPlanBuilder::from(table_scan).build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join_using(
+ &right,
+ JoinType::Left,
+ vec![Column::from_name("a".to_string())],
+ )?
+ .filter(col("a").lt_eq(lit(1i64)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{:?}", plan),
+ "\
+ Filter: #test.a <= Int64(1)\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None"
+ );
+
+ // filter sent to left side of the join, not the right
+ let expected = "\
+ Join: Using #test.a = #test2.a\
+ \n Filter: #test.a <= Int64(1)\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ /// post-right-join predicate on a column common to both sides is only
pushed to the right side
+ /// i.e. - not duplicated to the left side.
+ #[test]
+ fn filter_using_right_join_on_common() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let left = LogicalPlanBuilder::from(table_scan).build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join_using(
+ &right,
+ JoinType::Right,
+ vec![Column::from_name("a".to_string())],
+ )?
+ .filter(col("test2.a").lt_eq(lit(1i64)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{:?}", plan),
+ "\
+ Filter: #test2.a <= Int64(1)\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None"
+ );
+
+ // filter sent to right side of join, not duplicated to the left
+ let expected = "\
+ Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n Filter: #test2.a <= Int64(1)\
+ \n TableScan: test2 projection=None";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
struct PushDownProvider {
pub filter_support: TableProviderFilterPushDown,
}
diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs
index 85b59e6..70d824b 100644
--- a/datafusion/tests/sql/joins.rs
+++ b/datafusion/tests/sql/joins.rs
@@ -212,6 +212,202 @@ async fn left_join_unbalanced() -> Result<()> {
}
#[tokio::test]
+async fn left_join_null_filter() -> Result<()> {
+ // Since t2 is the non-preserved side of the join, we cannot push down a
NULL filter.
+ // Note that this is only true because IS NULL does not remove nulls. For
filters that
+ // remove nulls, we can rewrite the join as an inner join and then push
down the filter.
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id =
t2_id WHERE t2_name IS NULL ORDER BY t1_id";
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t2_name |",
+ "+-------+-------+---------+",
+ "| 22 | 22 | |",
+ "| 33 | | |",
+ "| 77 | | |",
+ "| 88 | | |",
+ "+-------+-------+---------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn left_join_null_filter_on_join_column() -> Result<()> {
+ // Again, since t2 is the non-preserved side of the join, we cannot push
down a NULL filter.
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id =
t2_id WHERE t2_id IS NULL ORDER BY t1_id";
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t2_name |",
+ "+-------+-------+---------+",
+ "| 33 | | |",
+ "| 77 | | |",
+ "| 88 | | |",
+ "+-------+-------+---------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn left_join_not_null_filter() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id =
t2_id WHERE t2_name IS NOT NULL ORDER BY t1_id";
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t2_name |",
+ "+-------+-------+---------+",
+ "| 11 | 11 | z |",
+ "| 44 | 44 | x |",
+ "| 99 | 99 | u |",
+ "+-------+-------+---------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn left_join_not_null_filter_on_join_column() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id =
t2_id WHERE t2_id IS NOT NULL ORDER BY t1_id";
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t2_name |",
+ "+-------+-------+---------+",
+ "| 11 | 11 | z |",
+ "| 22 | 22 | |",
+ "| 44 | 44 | x |",
+ "| 99 | 99 | u |",
+ "+-------+-------+---------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn right_join_null_filter() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id =
t2_id WHERE t1_name IS NULL ORDER BY t2_id";
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| | | 55 |",
+ "| 99 | | 99 |",
+ "+-------+---------+-------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn right_join_null_filter_on_join_column() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id =
t2_id WHERE t1_id IS NULL ORDER BY t2_id";
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| | | 55 |",
+ "+-------+---------+-------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn right_join_not_null_filter() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id =
t2_id WHERE t1_name IS NOT NULL ORDER BY t2_id";
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 11 |",
+ "| 22 | b | 22 |",
+ "| 44 | d | 44 |",
+ "+-------+---------+-------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn right_join_not_null_filter_on_join_column() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id =
t2_id WHERE t1_id IS NOT NULL ORDER BY t2_id";
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 11 |",
+ "| 22 | b | 22 |",
+ "| 44 | d | 44 |",
+ "| 99 | | 99 |",
+ "+-------+---------+-------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn full_join_null_filter() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON
t1_id = t2_id WHERE t1_name IS NULL ORDER BY t1_id";
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 88 | | |",
+ "| 99 | | 99 |",
+ "| | | 55 |",
+ "+-------+---------+-------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn full_join_not_null_filter() -> Result<()> {
+ let mut ctx = create_join_context_with_nulls()?;
+ let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON
t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t1_id";
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 11 |",
+ "| 22 | b | 22 |",
+ "| 33 | c | |",
+ "| 44 | d | 44 |",
+ "| 77 | e | |",
+ "+-------+---------+-------+",
+ ];
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
async fn right_join() -> Result<()> {
let mut ctx = create_join_context("t1_id", "t2_id")?;
let equivalent_sql = [
diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs
index 46c33f7..55715af 100644
--- a/datafusion/tests/sql/mod.rs
+++ b/datafusion/tests/sql/mod.rs
@@ -297,6 +297,55 @@ fn create_join_context_unbalanced(
Ok(ctx)
}
+// Create memory tables with nulls
+fn create_join_context_with_nulls() -> Result<ExecutionContext> {
+ let mut ctx = ExecutionContext::new();
+
+ let t1_schema = Arc::new(Schema::new(vec![
+ Field::new("t1_id", DataType::UInt32, true),
+ Field::new("t1_name", DataType::Utf8, true),
+ ]));
+ let t1_data = RecordBatch::try_new(
+ t1_schema.clone(),
+ vec![
+ Arc::new(UInt32Array::from(vec![11, 22, 33, 44, 77, 88, 99])),
+ Arc::new(StringArray::from(vec![
+ Some("a"),
+ Some("b"),
+ Some("c"),
+ Some("d"),
+ Some("e"),
+ None,
+ None,
+ ])),
+ ],
+ )?;
+ let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?;
+ ctx.register_table("t1", Arc::new(t1_table))?;
+
+ let t2_schema = Arc::new(Schema::new(vec![
+ Field::new("t2_id", DataType::UInt32, true),
+ Field::new("t2_name", DataType::Utf8, true),
+ ]));
+ let t2_data = RecordBatch::try_new(
+ t2_schema.clone(),
+ vec![
+ Arc::new(UInt32Array::from(vec![11, 22, 44, 55, 99])),
+ Arc::new(StringArray::from(vec![
+ Some("z"),
+ None,
+ Some("x"),
+ Some("w"),
+ Some("u"),
+ ])),
+ ],
+ )?;
+ let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?;
+ ctx.register_table("t2", Arc::new(t2_table))?;
+
+ Ok(ctx)
+}
+
fn get_tpch_table_schema(table: &str) -> Schema {
match table {
"customer" => Schema::new(vec![