This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 89e71ef5fd [Optimization] Infer predicate under all JoinTypes (#13081)
89e71ef5fd is described below
commit 89e71ef5fd058278f1a0cc659dccf1757def3a62
Author: JasonLi <[email protected]>
AuthorDate: Tue Oct 29 19:31:44 2024 +0800
[Optimization] Infer predicate under all JoinTypes (#13081)
* optimize infer join predicate
* pass clippy
* chores: remove unnecessary debug code
---
datafusion/optimizer/src/push_down_filter.rs | 393 +++++++++++++++++++++++----
datafusion/optimizer/src/utils.rs | 171 +++++++++++-
2 files changed, 508 insertions(+), 56 deletions(-)
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index f8e614a0aa..a0262d7d95 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -36,7 +36,7 @@ use datafusion_expr::{
};
use crate::optimizer::ApplyOrder;
-use crate::utils::has_all_column_refs;
+use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
use crate::{OptimizerConfig, OptimizerRule};
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
@@ -558,10 +558,6 @@ fn infer_join_predicates(
predicates: &[Expr],
on_filters: &[Expr],
) -> Result<Vec<Expr>> {
- if join.join_type != JoinType::Inner {
- return Ok(vec![]);
- }
-
// Only allow both side key is column.
let join_col_keys = join
.on
@@ -573,55 +569,176 @@ fn infer_join_predicates(
})
.collect::<Vec<_>>();
- // TODO refine the logic, introduce EquivalenceProperties to logical plan
and infer additional filters to push down
- // For inner joins, duplicate filters for joined columns so filters can be
pushed down
- // to both sides. Take the following query as an example:
- //
- // ```sql
- // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
- // ```
- //
- // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
- // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
- //
- // Join clauses with `Using` constraints also take advantage of this logic
to make sure
- // predicates reference the shared join columns are pushed to both sides.
- // This logic should also been applied to conditions in JOIN ON clause
- predicates
- .iter()
- .chain(on_filters.iter())
- .filter_map(|predicate| {
- let mut join_cols_to_replace = HashMap::new();
-
- let columns = predicate.column_refs();
-
- for &col in columns.iter() {
- for (l, r) in join_col_keys.iter() {
- if col == *l {
- join_cols_to_replace.insert(col, *r);
- break;
- } else if col == *r {
- join_cols_to_replace.insert(col, *l);
- break;
- }
- }
- }
+ let join_type = join.join_type;
- if join_cols_to_replace.is_empty() {
- return None;
- }
+ let mut inferred_predicates = InferredPredicates::new(join_type);
- let join_side_predicate =
- match replace_col(predicate.clone(), &join_cols_to_replace) {
- Ok(p) => p,
- Err(e) => {
- return Some(Err(e));
- }
- };
+ infer_join_predicates_from_predicates(
+ &join_col_keys,
+ predicates,
+ &mut inferred_predicates,
+ )?;
- Some(Ok(join_side_predicate))
- })
- .collect::<Result<Vec<_>>>()
+ infer_join_predicates_from_on_filters(
+ &join_col_keys,
+ join_type,
+ on_filters,
+ &mut inferred_predicates,
+ )?;
+
+ Ok(inferred_predicates.predicates)
+}
+
+/// Inferred predicates collector.
+/// When the JoinType is not Inner, we need to detect whether the inferred
predicate can strictly
+/// filter out NULL, otherwise ignore it. e.g.
+/// ```text
+/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
+/// ```
+/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate
will be pushed down to
+/// the left side, resulting in the wrong result.
+struct InferredPredicates {
+ predicates: Vec<Expr>,
+ is_inner_join: bool,
+}
+
+impl InferredPredicates {
+ fn new(join_type: JoinType) -> Self {
+ Self {
+ predicates: vec![],
+ is_inner_join: matches!(join_type, JoinType::Inner),
+ }
+ }
+
+ fn try_build_predicate(
+ &mut self,
+ predicate: Expr,
+ replace_map: &HashMap<&Column, &Column>,
+ ) -> Result<()> {
+ if self.is_inner_join
+ || matches!(
+ is_restrict_null_predicate(
+ predicate.clone(),
+ replace_map.keys().cloned()
+ ),
+ Ok(true)
+ )
+ {
+ self.predicates.push(replace_col(predicate, replace_map)?);
+ }
+
+ Ok(())
+ }
+}
+
+/// Infer predicates from the pushed down predicates.
+///
+/// Parameters
+/// * `join_col_keys` column pairs from the join ON clause
+///
+/// * `predicates` the pushed down predicates
+///
+/// * `inferred_predicates` the inferred results
+///
+fn infer_join_predicates_from_predicates(
+ join_col_keys: &[(&Column, &Column)],
+ predicates: &[Expr],
+ inferred_predicates: &mut InferredPredicates,
+) -> Result<()> {
+ infer_join_predicates_impl::<true, true>(
+ join_col_keys,
+ predicates,
+ inferred_predicates,
+ )
+}
+
+/// Infer predicates from the join filter.
+///
+/// Parameters
+/// * `join_col_keys` column pairs from the join ON clause
+///
+/// * `join_type` the JoinType of Join
+///
+/// * `on_filters` filters from the join ON clause that have not already been
+/// identified as join predicates
+///
+/// * `inferred_predicates` the inferred results
+///
+fn infer_join_predicates_from_on_filters(
+ join_col_keys: &[(&Column, &Column)],
+ join_type: JoinType,
+ on_filters: &[Expr],
+ inferred_predicates: &mut InferredPredicates,
+) -> Result<()> {
+ match join_type {
+ JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
+ JoinType::Inner => infer_join_predicates_impl::<true, true>(
+ join_col_keys,
+ on_filters,
+ inferred_predicates,
+ ),
+ JoinType::Left | JoinType::LeftSemi =>
infer_join_predicates_impl::<true, false>(
+ join_col_keys,
+ on_filters,
+ inferred_predicates,
+ ),
+ JoinType::Right | JoinType::RightSemi => {
+ infer_join_predicates_impl::<false, true>(
+ join_col_keys,
+ on_filters,
+ inferred_predicates,
+ )
+ }
+ }
+}
+
+/// Infer predicates from the given predicates.
+///
+/// Parameters
+/// * `join_col_keys` column pairs from the join ON clause
+///
+/// * `input_predicates` the given predicates. It can be the pushed down
predicates,
+/// or it can be the filters of the Join
+///
+/// * `inferred_predicates` the inferred results
+///
+/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate
can
+/// be inferred from the left table related predicate
+///
+/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate
can
+/// be inferred from the right table related predicate
+///
+fn infer_join_predicates_impl<
+ const ENABLE_LEFT_TO_RIGHT: bool,
+ const ENABLE_RIGHT_TO_LEFT: bool,
+>(
+ join_col_keys: &[(&Column, &Column)],
+ input_predicates: &[Expr],
+ inferred_predicates: &mut InferredPredicates,
+) -> Result<()> {
+ for predicate in input_predicates {
+ let mut join_cols_to_replace = HashMap::new();
+
+ for &col in &predicate.column_refs() {
+ for (l, r) in join_col_keys.iter() {
+ if ENABLE_LEFT_TO_RIGHT && col == *l {
+ join_cols_to_replace.insert(col, *r);
+ break;
+ }
+ if ENABLE_RIGHT_TO_LEFT && col == *r {
+ join_cols_to_replace.insert(col, *l);
+ break;
+ }
+ }
+ }
+ if join_cols_to_replace.is_empty() {
+ continue;
+ }
+
+ inferred_predicates
+ .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
+ }
+ Ok(())
}
impl OptimizerRule for PushDownFilter {
@@ -1992,7 +2109,7 @@ mod tests {
let expected = "\
Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
- \n TableScan: test\
+ \n TableScan: test, full_filters=[test.a <= Int64(1)]\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
@@ -2032,7 +2149,7 @@ mod tests {
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
- \n TableScan: test2";
+ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
@@ -2814,6 +2931,46 @@ Projection: a, b
assert_optimized_plan_eq(optimized_plan, expected)
}
+ #[test]
+ fn left_semi_join() -> Result<()> {
+ let left = test_table_scan_with_name("test1")?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join(
+ right,
+ JoinType::LeftSemi,
+ (
+ vec![Column::from_qualified_name("test1.a")],
+ vec![Column::from_qualified_name("test2.a")],
+ ),
+ None,
+ )?
+ .filter(col("test2.a").lt_eq(lit(1i64)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{plan}"),
+ "Filter: test2.a <= Int64(1)\
+ \n LeftSemi Join: test1.a = test2.a\
+ \n TableScan: test1\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2"
+ );
+
+ // Inferred the predicate `test1.a <= Int64(1)` and push it down to
the left side.
+ let expected = "\
+ Filter: test2.a <= Int64(1)\
+ \n LeftSemi Join: test1.a = test2.a\
+ \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2";
+ assert_optimized_plan_eq(plan, expected)
+ }
+
#[test]
fn left_semi_join_with_filters() -> Result<()> {
let left = test_table_scan_with_name("test1")?;
@@ -2855,6 +3012,46 @@ Projection: a, b
assert_optimized_plan_eq(plan, expected)
}
+ #[test]
+ fn right_semi_join() -> Result<()> {
+ let left = test_table_scan_with_name("test1")?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join(
+ right,
+ JoinType::RightSemi,
+ (
+ vec![Column::from_qualified_name("test1.a")],
+ vec![Column::from_qualified_name("test2.a")],
+ ),
+ None,
+ )?
+ .filter(col("test1.a").lt_eq(lit(1i64)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{plan}"),
+ "Filter: test1.a <= Int64(1)\
+ \n RightSemi Join: test1.a = test2.a\
+ \n TableScan: test1\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2",
+ );
+
+ // Inferred the predicate `test2.a <= Int64(1)` and push it down to
the right side.
+ let expected = "\
+ Filter: test1.a <= Int64(1)\
+ \n RightSemi Join: test1.a = test2.a\
+ \n TableScan: test1\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
+ assert_optimized_plan_eq(plan, expected)
+ }
+
#[test]
fn right_semi_join_with_filters() -> Result<()> {
let left = test_table_scan_with_name("test1")?;
@@ -2896,6 +3093,51 @@ Projection: a, b
assert_optimized_plan_eq(plan, expected)
}
+ #[test]
+ fn left_anti_join() -> Result<()> {
+ let table_scan = test_table_scan_with_name("test1")?;
+ let left = LogicalPlanBuilder::from(table_scan)
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join(
+ right,
+ JoinType::LeftAnti,
+ (
+ vec![Column::from_qualified_name("test1.a")],
+ vec![Column::from_qualified_name("test2.a")],
+ ),
+ None,
+ )?
+ .filter(col("test2.a").gt(lit(2u32)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{plan}"),
+ "Filter: test2.a > UInt32(2)\
+ \n LeftAnti Join: test1.a = test2.a\
+ \n Projection: test1.a, test1.b\
+ \n TableScan: test1\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2",
+ );
+
+ // For left anti, filter of the right side filter can be pushed down.
+ let expected = "\
+ Filter: test2.a > UInt32(2)\
+ \n LeftAnti Join: test1.a = test2.a\
+ \n Projection: test1.a, test1.b\
+ \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2";
+ assert_optimized_plan_eq(plan, expected)
+ }
+
#[test]
fn left_anti_join_with_filters() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
@@ -2942,6 +3184,51 @@ Projection: a, b
assert_optimized_plan_eq(plan, expected)
}
+ #[test]
+ fn right_anti_join() -> Result<()> {
+ let table_scan = test_table_scan_with_name("test1")?;
+ let left = LogicalPlanBuilder::from(table_scan)
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join(
+ right,
+ JoinType::RightAnti,
+ (
+ vec![Column::from_qualified_name("test1.a")],
+ vec![Column::from_qualified_name("test2.a")],
+ ),
+ None,
+ )?
+ .filter(col("test1.a").gt(lit(2u32)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{plan}"),
+ "Filter: test1.a > UInt32(2)\
+ \n RightAnti Join: test1.a = test2.a\
+ \n Projection: test1.a, test1.b\
+ \n TableScan: test1\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2",
+ );
+
+ // For right anti, filter of the left side can be pushed down.
+ let expected = "\
+ Filter: test1.a > UInt32(2)\
+ \n RightAnti Join: test1.a = test2.a\
+ \n Projection: test1.a, test1.b\
+ \n TableScan: test1\
+ \n Projection: test2.a, test2.b\
+ \n TableScan: test2, full_filters=[test2.a > UInt32(2)]";
+ assert_optimized_plan_eq(plan, expected)
+ }
+
#[test]
fn right_anti_join_with_filters() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
diff --git a/datafusion/optimizer/src/utils.rs
b/datafusion/optimizer/src/utils.rs
index 6972c16c0d..9f325bc01b 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -21,11 +21,18 @@ use std::collections::{BTreeSet, HashMap, HashSet};
use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{Column, DFSchema, Result};
+use crate::analyzer::type_coercion::TypeCoercionRewriter;
+use arrow::array::{new_null_array, Array, RecordBatch};
+use arrow::datatypes::{DataType, Field, Schema};
+use datafusion_common::cast::as_boolean_array;
+use datafusion_common::tree_node::{TransformedResult, TreeNode};
+use datafusion_common::{Column, DFSchema, Result, ScalarValue};
+use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr_rewriter::replace_col;
-use datafusion_expr::{logical_plan::LogicalPlan, Expr};
-
+use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr};
+use datafusion_physical_expr::create_physical_expr;
use log::{debug, trace};
+use std::sync::Arc;
/// Re-export of `NamesPreserver` for backwards compatibility,
/// as it was initially placed here and then moved elsewhere.
@@ -117,3 +124,161 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) {
debug!("{description}:\n{}\n", plan.display_indent());
trace!("{description}::\n{}\n", plan.display_indent_schema());
}
+
+/// Determine whether a predicate can restrict NULLs. e.g.
+/// `c0 > 8` return true;
+/// `c0 IS NULL` return false.
+pub fn is_restrict_null_predicate<'a>(
+ predicate: Expr,
+ join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
+) -> Result<bool> {
+ if matches!(predicate, Expr::Column(_)) {
+ return Ok(true);
+ }
+
+ static DUMMY_COL_NAME: &str = "?";
+ let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null,
true)]);
+ let input_schema = DFSchema::try_from(schema.clone())?;
+ let column = new_null_array(&DataType::Null, 1);
+ let input_batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![column])?;
+ let execution_props = ExecutionProps::default();
+ let null_column = Column::from_name(DUMMY_COL_NAME);
+
+ let join_cols_to_replace = join_cols_of_predicate
+ .into_iter()
+ .map(|column| (column, &null_column))
+ .collect::<HashMap<_, _>>();
+
+ let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
+ let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
+ let phys_expr =
+ create_physical_expr(&coerced_predicate, &input_schema,
&execution_props)?;
+
+ let result_type = phys_expr.data_type(&schema)?;
+ if !matches!(&result_type, DataType::Boolean) {
+ return Ok(false);
+ }
+
+ // If result is single `true`, return false;
+ // If result is single `NULL` or `false`, return true;
+ Ok(match phys_expr.evaluate(&input_batch)? {
+ ColumnarValue::Array(array) => {
+ if array.len() == 1 {
+ let boolean_array = as_boolean_array(&array)?;
+ boolean_array.is_null(0) || !boolean_array.value(0)
+ } else {
+ false
+ }
+ }
+ ColumnarValue::Scalar(scalar) => matches!(
+ scalar,
+ ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
+ ),
+ })
+}
+
+fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
+ let mut expr_rewrite = TypeCoercionRewriter { schema };
+ expr.rewrite(&mut expr_rewrite).data()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit,
Operator};
+
+ #[test]
+ fn expr_is_restrict_null_predicate() -> Result<()> {
+ let test_cases = vec![
+ // a
+ (col("a"), true),
+ // a IS NULL
+ (is_null(col("a")), false),
+ // a IS NOT NULL
+ (Expr::IsNotNull(Box::new(col("a"))), true),
+ // a = NULL
+ (
+ binary_expr(col("a"), Operator::Eq,
Expr::Literal(ScalarValue::Null)),
+ true,
+ ),
+ // a > 8
+ (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
+ // a <= 8
+ (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
+ // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END
+ (
+ case(col("a"))
+ .when(lit(1i64), lit(true))
+ .when(lit(0i64), lit(false))
+ .otherwise(lit(ScalarValue::Null))?,
+ true,
+ ),
+ // CASE a WHEN 1 THEN true ELSE false END
+ (
+ case(col("a"))
+ .when(lit(1i64), lit(true))
+ .otherwise(lit(false))?,
+ true,
+ ),
+ // CASE a WHEN 0 THEN false ELSE true END
+ (
+ case(col("a"))
+ .when(lit(0i64), lit(false))
+ .otherwise(lit(true))?,
+ false,
+ ),
+ // (CASE a WHEN 0 THEN false ELSE true END) OR false
+ (
+ binary_expr(
+ case(col("a"))
+ .when(lit(0i64), lit(false))
+ .otherwise(lit(true))?,
+ Operator::Or,
+ lit(false),
+ ),
+ false,
+ ),
+ // (CASE a WHEN 0 THEN true ELSE false END) OR false
+ (
+ binary_expr(
+ case(col("a"))
+ .when(lit(0i64), lit(true))
+ .otherwise(lit(false))?,
+ Operator::Or,
+ lit(false),
+ ),
+ true,
+ ),
+ // a IN (1, 2, 3)
+ (
+ in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)],
false),
+ true,
+ ),
+ // a NOT IN (1, 2, 3)
+ (
+ in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
+ true,
+ ),
+ // a IN (NULL)
+ (
+ in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)],
false),
+ true,
+ ),
+ // a NOT IN (NULL)
+ (
+ in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)],
true),
+ true,
+ ),
+ ];
+
+ let column_a = Column::from_name("a");
+ for (predicate, expected) in test_cases {
+ let join_cols_of_predicate = std::iter::once(&column_a);
+ let actual =
+ is_restrict_null_predicate(predicate.clone(),
join_cols_of_predicate)?;
+ assert_eq!(actual, expected, "{}", predicate);
+ }
+
+ Ok(())
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]