This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 992624a  Fix predicate pushdown for outer joins (#1618)
992624a is described below

commit 992624aefcd24b9df6f49476cf8504ab3c59a155
Author: James Katz <[email protected]>
AuthorDate: Mon Jan 24 15:38:00 2022 -0500

    Fix predicate pushdown for outer joins (#1618)
    
    * Add failing test for outer join with null filter
    
    * Fix typos in filter_push_down comment
    
    * Don't push predicates to unpreserved join sides
    
    * Do not duplicate filters for joined columns for non inner joins
    
    * Add some more tests
    
    * Only skip filter duplication for non-preserved side
    
    * Only duplicate filters on the join column for inner joins
    
    * Clarify test comment
    
    * Add some sql tests
    
    * Add some data to test tables
    
    * Add more sql tests
    
    * Clarify definition of preserved
    
    * Remove redundant logic
---
 datafusion/src/optimizer/filter_push_down.rs | 463 +++++++++++++++++++--------
 datafusion/tests/sql/joins.rs                | 196 ++++++++++++
 datafusion/tests/sql/mod.rs                  |  49 +++
 3 files changed, 567 insertions(+), 141 deletions(-)

diff --git a/datafusion/src/optimizer/filter_push_down.rs 
b/datafusion/src/optimizer/filter_push_down.rs
index ef171e1..6141af1 100644
--- a/datafusion/src/optimizer/filter_push_down.rs
+++ b/datafusion/src/optimizer/filter_push_down.rs
@@ -18,7 +18,7 @@ use 
crate::datasource::datasource::TableProviderFilterPushDown;
 use crate::execution::context::ExecutionProps;
 use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection};
 use crate::logical_plan::{
-    and, replace_col, Column, CrossJoin, Limit, LogicalPlan, TableScan,
+    and, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan, 
TableScan,
 };
 use crate::logical_plan::{DFSchema, Expr};
 use crate::optimizer::optimizer::OptimizerRule;
@@ -50,9 +50,9 @@ use std::{
 /// Projection: #a AS b
 ///     Filter: #a Gt Int64(10)  <--- changed from #b to #a
 ///
-/// This performs a single pass trought the plan. When it passes trought a 
filter, it stores that filter,
+/// This performs a single pass through the plan. When it passes through a 
filter, it stores that filter,
 /// and when it reaches a node that does not commute with it, it adds the 
filter to that place.
-/// When it passes through a projection, it re-writes the filter's expression 
taking into accoun that projection.
+/// When it passes through a projection, it re-writes the filter's expression 
taking into account that projection.
 /// When multiple filters would have been written, it `AND` their expressions 
into a single expression.
 #[derive(Default)]
 pub struct FilterPushDown {}
@@ -83,83 +83,6 @@ fn get_predicates<'a>(
         .unzip()
 }
 
-// returns 3 (potentially overlaping) sets of predicates:
-// * pushable to left: its columns are all on the left
-// * pushable to right: its columns is all on the right
-// * keep: the set of columns is not in only either left or right
-// Note that a predicate can be both pushed to the left and to the right.
-fn get_join_predicates<'a>(
-    state: &'a State,
-    left: &DFSchema,
-    right: &DFSchema,
-) -> (
-    Vec<&'a HashSet<Column>>,
-    Vec<&'a HashSet<Column>>,
-    Predicates<'a>,
-) {
-    let left_columns = &left
-        .fields()
-        .iter()
-        .map(|f| {
-            [
-                f.qualified_column(),
-                // we need to push down filter using unqualified column as well
-                f.unqualified_column(),
-            ]
-        })
-        .flatten()
-        .collect::<HashSet<_>>();
-    let right_columns = &right
-        .fields()
-        .iter()
-        .map(|f| {
-            [
-                f.qualified_column(),
-                // we need to push down filter using unqualified column as well
-                f.unqualified_column(),
-            ]
-        })
-        .flatten()
-        .collect::<HashSet<_>>();
-
-    let filters = state
-        .filters
-        .iter()
-        .map(|(predicate, columns)| {
-            (
-                (predicate, columns),
-                (
-                    columns,
-                    left_columns.intersection(columns).collect::<HashSet<_>>(),
-                    
right_columns.intersection(columns).collect::<HashSet<_>>(),
-                ),
-            )
-        })
-        .collect::<Vec<_>>();
-
-    let pushable_to_left = filters
-        .iter()
-        .filter(|(_, (columns, left, _))| left.len() == columns.len())
-        .map(|((_, b), _)| *b)
-        .collect();
-    let pushable_to_right = filters
-        .iter()
-        .filter(|(_, (columns, _, right))| right.len() == columns.len())
-        .map(|((_, b), _)| *b)
-        .collect();
-    let keep = filters
-        .iter()
-        .filter(|(_, (columns, left, right))| {
-            // predicates whose columns are not in only one side of the join 
need to remain
-            let all_in_left = left.len() == columns.len();
-            let all_in_right = right.len() == columns.len();
-            !all_in_left && !all_in_right
-        })
-        .map(|((a, b), _)| (a, b))
-        .unzip();
-    (pushable_to_left, pushable_to_right, keep)
-}
-
 /// Optimizes the plan
 fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
     let new_inputs = plan
