This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 89e71ef5fd [Optimization] Infer predicate under all JoinTypes (#13081)
89e71ef5fd is described below

commit 89e71ef5fd058278f1a0cc659dccf1757def3a62
Author: JasonLi <[email protected]>
AuthorDate: Tue Oct 29 19:31:44 2024 +0800

    [Optimization] Infer predicate under all JoinTypes (#13081)
    
    * optimize infer join predicate
    
    * pass clippy
    
    * chores: remove unnecessary debug code
---
 datafusion/optimizer/src/push_down_filter.rs | 393 +++++++++++++++++++++++----
 datafusion/optimizer/src/utils.rs            | 171 +++++++++++-
 2 files changed, 508 insertions(+), 56 deletions(-)

diff --git a/datafusion/optimizer/src/push_down_filter.rs 
b/datafusion/optimizer/src/push_down_filter.rs
index f8e614a0aa..a0262d7d95 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -36,7 +36,7 @@ use datafusion_expr::{
 };
 
 use crate::optimizer::ApplyOrder;
-use crate::utils::has_all_column_refs;
+use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
 use crate::{OptimizerConfig, OptimizerRule};
 
 /// Optimizer rule for pushing (moving) filter expressions down in a plan so
@@ -558,10 +558,6 @@ fn infer_join_predicates(
     predicates: &[Expr],
     on_filters: &[Expr],
 ) -> Result<Vec<Expr>> {
-    if join.join_type != JoinType::Inner {
-        return Ok(vec![]);
-    }
-
     // Only allow both side key is column.
     let join_col_keys = join
         .on
@@ -573,55 +569,176 @@ fn infer_join_predicates(
         })
         .collect::<Vec<_>>();
 
-    // TODO refine the logic, introduce EquivalenceProperties to logical plan 
and infer additional filters to push down
-    // For inner joins, duplicate filters for joined columns so filters can be 
pushed down
-    // to both sides. Take the following query as an example:
-    //
-    // ```sql
-    // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
-    // ```
-    //
-    // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
-    // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
-    //
-    // Join clauses with `Using` constraints also take advantage of this logic 
to make sure
-    // predicates reference the shared join columns are pushed to both sides.
-    // This logic should also been applied to conditions in JOIN ON clause
-    predicates
-        .iter()
-        .chain(on_filters.iter())
-        .filter_map(|predicate| {
-            let mut join_cols_to_replace = HashMap::new();
-
-            let columns = predicate.column_refs();
-
-            for &col in columns.iter() {
-                for (l, r) in join_col_keys.iter() {
-                    if col == *l {
-                        join_cols_to_replace.insert(col, *r);
-                        break;
-                    } else if col == *r {
-                        join_cols_to_replace.insert(col, *l);
-                        break;
-                    }
-                }
-            }
+    let join_type = join.join_type;
 
-            if join_cols_to_replace.is_empty() {
-                return None;
-            }
+    let mut inferred_predicates = InferredPredicates::new(join_type);
 
-            let join_side_predicate =
-                match replace_col(predicate.clone(), &join_cols_to_replace) {
-                    Ok(p) => p,
-                    Err(e) => {
-                        return Some(Err(e));
-                    }
-                };
+    infer_join_predicates_from_predicates(
+        &join_col_keys,
+        predicates,
+        &mut inferred_predicates,
+    )?;
 
-            Some(Ok(join_side_predicate))
-        })
-        .collect::<Result<Vec<_>>>()
+    infer_join_predicates_from_on_filters(
+        &join_col_keys,
+        join_type,
+        on_filters,
+        &mut inferred_predicates,
+    )?;
+
+    Ok(inferred_predicates.predicates)
+}
+
+/// Inferred predicates collector.
+/// When the JoinType is not Inner, we need to detect whether the inferred 
predicate can strictly
+/// filter out NULL, otherwise ignore it. e.g.
+/// ```text
+/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
+/// ```
+/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate 
will be pushed down to
+/// the left side, resulting in the wrong result.
+struct InferredPredicates {
+    predicates: Vec<Expr>,
+    is_inner_join: bool,
+}
+
+impl InferredPredicates {
+    fn new(join_type: JoinType) -> Self {
+        Self {
+            predicates: vec![],
+            is_inner_join: matches!(join_type, JoinType::Inner),
+        }
+    }
+
+    fn try_build_predicate(
+        &mut self,
+        predicate: Expr,
+        replace_map: &HashMap<&Column, &Column>,
+    ) -> Result<()> {
+        if self.is_inner_join
+            || matches!(
+                is_restrict_null_predicate(
+                    predicate.clone(),
+                    replace_map.keys().cloned()
+                ),
+                Ok(true)
+            )
+        {
+            self.predicates.push(replace_col(predicate, replace_map)?);
+        }
+
+        Ok(())
+    }
+}
+
+/// Infer predicates from the pushed down predicates.
+///
+/// Parameters
+/// * `join_col_keys` column pairs from the join ON clause
+///
+/// * `predicates` the pushed down predicates
+///
+/// * `inferred_predicates` the inferred results
+///
+fn infer_join_predicates_from_predicates(
+    join_col_keys: &[(&Column, &Column)],
+    predicates: &[Expr],
+    inferred_predicates: &mut InferredPredicates,
+) -> Result<()> {
+    infer_join_predicates_impl::<true, true>(
+        join_col_keys,
+        predicates,
+        inferred_predicates,
+    )
+}
+
+/// Infer predicates from the join filter.
+///
+/// Parameters
+/// * `join_col_keys` column pairs from the join ON clause
+///
+/// * `join_type` the JoinType of Join
+///
+/// * `on_filters` filters from the join ON clause that have not already been
+///   identified as join predicates
+///
+/// * `inferred_predicates` the inferred results
+///
+fn infer_join_predicates_from_on_filters(
+    join_col_keys: &[(&Column, &Column)],
+    join_type: JoinType,
+    on_filters: &[Expr],
+    inferred_predicates: &mut InferredPredicates,
+) -> Result<()> {
+    match join_type {
+        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
+        JoinType::Inner => infer_join_predicates_impl::<true, true>(
+            join_col_keys,
+            on_filters,
+            inferred_predicates,
+        ),
+        JoinType::Left | JoinType::LeftSemi => 
infer_join_predicates_impl::<true, false>(
+            join_col_keys,
+            on_filters,
+            inferred_predicates,
+        ),
+        JoinType::Right | JoinType::RightSemi => {
+            infer_join_predicates_impl::<false, true>(
+                join_col_keys,
+                on_filters,
+                inferred_predicates,
+            )
+        }
+    }
+}
+
+/// Infer predicates from the given predicates.
+///
+/// Parameters
+/// * `join_col_keys` column pairs from the join ON clause
+///
+/// * `input_predicates` the given predicates. It can be the pushed down 
predicates,
+///   or it can be the filters of the Join
+///
+/// * `inferred_predicates` the inferred results
+///
+/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate 
can
+///   be inferred from the left table related predicate
+///
+/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate 
can
+///   be inferred from the right table related predicate
+///
+fn infer_join_predicates_impl<
+    const ENABLE_LEFT_TO_RIGHT: bool,
+    const ENABLE_RIGHT_TO_LEFT: bool,
+>(
+    join_col_keys: &[(&Column, &Column)],
+    input_predicates: &[Expr],
+    inferred_predicates: &mut InferredPredicates,
+) -> Result<()> {
+    for predicate in input_predicates {
+        let mut join_cols_to_replace = HashMap::new();
+
+        for &col in &predicate.column_refs() {
+            for (l, r) in join_col_keys.iter() {
+                if ENABLE_LEFT_TO_RIGHT && col == *l {
+                    join_cols_to_replace.insert(col, *r);
+                    break;
+                }
+                if ENABLE_RIGHT_TO_LEFT && col == *r {
+                    join_cols_to_replace.insert(col, *l);
+                    break;
+                }
+            }
+        }
+        if join_cols_to_replace.is_empty() {
+            continue;
+        }
+
+        inferred_predicates
+            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
+    }
+    Ok(())
 }
 
 impl OptimizerRule for PushDownFilter {
@@ -1992,7 +2109,7 @@ mod tests {
         let expected = "\
         Filter: test2.a <= Int64(1)\
         \n  Left Join: Using test.a = test2.a\
-        \n    TableScan: test\
+        \n    TableScan: test, full_filters=[test.a <= Int64(1)]\
         \n    Projection: test2.a\
         \n      TableScan: test2";
         assert_optimized_plan_eq(plan, expected)
@@ -2032,7 +2149,7 @@ mod tests {
         \n  Right Join: Using test.a = test2.a\
         \n    TableScan: test\
         \n    Projection: test2.a\
-        \n      TableScan: test2";
+        \n      TableScan: test2, full_filters=[test2.a <= Int64(1)]";
         assert_optimized_plan_eq(plan, expected)
     }
 
@@ -2814,6 +2931,46 @@ Projection: a, b
         assert_optimized_plan_eq(optimized_plan, expected)
     }
 
+    #[test]
+    fn left_semi_join() -> Result<()> {
+        let left = test_table_scan_with_name("test1")?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join(
+                right,
+                JoinType::LeftSemi,
+                (
+                    vec![Column::from_qualified_name("test1.a")],
+                    vec![Column::from_qualified_name("test2.a")],
+                ),
+                None,
+            )?
+            .filter(col("test2.a").lt_eq(lit(1i64)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{plan}"),
+            "Filter: test2.a <= Int64(1)\
+            \n  LeftSemi Join: test1.a = test2.a\
+            \n    TableScan: test1\
+            \n    Projection: test2.a, test2.b\
+            \n      TableScan: test2"
+        );
+
+        // Inferred the predicate `test1.a <= Int64(1)` and push it down to 
the left side.
+        let expected = "\
+        Filter: test2.a <= Int64(1)\
+        \n  LeftSemi Join: test1.a = test2.a\
+        \n    TableScan: test1, full_filters=[test1.a <= Int64(1)]\
+        \n    Projection: test2.a, test2.b\
+        \n      TableScan: test2";
+        assert_optimized_plan_eq(plan, expected)
+    }
+
     #[test]
     fn left_semi_join_with_filters() -> Result<()> {
         let left = test_table_scan_with_name("test1")?;
@@ -2855,6 +3012,46 @@ Projection: a, b
         assert_optimized_plan_eq(plan, expected)
     }
 
+    #[test]
+    fn right_semi_join() -> Result<()> {
+        let left = test_table_scan_with_name("test1")?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join(
+                right,
+                JoinType::RightSemi,
+                (
+                    vec![Column::from_qualified_name("test1.a")],
+                    vec![Column::from_qualified_name("test2.a")],
+                ),
+                None,
+            )?
+            .filter(col("test1.a").lt_eq(lit(1i64)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{plan}"),
+            "Filter: test1.a <= Int64(1)\
+            \n  RightSemi Join: test1.a = test2.a\
+            \n    TableScan: test1\
+            \n    Projection: test2.a, test2.b\
+            \n      TableScan: test2",
+        );
+
+        // Inferred the predicate `test2.a <= Int64(1)` and push it down to 
the right side.
+        let expected = "\
+        Filter: test1.a <= Int64(1)\
+        \n  RightSemi Join: test1.a = test2.a\
+        \n    TableScan: test1\
+        \n    Projection: test2.a, test2.b\
+        \n      TableScan: test2, full_filters=[test2.a <= Int64(1)]";
+        assert_optimized_plan_eq(plan, expected)
+    }
+
     #[test]
     fn right_semi_join_with_filters() -> Result<()> {
         let left = test_table_scan_with_name("test1")?;
@@ -2896,6 +3093,51 @@ Projection: a, b
         assert_optimized_plan_eq(plan, expected)
     }
 
+    #[test]
+    fn left_anti_join() -> Result<()> {
+        let table_scan = test_table_scan_with_name("test1")?;
+        let left = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join(
+                right,
+                JoinType::LeftAnti,
+                (
+                    vec![Column::from_qualified_name("test1.a")],
+                    vec![Column::from_qualified_name("test2.a")],
+                ),
+                None,
+            )?
+            .filter(col("test2.a").gt(lit(2u32)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{plan}"),
+            "Filter: test2.a > UInt32(2)\
+            \n  LeftAnti Join: test1.a = test2.a\
+            \n    Projection: test1.a, test1.b\
+            \n      TableScan: test1\
+            \n    Projection: test2.a, test2.b\
+            \n      TableScan: test2",
+        );
+
+        // For left anti, filter of the right side filter can be pushed down.
+        let expected = "\
+        Filter: test2.a > UInt32(2)\
+        \n  LeftAnti Join: test1.a = test2.a\
+        \n    Projection: test1.a, test1.b\
+        \n      TableScan: test1, full_filters=[test1.a > UInt32(2)]\
+        \n    Projection: test2.a, test2.b\
+        \n      TableScan: test2";
+        assert_optimized_plan_eq(plan, expected)
+    }
+
     #[test]
     fn left_anti_join_with_filters() -> Result<()> {
         let table_scan = test_table_scan_with_name("test1")?;
@@ -2942,6 +3184,51 @@ Projection: a, b
         assert_optimized_plan_eq(plan, expected)
     }
 
+    #[test]
+    fn right_anti_join() -> Result<()> {
+        let table_scan = test_table_scan_with_name("test1")?;
+        let left = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a"), col("b")])?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(left)
+            .join(
+                right,
+                JoinType::RightAnti,
+                (
+                    vec![Column::from_qualified_name("test1.a")],
+                    vec![Column::from_qualified_name("test2.a")],
+                ),
+                None,
+            )?
+            .filter(col("test1.a").gt(lit(2u32)))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{plan}"),
+            "Filter: test1.a > UInt32(2)\
+             \n  RightAnti Join: test1.a = test2.a\
+             \n    Projection: test1.a, test1.b\
+             \n      TableScan: test1\
+             \n    Projection: test2.a, test2.b\
+             \n      TableScan: test2",
+        );
+
+        // For right anti, filter of the left side can be pushed down.
+        let expected = "\
+        Filter: test1.a > UInt32(2)\
+        \n  RightAnti Join: test1.a = test2.a\
+        \n    Projection: test1.a, test1.b\
+        \n      TableScan: test1\
+        \n    Projection: test2.a, test2.b\
+        \n      TableScan: test2, full_filters=[test2.a > UInt32(2)]";
+        assert_optimized_plan_eq(plan, expected)
+    }
+
     #[test]
     fn right_anti_join_with_filters() -> Result<()> {
         let table_scan = test_table_scan_with_name("test1")?;
diff --git a/datafusion/optimizer/src/utils.rs 
b/datafusion/optimizer/src/utils.rs
index 6972c16c0d..9f325bc01b 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -21,11 +21,18 @@ use std::collections::{BTreeSet, HashMap, HashSet};
 
 use crate::{OptimizerConfig, OptimizerRule};
 
-use datafusion_common::{Column, DFSchema, Result};
+use crate::analyzer::type_coercion::TypeCoercionRewriter;
+use arrow::array::{new_null_array, Array, RecordBatch};
+use arrow::datatypes::{DataType, Field, Schema};
+use datafusion_common::cast::as_boolean_array;
+use datafusion_common::tree_node::{TransformedResult, TreeNode};
+use datafusion_common::{Column, DFSchema, Result, ScalarValue};
+use datafusion_expr::execution_props::ExecutionProps;
 use datafusion_expr::expr_rewriter::replace_col;
-use datafusion_expr::{logical_plan::LogicalPlan, Expr};
-
+use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr};
+use datafusion_physical_expr::create_physical_expr;
 use log::{debug, trace};
+use std::sync::Arc;
 
 /// Re-export of `NamesPreserver` for backwards compatibility,
 /// as it was initially placed here and then moved elsewhere.
@@ -117,3 +124,161 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) {
     debug!("{description}:\n{}\n", plan.display_indent());
     trace!("{description}::\n{}\n", plan.display_indent_schema());
 }
+
+/// Determine whether a predicate can restrict NULLs. e.g.
+/// `c0 > 8` return true;
+/// `c0 IS NULL` return false.
+pub fn is_restrict_null_predicate<'a>(
+    predicate: Expr,
+    join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
+) -> Result<bool> {
+    if matches!(predicate, Expr::Column(_)) {
+        return Ok(true);
+    }
+
+    static DUMMY_COL_NAME: &str = "?";
+    let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, 
true)]);
+    let input_schema = DFSchema::try_from(schema.clone())?;
+    let column = new_null_array(&DataType::Null, 1);
+    let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![column])?;
+    let execution_props = ExecutionProps::default();
+    let null_column = Column::from_name(DUMMY_COL_NAME);
+
+    let join_cols_to_replace = join_cols_of_predicate
+        .into_iter()
+        .map(|column| (column, &null_column))
+        .collect::<HashMap<_, _>>();
+
+    let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
+    let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
+    let phys_expr =
+        create_physical_expr(&coerced_predicate, &input_schema, 
&execution_props)?;
+
+    let result_type = phys_expr.data_type(&schema)?;
+    if !matches!(&result_type, DataType::Boolean) {
+        return Ok(false);
+    }
+
+    // If result is single `true`, return false;
+    // If result is single `NULL` or `false`, return true;
+    Ok(match phys_expr.evaluate(&input_batch)? {
+        ColumnarValue::Array(array) => {
+            if array.len() == 1 {
+                let boolean_array = as_boolean_array(&array)?;
+                boolean_array.is_null(0) || !boolean_array.value(0)
+            } else {
+                false
+            }
+        }
+        ColumnarValue::Scalar(scalar) => matches!(
+            scalar,
+            ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
+        ),
+    })
+}
+
+fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
+    let mut expr_rewrite = TypeCoercionRewriter { schema };
+    expr.rewrite(&mut expr_rewrite).data()
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, 
Operator};
+
+    #[test]
+    fn expr_is_restrict_null_predicate() -> Result<()> {
+        let test_cases = vec![
+            // a
+            (col("a"), true),
+            // a IS NULL
+            (is_null(col("a")), false),
+            // a IS NOT NULL
+            (Expr::IsNotNull(Box::new(col("a"))), true),
+            // a = NULL
+            (
+                binary_expr(col("a"), Operator::Eq, 
Expr::Literal(ScalarValue::Null)),
+                true,
+            ),
+            // a > 8
+            (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
+            // a <= 8
+            (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
+            // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END
+            (
+                case(col("a"))
+                    .when(lit(1i64), lit(true))
+                    .when(lit(0i64), lit(false))
+                    .otherwise(lit(ScalarValue::Null))?,
+                true,
+            ),
+            // CASE a WHEN 1 THEN true ELSE false END
+            (
+                case(col("a"))
+                    .when(lit(1i64), lit(true))
+                    .otherwise(lit(false))?,
+                true,
+            ),
+            // CASE a WHEN 0 THEN false ELSE true END
+            (
+                case(col("a"))
+                    .when(lit(0i64), lit(false))
+                    .otherwise(lit(true))?,
+                false,
+            ),
+            // (CASE a WHEN 0 THEN false ELSE true END) OR false
+            (
+                binary_expr(
+                    case(col("a"))
+                        .when(lit(0i64), lit(false))
+                        .otherwise(lit(true))?,
+                    Operator::Or,
+                    lit(false),
+                ),
+                false,
+            ),
+            // (CASE a WHEN 0 THEN true ELSE false END) OR false
+            (
+                binary_expr(
+                    case(col("a"))
+                        .when(lit(0i64), lit(true))
+                        .otherwise(lit(false))?,
+                    Operator::Or,
+                    lit(false),
+                ),
+                true,
+            ),
+            // a IN (1, 2, 3)
+            (
+                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], 
false),
+                true,
+            ),
+            // a NOT IN (1, 2, 3)
+            (
+                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
+                true,
+            ),
+            // a IN (NULL)
+            (
+                in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], 
false),
+                true,
+            ),
+            // a NOT IN (NULL)
+            (
+                in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], 
true),
+                true,
+            ),
+        ];
+
+        let column_a = Column::from_name("a");
+        for (predicate, expected) in test_cases {
+            let join_cols_of_predicate = std::iter::once(&column_a);
+            let actual =
+                is_restrict_null_predicate(predicate.clone(), 
join_cols_of_predicate)?;
+            assert_eq!(actual, expected, "{}", predicate);
+        }
+
+        Ok(())
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to