This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c8de85e66 Add optimizer pass to reduce `left`/`right`/`full` joins to 
`inner` join if possible (#2750)
c8de85e66 is described below

commit c8de85e6623d70b6920a980b52c44644bfffc9c0
Author: AssHero <[email protected]>
AuthorDate: Tue Jun 28 07:39:24 2022 +0800

    Add optimizer pass to reduce `left`/`right`/`full` joins to `inner` join if 
possible (#2750)
    
    * try to reduce left/right/full join to inner join
    
    * split the test cases
    
    * Implementing reduce outer join as optimization rule
    
    * minor improvements
    
    * Update explain plan format
    
    * Update more plans
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/execution/context.rs      |   2 +
 datafusion/core/tests/sql/joins.rs            | 348 ++++++++++++++++++
 datafusion/optimizer/src/lib.rs               |   1 +
 datafusion/optimizer/src/reduce_outer_join.rs | 484 ++++++++++++++++++++++++++
 4 files changed, 835 insertions(+)

diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index 16cd1adc2..1eb34c710 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -71,6 +71,7 @@ use crate::optimizer::filter_push_down::FilterPushDown;
 use crate::optimizer::limit_push_down::LimitPushDown;
 use crate::optimizer::optimizer::{OptimizerConfig, OptimizerRule};
 use crate::optimizer::projection_push_down::ProjectionPushDown;
+use crate::optimizer::reduce_outer_join::ReduceOuterJoin;
 use crate::optimizer::simplify_expressions::SimplifyExpressions;
 use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
 use crate::optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
@@ -1243,6 +1244,7 @@ impl SessionState {
         if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) {
             rules.push(Arc::new(FilterNullJoinKeys::default()));
         }
+        rules.push(Arc::new(ReduceOuterJoin::new()));
         rules.push(Arc::new(FilterPushDown::new()));
         rules.push(Arc::new(LimitPushDown::new()));
         rules.push(Arc::new(SingleDistinctToGroupBy::new()));
diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index e4876e638..d7d7a0cc6 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1375,3 +1375,351 @@ async fn hash_join_with_dictionary() -> Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn reduce_left_join_1() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce to inner join
+    let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where 
t2.t2_id < 100";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "      Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      Filter: #t2.t2_id < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "| 11    | a       | 1      | 11    | z       | 3      |",
+        "| 22    | b       | 2      | 22    | y       | 1      |",
+        "| 44    | d       | 4      | 44    | x       | 3      |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_left_join_2() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce to inner join
+    let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where 
t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Filter: #t2.t2_int < Int64(10) OR #t1.t1_int > Int64(2) AND 
#t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "      Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "| 11    | a       | 1      | 11    | z       | 3      |",
+        "| 22    | b       | 2      | 22    | y       | 1      |",
+        "| 44    | d       | 4      | 44    | x       | 3      |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_left_join_3() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce subquery to inner join
+    let sql = "select * from (select t1.* from t1 left join t2 on t1.t1_id = 
t2.t2_id where t2.t2_int < 3) t3 left join t2 on t3.t1_int = t2.t2_int where 
t3.t1_id < 100";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t3.t1_id, #t3.t1_name, #t3.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Left Join: #t3.t1_int = #t2.t2_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "      Projection: #t3.t1_id, #t3.t1_name, #t3.t1_int, alias=t3 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, alias=t3 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "          Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "            Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "              TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "            Filter: #t2.t2_int < Int64(3) AND #t2.t2_id < Int64(100) 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "              TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "      TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "| 22    | b       | 2      |       |         |        |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_right_join_1() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce to inner join
+    let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where 
t1.t1_int is not null";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "      Filter: #t1.t1_int IS NOT NULL [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "| 11    | a       | 1      | 11    | z       | 3      |",
+        "| 22    | b       | 2      | 22    | y       | 1      |",
+        "| 44    | d       | 4      | 44    | x       | 3      |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_right_join_2() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce to inner join
+    let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where 
not(t1.t1_int = t2.t2_int)";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Filter: NOT #t1.t1_int = #t2.t2_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "      Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "| 11    | a       | 1      | 11    | z       | 3      |",
+        "| 22    | b       | 2      | 22    | y       | 1      |",
+        "| 44    | d       | 4      | 44    | x       | 3      |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_full_join_to_right_join() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce to right join
+    let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where 
t2.t2_name is not null";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Right Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      Filter: #t2.t2_name IS NOT NULL [t2_id:UInt32;N, 
t2_name:Utf8;N, t2_int:UInt32;N]",
+        "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "|       |         |        | 55    | w       | 3      |",
+        "| 11    | a       | 1      | 11    | z       | 3      |",
+        "| 22    | b       | 2      | 22    | y       | 1      |",
+        "| 44    | d       | 4      | 44    | x       | 3      |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_full_join_to_left_join() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce to left join
+    let sql =
+        "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t1.t1_name 
!= 'b'";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Left Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "      Filter: #t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "| 11    | a       | 1      | 11    | z       | 3      |",
+        "| 33    | c       | 3      |       |         |        |",
+        "| 44    | d       | 4      | 44    | x       | 3      |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_full_join_to_inner_join() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id")?;
+
+    // reduce to inner join
+    let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where 
t1.t1_name != 'b' and t2.t2_name = 'x'";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx
+        .create_logical_plan(&("explain ".to_owned() + sql))
+        .expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "      Filter: #t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      Filter: #t2.t2_name = Utf8(\"x\") [t2_id:UInt32;N, 
t2_name:Utf8;N, t2_int:UInt32;N]",
+        "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    let expected = vec![
+        "+-------+---------+--------+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+        "+-------+---------+--------+-------+---------+--------+",
+        "| 44    | d       | 4      | 44    | x       | 3      |",
+        "+-------+---------+--------+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index c01ea3c63..a6b7cfcbb 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -24,6 +24,7 @@ pub mod filter_push_down;
 pub mod limit_push_down;
 pub mod optimizer;
 pub mod projection_push_down;