@@ -204,11 +127,11 @@ fn remove_filters(
 // keeps all filters from `filters` that are in `predicate_columns`
 fn keep_filters(
     filters: &[(Expr, HashSet<Column>)],
-    predicate_columns: &[&HashSet<Column>],
+    relevant_predicates: &Predicates,
 ) -> Vec<(Expr, HashSet<Column>)> {
     filters
         .iter()
-        .filter(|(_, columns)| predicate_columns.contains(&columns))
+        .filter(|(expr, _)| relevant_predicates.0.contains(&expr))
         .cloned()
         .collect::<Vec<_>>()
 }
@@ -253,33 +176,122 @@ fn split_members<'a>(predicate: &'a Expr, predicates: 
&mut Vec<&'a Expr>) {
     }
 }
 
+// For a given JOIN logical plan, determine whether each side of the join is 
preserved.
+// We say a join side is preserved if the join returns all or a subset of the 
rows from
+// the relevant side, such that each row of the output table directly maps to 
a row of
+// the preserved input table. If a table is not preserved, it can provide 
extra null rows.
+// That is, there may be rows in the output table that don't directly map to a 
row in the
+// input table.
+//
+// For example:
+//   - In an inner join, both sides are preserved, because each row of the 
output
+//     maps directly to a row from each side.
+//   - In a left join, the left side is preserved and the right is not, because
+//     there may be rows in the output that don't directly map to a row in the
+//     right input (due to nulls filling where there is no match on the right).
+//
+// This is important because we can always push down post-join filters to a 
preserved
+// side of the join, assuming the filter only references columns from that 
side. For the
+// non-preserved side it can be more tricky.
+//
+// Returns a tuple of booleans - (left_preserved, right_preserved).
+fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
+    match plan {
+        LogicalPlan::Join(Join { join_type, .. }) => match join_type {
+            JoinType::Inner => (true, true),
+            JoinType::Left => (true, false),
+            JoinType::Right => (false, true),
+            JoinType::Full => (false, false),
+            // No columns from the right side of the join can be referenced in 
output
+            // predicates for semi/anti joins, so whether we specify t/f 
doesn't matter.
+            JoinType::Semi | JoinType::Anti => (true, false),
+        },
+        LogicalPlan::CrossJoin(_) => (true, true),
+        _ => unreachable!("lr_is_preserved only valid for JOIN nodes"),
+    }
+}
+
+// Determine which predicates in state can be pushed down to a given side of a 
join.
+// To determine this, we need to know the schema of the relevant join side and 
whether
+// or not the side's rows are preserved when joining. If the side is not 
preserved, we
+// do not push down anything. Otherwise we can push down predicates where all 
of the
+// relevant columns are contained on the relevant join side's schema.
+fn get_pushable_join_predicates<'a>(
+    state: &'a State,
+    schema: &DFSchema,
+    preserved: bool,
+) -> Predicates<'a> {
+    if !preserved {
+        return (vec![], vec![]);
+    }
+
+    let schema_columns = schema
+        .fields()
+        .iter()
+        .map(|f| {
+            [
+                f.qualified_column(),
+                // we need to push down filter using unqualified column as well
+                f.unqualified_column(),
+            ]
+        })
+        .flatten()
+        .collect::<HashSet<_>>();
+
+    state
+        .filters
+        .iter()
+        .filter(|(_, columns)| {
+            let all_columns_in_schema = schema_columns
+                .intersection(columns)
+                .collect::<HashSet<_>>()
+                .len()
+                == columns.len();
+            all_columns_in_schema
+        })
+        .map(|(a, b)| (a, b))
+        .unzip()
+}
+
 fn optimize_join(
     mut state: State,
     plan: &LogicalPlan,
     left: &LogicalPlan,
     right: &LogicalPlan,
 ) -> Result<LogicalPlan> {
-    let (pushable_to_left, pushable_to_right, keep) =
-        get_join_predicates(&state, left.schema(), right.schema());
+    let (left_preserved, right_preserved) = lr_is_preserved(plan);
+    let to_left = get_pushable_join_predicates(&state, left.schema(), 
left_preserved);
+    let to_right = get_pushable_join_predicates(&state, right.schema(), 
right_preserved);
+
+    let to_keep: Predicates = state
+        .filters
+        .iter()
+        .filter(|(expr, _)| {
+            let pushed_to_left = to_left.0.contains(&expr);
+            let pushed_to_right = to_right.0.contains(&expr);
+            !pushed_to_left && !pushed_to_right
+        })
+        .map(|(a, b)| (a, b))
+        .unzip();
 
     let mut left_state = state.clone();
-    left_state.filters = keep_filters(&left_state.filters, &pushable_to_left);
+    left_state.filters = keep_filters(&left_state.filters, &to_left);
     let left = optimize(left, left_state)?;
 
     let mut right_state = state.clone();
-    right_state.filters = keep_filters(&right_state.filters, 
&pushable_to_right);
+    right_state.filters = keep_filters(&right_state.filters, &to_right);
     let right = optimize(right, right_state)?;
 
     // create a new Join with the new `left` and `right`
     let expr = plan.expressions();
     let plan = utils::from_plan(plan, &expr, &[left, right])?;
 
-    if keep.0.is_empty() {
+    if to_keep.0.is_empty() {
         Ok(plan)
     } else {
         // wrap the join on the filter whose predicates must be kept
-        let plan = add_filter(plan, &keep.0);
-        state.filters = remove_filters(&state.filters, &keep.1);
+        let plan = add_filter(plan, &to_keep.0);
+        state.filters = remove_filters(&state.filters, &to_keep.1);
 
         Ok(plan)
     }
@@ -400,63 +412,68 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> 
Result<LogicalPlan> {
             optimize_join(state, plan, left, right)
         }
         LogicalPlan::Join(Join {
-            left, right, on, ..
+            left,
+            right,
+            on,
+            join_type,
+            ..
         }) => {
-            // duplicate filters for joined columns so filters can be pushed 
down to both sides.
-            // Take the following query as an example:
-            //
-            // ```sql
-            // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
-            // ```
-            //
-            // `t1.id > 1` predicate needs to be pushed down to t1 table scan, 
while
-            // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
-            //
-            // Join clauses with `Using` constraints also take advantage of 
this logic to make sure
-            // predicates reference the shared join columns are pushed to both 
sides.
-            let join_side_filters = state
-                .filters
-                .iter()
-                .filter_map(|(predicate, columns)| {
-                    let mut join_cols_to_replace = HashMap::new();
-                    for col in columns.iter() {
-                        for (l, r) in on {
-                            if col == l {
-                                join_cols_to_replace.insert(col, r);
-                                break;
-                            } else if col == r {
-                                join_cols_to_replace.insert(col, l);
-                                break;
+            if *join_type == JoinType::Inner {
+                // For inner joins, duplicate filters for joined columns so 
filters can be pushed down
+                // to both sides. Take the following query as an example:
+                //
+                // ```sql
+                // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+                // ```
+                //
+                // `t1.id > 1` predicate needs to be pushed down to t1 table 
scan, while
+                // `t2.uid > 1` predicate needs to be pushed down to t2 table 
scan.
+                //
+                // Join clauses with `Using` constraints also take advantage 
of this logic to make sure
+                // predicates reference the shared join columns are pushed to 
both sides.
+                let join_side_filters = state
+                    .filters
+                    .iter()
+                    .filter_map(|(predicate, columns)| {
+                        let mut join_cols_to_replace = HashMap::new();
+                        for col in columns.iter() {
+                            for (l, r) in on {
+                                if col == l {
+                                    join_cols_to_replace.insert(col, r);
+                                    break;
+                                } else if col == r {
+                                    join_cols_to_replace.insert(col, l);
+                                    break;
+                                }
                             }
                         }
-                    }
 
-                    if join_cols_to_replace.is_empty() {
-                        return None;
-                    }
-
-                    let join_side_predicate =
-                        match replace_col(predicate.clone(), 
&join_cols_to_replace) {
-                            Ok(p) => p,
-                            Err(e) => {
-                                return Some(Err(e));
-                            }
-                        };
-
-                    let join_side_columns = columns
-                        .clone()
-                        .into_iter()
-                        // replace keys in join_cols_to_replace with values in 
resulting column
-                        // set
-                        .filter(|c| !join_cols_to_replace.contains_key(c))
-                        .chain(join_cols_to_replace.iter().map(|(_, v)| 
(*v).clone()))
-                        .collect();
-
-                    Some(Ok((join_side_predicate, join_side_columns)))
-                })
-                .collect::<Result<Vec<_>>>()?;
-            state.filters.extend(join_side_filters);
+                        if join_cols_to_replace.is_empty() {
+                            return None;
+                        }
 
+                        let join_side_predicate =
+                            match replace_col(predicate.clone(), 
&join_cols_to_replace) {
+                                Ok(p) => p,
+                                Err(e) => {
+                                    return Some(Err(e));
+                                }
+                            };
+
+                        let join_side_columns = columns
+                            .clone()
+                            .into_iter()
+                            // replace keys in join_cols_to_replace with 
values in resulting column
+                            // set
+                            .filter(|c| !join_cols_to_replace.contains_key(c))
+                            .chain(join_cols_to_replace.iter().map(|(_, v)| 
(*v).clone()))
+                            .collect();
+
+                        Some(Ok((join_side_predicate, join_side_columns)))
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+                state.filters.extend(join_side_filters);
+            }
             optimize_join(state, plan, left, right)
         }
         LogicalPlan::TableScan(TableScan {
@@ -1139,6 +1156,170 @@ mod tests {
         Ok(())
     }
 
+    /// post-join predicates on the right side of a left join are not 
duplicated
+    /// TODO: In this case we can sometimes convert the join to an INNER join
+    #[test]
+    fn filter_using_left_join() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan).build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join_using(
+                &right,
+                JoinType::Left,
+                vec![Column::from_name("a".to_string())],
+            )?
+            .filter(col("test2.a").lt_eq(lit(1i64)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Filter: #test2.a <= Int64(1)\
+            \n  Join: Using #test.a = #test2.a\
+            \n    TableScan: test projection=None\
+            \n    Projection: #test2.a\
+            \n      TableScan: test2 projection=None"
+        );
+
+        // filter not duplicated nor pushed down - i.e. noop
+        let expected = "\
+        Filter: #test2.a <= Int64(1)\
+        \n  Join: Using #test.a = #test2.a\
+        \n    TableScan: test projection=None\
+        \n    Projection: #test2.a\
+        \n      TableScan: test2 projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    /// post-join predicates on the left side of a right join are not 
duplicated
+    /// TODO: In this case we can sometimes convert the join to an INNER join
+    #[test]
+    fn filter_using_right_join() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan).build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join_using(
+                &right,
+                JoinType::Right,
+                vec![Column::from_name("a".to_string())],
+            )?
+            .filter(col("test.a").lt_eq(lit(1i64)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Filter: #test.a <= Int64(1)\
+            \n  Join: Using #test.a = #test2.a\
+            \n    TableScan: test projection=None\
+            \n    Projection: #test2.a\
+            \n      TableScan: test2 projection=None"
+        );
+
+        // filter not duplicated nor pushed down - i.e. noop
+        let expected = "\
+        Filter: #test.a <= Int64(1)\
+        \n  Join: Using #test.a = #test2.a\
+        \n    TableScan: test projection=None\
+        \n    Projection: #test2.a\
+        \n      TableScan: test2 projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    /// post-left-join predicate on a column common to both sides is only 
pushed to the left side
+    /// i.e. - not duplicated to the right side
+    #[test]
+    fn filter_using_left_join_on_common() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan).build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join_using(
+                &right,
+                JoinType::Left,
+                vec![Column::from_name("a".to_string())],
+            )?
+            .filter(col("a").lt_eq(lit(1i64)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Filter: #test.a <= Int64(1)\
+            \n  Join: Using #test.a = #test2.a\
+            \n    TableScan: test projection=None\
+            \n    Projection: #test2.a\
+            \n      TableScan: test2 projection=None"
+        );
+
+        // filter sent to left side of the join, not the right
+        let expected = "\
+        Join: Using #test.a = #test2.a\
+        \n  Filter: #test.a <= Int64(1)\
+        \n    TableScan: test projection=None\
+        \n  Projection: #test2.a\
+        \n    TableScan: test2 projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    /// post-right-join predicate on a column common to both sides is only 
pushed to the right side
+    /// i.e. - not duplicated to the left side.
+    #[test]
+    fn filter_using_right_join_on_common() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan).build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join_using(
+                &right,
+                JoinType::Right,
+                vec![Column::from_name("a".to_string())],
+            )?
+            .filter(col("test2.a").lt_eq(lit(1i64)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Filter: #test2.a <= Int64(1)\
+            \n  Join: Using #test.a = #test2.a\
+            \n    TableScan: test projection=None\
+            \n    Projection: #test2.a\
+            \n      TableScan: test2 projection=None"
+        );
+
+        // filter sent to right side of join, not duplicated to the left
+        let expected = "\
+        Join: Using #test.a = #test2.a\
+        \n  TableScan: test projection=None\
+        \n  Projection: #test2.a\
+        \n    Filter: #test2.a <= Int64(1)\
+        \n      TableScan: test2 projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
     struct PushDownProvider {
         pub filter_support: TableProviderFilterPushDown,
     }
diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs
index 85b59e6..70d824b 100644
--- a/datafusion/tests/sql/joins.rs
+++ b/datafusion/tests/sql/joins.rs
@@ -212,6 +212,202 @@ async fn left_join_unbalanced() -> Result<()> {
 }
 
 #[tokio::test]
+async fn left_join_null_filter() -> Result<()> {
+    // Since t2 is the non-preserved side of the join, we cannot push down a 
NULL filter.
+    // Note that this is only true because IS NULL does not remove nulls. For 
filters that
+    // remove nulls, we can rewrite the join as an inner join and then push 
down the filter.
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = 
t2_id WHERE t2_name IS NULL ORDER BY t1_id";
+    let expected = vec![
+        "+-------+-------+---------+",
+        "| t1_id | t2_id | t2_name |",
+        "+-------+-------+---------+",
+        "| 22    | 22    |         |",
+        "| 33    |       |         |",
+        "| 77    |       |         |",
+        "| 88    |       |         |",
+        "+-------+-------+---------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn left_join_null_filter_on_join_column() -> Result<()> {
+    // Again, since t2 is the non-preserved side of the join, we cannot push 
down a NULL filter.
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = 
t2_id WHERE t2_id IS NULL ORDER BY t1_id";
+    let expected = vec![
+        "+-------+-------+---------+",
+        "| t1_id | t2_id | t2_name |",
+        "+-------+-------+---------+",
+        "| 33    |       |         |",
+        "| 77    |       |         |",
+        "| 88    |       |         |",
+        "+-------+-------+---------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn left_join_not_null_filter() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = 
t2_id WHERE t2_name IS NOT NULL ORDER BY t1_id";
+    let expected = vec![
+        "+-------+-------+---------+",
+        "| t1_id | t2_id | t2_name |",
+        "+-------+-------+---------+",
+        "| 11    | 11    | z       |",
+        "| 44    | 44    | x       |",
+        "| 99    | 99    | u       |",
+        "+-------+-------+---------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn left_join_not_null_filter_on_join_column() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = 
t2_id WHERE t2_id IS NOT NULL ORDER BY t1_id";
+    let expected = vec![
+        "+-------+-------+---------+",
+        "| t1_id | t2_id | t2_name |",
+        "+-------+-------+---------+",
+        "| 11    | 11    | z       |",
+        "| 22    | 22    |         |",
+        "| 44    | 44    | x       |",
+        "| 99    | 99    | u       |",
+        "+-------+-------+---------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn right_join_null_filter() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = 
t2_id WHERE t1_name IS NULL ORDER BY t2_id";
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "|       |         | 55    |",
+        "| 99    |         | 99    |",
+        "+-------+---------+-------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn right_join_null_filter_on_join_column() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = 
t2_id WHERE t1_id IS NULL ORDER BY t2_id";
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "|       |         | 55    |",
+        "+-------+---------+-------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn right_join_not_null_filter() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = 
t2_id WHERE t1_name IS NOT NULL ORDER BY t2_id";
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "| 11    | a       | 11    |",
+        "| 22    | b       | 22    |",
+        "| 44    | d       | 44    |",
+        "+-------+---------+-------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn right_join_not_null_filter_on_join_column() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = 
t2_id WHERE t1_id IS NOT NULL ORDER BY t2_id";
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "| 11    | a       | 11    |",
+        "| 22    | b       | 22    |",
+        "| 44    | d       | 44    |",
+        "| 99    |         | 99    |",
+        "+-------+---------+-------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn full_join_null_filter() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON 
t1_id = t2_id WHERE t1_name IS NULL ORDER BY t1_id";
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "| 88    |         |       |",
+        "| 99    |         | 99    |",
+        "|       |         | 55    |",
+        "+-------+---------+-------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn full_join_not_null_filter() -> Result<()> {
+    let mut ctx = create_join_context_with_nulls()?;
+    let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON 
t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t1_id";
+    let expected = vec![
+        "+-------+---------+-------+",
+        "| t1_id | t1_name | t2_id |",
+        "+-------+---------+-------+",
+        "| 11    | a       | 11    |",
+        "| 22    | b       | 22    |",
+        "| 33    | c       |       |",
+        "| 44    | d       | 44    |",
+        "| 77    | e       |       |",
+        "+-------+---------+-------+",
+    ];
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
 async fn right_join() -> Result<()> {
     let mut ctx = create_join_context("t1_id", "t2_id")?;
     let equivalent_sql = [
diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs
index 46c33f7..55715af 100644
--- a/datafusion/tests/sql/mod.rs
+++ b/datafusion/tests/sql/mod.rs
@@ -297,6 +297,55 @@ fn create_join_context_unbalanced(
     Ok(ctx)
 }
 
+// Create memory tables with nulls
+fn create_join_context_with_nulls() -> Result<ExecutionContext> {
+    let mut ctx = ExecutionContext::new();
+
+    let t1_schema = Arc::new(Schema::new(vec![
+        Field::new("t1_id", DataType::UInt32, true),
+        Field::new("t1_name", DataType::Utf8, true),
+    ]));
+    let t1_data = RecordBatch::try_new(
+        t1_schema.clone(),
+        vec![
+            Arc::new(UInt32Array::from(vec![11, 22, 33, 44, 77, 88, 99])),
+            Arc::new(StringArray::from(vec![
+                Some("a"),
+                Some("b"),
+                Some("c"),
+                Some("d"),
+                Some("e"),
+                None,
+                None,
+            ])),
+        ],
+    )?;
+    let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?;
+    ctx.register_table("t1", Arc::new(t1_table))?;
+
+    let t2_schema = Arc::new(Schema::new(vec![
+        Field::new("t2_id", DataType::UInt32, true),
+        Field::new("t2_name", DataType::Utf8, true),
+    ]));
+    let t2_data = RecordBatch::try_new(
+        t2_schema.clone(),
+        vec![
+            Arc::new(UInt32Array::from(vec![11, 22, 44, 55, 99])),
+            Arc::new(StringArray::from(vec![
+                Some("z"),
+                None,
+                Some("x"),
+                Some("w"),
+                Some("u"),
+            ])),
+        ],
+    )?;
+    let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?;
+    ctx.register_table("t2", Arc::new(t2_table))?;
+
+    Ok(ctx)
+}
+
 fn get_tpch_table_schema(table: &str) -> Schema {
     match table {
         "customer" => Schema::new(vec![

Reply via email to