+pub mod reduce_outer_join;
 pub mod simplify_expressions;
 pub mod single_distinct_to_groupby;
 pub mod subquery_filter_to_join;
diff --git a/datafusion/optimizer/src/reduce_outer_join.rs 
b/datafusion/optimizer/src/reduce_outer_join.rs
new file mode 100644
index 000000000..fa6c95075
--- /dev/null
+++ b/datafusion/optimizer/src/reduce_outer_join.rs
@@ -0,0 +1,484 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Optimizer rule to reduce left/right/full join to inner join if possible.
+use crate::{OptimizerConfig, OptimizerRule};
+use datafusion_common::{Column, DFSchema, Result};
+use datafusion_expr::{
+    logical_plan::{Filter, Join, JoinType, LogicalPlan, Projection},
+    utils::from_plan,
+};
+use datafusion_expr::{Expr, Operator};
+
+use std::collections::HashMap;
+use std::sync::Arc;
+
+#[derive(Default)]
+pub struct ReduceOuterJoin;
+
+impl ReduceOuterJoin {
+    #[allow(missing_docs)]
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for ReduceOuterJoin {
+    fn optimize(
+        &self,
+        plan: &LogicalPlan,
+        optimizer_config: &OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        let mut nonnullable_cols: Vec<Column> = vec![];
+
+        reduce_outer_join(self, plan, &mut nonnullable_cols, optimizer_config)
+    }
+
+    fn name(&self) -> &str {
+        "reduce_outer_join"
+    }
+}
+
+/// Attempt to reduce outer joins to inner joins.
+/// for query: select ... from a left join b on ... where b.xx = 100;
+/// if b.xx is null, and b.xx = 100 returns false, filterd those null rows.
+/// Therefore, there is no need to produce null rows for output, we can use
+/// inner join instead of left join.
+///
+/// Generally, an outer join can be reduced to inner join if quals from where
+/// return false while any inputs are null and columns of those quals are come 
from
+/// nullable side of outer join.
+fn reduce_outer_join(
+    _optimizer: &ReduceOuterJoin,
+    plan: &LogicalPlan,
+    nonnullable_cols: &mut Vec<Column>,
+    _optimizer_config: &OptimizerConfig,
+) -> Result<LogicalPlan> {
+    match plan {
+        LogicalPlan::Filter(Filter { input, predicate }) => match &**input {
+            LogicalPlan::Join(join) => {
+                extract_nonnullable_columns(
+                    predicate,
+                    nonnullable_cols,
+                    join.left.schema(),
+                    join.right.schema(),
+                    true,
+                )?;
+                Ok(LogicalPlan::Filter(Filter {
+                    predicate: predicate.clone(),
+                    input: Arc::new(reduce_outer_join(
+                        _optimizer,
+                        input,
+                        nonnullable_cols,
+                        _optimizer_config,
+                    )?),
+                }))
+            }
+            _ => Ok(LogicalPlan::Filter(Filter {
+                predicate: predicate.clone(),
+                input: Arc::new(reduce_outer_join(
+                    _optimizer,
+                    input,
+                    nonnullable_cols,
+                    _optimizer_config,
+                )?),
+            })),
+        },
+        LogicalPlan::Join(join) => {
+            let mut new_join_type = join.join_type;
+
+            if join.join_type == JoinType::Left
+                || join.join_type == JoinType::Right
+                || join.join_type == JoinType::Full
+            {
+                let mut left_nonnullable = false;
+                let mut right_nonnullable = false;
+                for col in nonnullable_cols.iter() {
+                    if join.left.schema().field_from_column(col).is_ok() {
+                        left_nonnullable = true;
+                    }
+                    if join.right.schema().field_from_column(col).is_ok() {
+                        right_nonnullable = true;
+                    }
+                }
+
+                match join.join_type {
+                    JoinType::Left => {
+                        if right_nonnullable {
+                            new_join_type = JoinType::Inner;
+                        }
+                    }
+                    JoinType::Right => {
+                        if left_nonnullable {
+                            new_join_type = JoinType::Inner;
+                        }
+                    }
+                    JoinType::Full => {
+                        if left_nonnullable && right_nonnullable {
+                            new_join_type = JoinType::Inner;
+                        } else if left_nonnullable {
+                            new_join_type = JoinType::Left;
+                        } else if right_nonnullable {
+                            new_join_type = JoinType::Right;
+                        }
+                    }
+                    _ => {}
+                };
+            }
+
+            let left_plan = reduce_outer_join(
+                _optimizer,
+                &join.left,
+                &mut nonnullable_cols.clone(),
+                _optimizer_config,
+            )?;
+            let right_plan = reduce_outer_join(
+                _optimizer,
+                &join.right,
+                &mut nonnullable_cols.clone(),
+                _optimizer_config,
+            )?;
+
+            Ok(LogicalPlan::Join(Join {
+                left: Arc::new(left_plan),
+                right: Arc::new(right_plan),
+                join_type: new_join_type,
+                join_constraint: join.join_constraint,
+                on: join.on.clone(),
+                filter: join.filter.clone(),
+                schema: join.schema.clone(),
+                null_equals_null: join.null_equals_null,
+            }))
+        }
+        LogicalPlan::Projection(Projection {
+            input,
+            expr,
+            schema,
+            alias: _,
+        }) => {
+            let projection = schema
+                .fields()
+                .iter()
+                .enumerate()
+                .map(|(i, field)| {
+                    // strip alias, as they should not be part of filters
+                    let expr = match &expr[i] {
+                        Expr::Alias(expr, _) => expr.as_ref().clone(),
+                        expr => expr.clone(),
+                    };
+
+                    (field.qualified_name(), expr)
+                })
+                .collect::<HashMap<_, _>>();
+
+            // re-write all Columns based on this projection
+            for col in nonnullable_cols.iter_mut() {
+                if let Some(Expr::Column(column)) = 
projection.get(&col.flat_name()) {
+                    *col = column.clone();
+                }
+            }
+
+            // optimize inner
+            let new_input = reduce_outer_join(
+                _optimizer,
+                input,
+                nonnullable_cols,
+                _optimizer_config,
+            )?;
+
+            from_plan(plan, expr, &[new_input])
+        }
+        _ => {
+            let expr = plan.expressions();
+
+            // apply the optimization to all inputs of the plan
+            let inputs = plan.inputs();
+            let new_inputs = inputs
+                .iter()
+                .map(|plan| {
+                    reduce_outer_join(
+                        _optimizer,
+                        plan,
+                        nonnullable_cols,
+                        _optimizer_config,
+                    )
+                })
+                .collect::<Result<Vec<_>>>()?;
+
+            from_plan(plan, &expr, &new_inputs)
+        }
+    }
+}
+
+/// Recursively traversese expr, if expr returns false when
+/// any inputs are null, treats columns of both sides as nonnullable columns.
+///
+/// For and/or expr, extracts from all sub exprs and merges the columns.
+/// For or expr, if one of sub exprs returns true, discards all columns from 
or expr.
+/// For IS NOT NULL/NOT expr, always returns false for NULL input.
+///     extracts columns from these exprs.
+/// For all other exprs, fall through
+fn extract_nonnullable_columns(
+    expr: &Expr,
+    nonnullable_cols: &mut Vec<Column>,
+    left_schema: &Arc<DFSchema>,
+    right_schema: &Arc<DFSchema>,
+    top_level: bool,
+) -> Result<()> {
+    match expr {
+        Expr::Column(col) => {
+            nonnullable_cols.push(col.clone());
+            Ok(())
+        }
+        Expr::BinaryExpr { left, op, right } => match op {
+            // If one of the inputs are null for these operators, the results 
should be false.
+            Operator::Eq
+            | Operator::NotEq
+            | Operator::Lt
+            | Operator::LtEq
+            | Operator::Gt
+            | Operator::GtEq => {
+                extract_nonnullable_columns(
+                    left,
+                    nonnullable_cols,
+                    left_schema,
+                    right_schema,
+                    false,
+                )?;
+                extract_nonnullable_columns(
+                    right,
+                    nonnullable_cols,
+                    left_schema,
+                    right_schema,
+                    false,
+                )
+            }
+            Operator::And | Operator::Or => {
+                // treat And as Or if does not from top level, such as
+                // not (c1 < 10 and c2 > 100)
+                if top_level && *op == Operator::And {
+                    extract_nonnullable_columns(
+                        left,
+                        nonnullable_cols,
+                        left_schema,
+                        right_schema,
+                        top_level,
+                    )?;
+                    extract_nonnullable_columns(
+                        right,
+                        nonnullable_cols,
+                        left_schema,
+                        right_schema,
+                        top_level,
+                    )?;
+                    return Ok(());
+                }
+
+                let mut left_nonnullable_cols: Vec<Column> = vec![];
+                let mut right_nonnullable_cols: Vec<Column> = vec![];
+
+                extract_nonnullable_columns(
+                    left,
+                    &mut left_nonnullable_cols,
+                    left_schema,
+                    right_schema,
+                    top_level,
+                )?;
+                extract_nonnullable_columns(
+                    right,
+                    &mut right_nonnullable_cols,
+                    left_schema,
+                    right_schema,
+                    top_level,
+                )?;
+
+                // for query: select *** from a left join b where b.c1 ... or 
b.c2 ...
+                // this can be reduced to inner join.
+                // for query: select *** from a left join b where a.c1 ... or 
b.c2 ...
+                // this can not be reduced.
+                // If columns of relation exist in both sub exprs, any columns 
of this relation
+                // can be added to non nullable columns.
+                if !left_nonnullable_cols.is_empty() && 
!right_nonnullable_cols.is_empty()
+                {
+                    for left_col in &left_nonnullable_cols {
+                        for right_col in &right_nonnullable_cols {
+                            if (left_schema.field_from_column(left_col).is_ok()
+                                && 
left_schema.field_from_column(right_col).is_ok())
+                                || 
(right_schema.field_from_column(left_col).is_ok()
+                                    && 
right_schema.field_from_column(right_col).is_ok())
+                            {
+                                nonnullable_cols.push(left_col.clone());
+                                break;
+                            }
+                        }
+                    }
+                }
+                Ok(())
+            }
+            _ => Ok(()),
+        },
+        Expr::Not(arg) => extract_nonnullable_columns(
+            arg,
+            nonnullable_cols,
+            left_schema,
+            right_schema,
+            false,
+        ),
+        Expr::IsNotNull(arg) => {
+            if !top_level {
+                return Ok(());
+            }
+            extract_nonnullable_columns(
+                arg,
+                nonnullable_cols,
+                left_schema,
+                right_schema,
+                false,
+            )
+        }
+        _ => Ok(()),
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::test::*;
+    use datafusion_expr::{
+        binary_expr, col, lit,
+        logical_plan::builder::LogicalPlanBuilder,
+        Operator::{And, Or},
+    };
+
+    fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
+        let rule = ReduceOuterJoin::new();
+        let optimized_plan = rule
+            .optimize(plan, &OptimizerConfig::new())
+            .expect("failed to optimize plan");
+        let formatted_plan = format!("{:?}", optimized_plan);
+        assert_eq!(formatted_plan, expected);
+        assert_eq!(plan.schema(), optimized_plan.schema());
+    }
+
+    #[test]
+    fn reduce_left_with_null() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // could not reduce to inner join
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                &t2,
+                JoinType::Left,
+                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
+                None,
+            )?
+            .filter(col("t2.b").is_null())?
+            .build()?;
+        let expected = "\
+        Filter: #t2.b IS NULL\
+        \n  Left Join: #t1.a = #t2.a\
+        \n    TableScan: t1\
+        \n    TableScan: t2";
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn reduce_left_with_not_null() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // reduce to inner join
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                &t2,
+                JoinType::Left,
+                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
+                None,
+            )?
+            .filter(col("t2.b").is_not_null())?
+            .build()?;
+        let expected = "\
+        Filter: #t2.b IS NOT NULL\
+        \n  Inner Join: #t1.a = #t2.a\
+        \n    TableScan: t1\
+        \n    TableScan: t2";
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn reduce_right_with_or() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // reduce to inner join
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                &t2,
+                JoinType::Right,
+                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
+                None,
+            )?
+            .filter(binary_expr(
+                col("t1.b").gt(lit(10u32)),
+                Or,
+                col("t1.c").lt(lit(20u32)),
+            ))?
+            .build()?;
+        let expected = "\
+        Filter: #t1.b > UInt32(10) OR #t1.c < UInt32(20)\
+        \n  Inner Join: #t1.a = #t2.a\
+        \n    TableScan: t1\
+        \n    TableScan: t2";
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn reduce_full_with_and() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // reduce to inner join
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                &t2,
+                JoinType::Full,
+                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
+                None,
+            )?
+            .filter(binary_expr(
+                col("t1.b").gt(lit(10u32)),
+                And,
+                col("t2.c").lt(lit(20u32)),
+            ))?
+            .build()?;
+        let expected = "\
+        Filter: #t1.b > UInt32(10) AND #t2.c < UInt32(20)\
+        \n  Inner Join: #t1.a = #t2.a\
+        \n    TableScan: t1\
+        \n    TableScan: t2";
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+}

Reply via email to