This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 87169f06ab Stop copying LogicalPlan and Exprs in `PushDownFilter` 
(#10444)
87169f06ab is described below

commit 87169f06ab590f20bd03b1be504a2119ddca6d68
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu May 16 20:08:39 2024 -0400

    Stop copying LogicalPlan and Exprs in `PushDownFilter` (#10444)
---
 datafusion/expr/src/logical_plan/plan.rs     |  10 +
 datafusion/optimizer/src/push_down_filter.rs | 623 ++++++++++++++-------------
 2 files changed, 343 insertions(+), 290 deletions(-)

diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index ddf075c2c2..4872e5acda 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -2407,6 +2407,16 @@ pub enum Distinct {
     On(DistinctOn),
 }
 
+impl Distinct {
+    /// return a reference to the nodes input
+    pub fn input(&self) -> &Arc<LogicalPlan> {
+        match self {
+            Distinct::All(input) => input,
+            Distinct::On(DistinctOn { input, .. }) => input,
+        }
+    }
+}
+
 /// Removes duplicate rows from the input
 #[derive(Clone, PartialEq, Eq, Hash)]
 pub struct DistinctOn {
diff --git a/datafusion/optimizer/src/push_down_filter.rs 
b/datafusion/optimizer/src/push_down_filter.rs
index 57b38bd0d0..b684b54903 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -14,6 +14,7 @@
 
 //! [`PushDownFilter`] applies filters as early as possible
 
+use indexmap::IndexSet;
 use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
 
@@ -23,10 +24,9 @@ use datafusion_common::tree_node::{
     Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
 };
 use datafusion_common::{
-    internal_err, plan_datafusion_err, qualified_name, Column, DFSchema, 
DFSchemaRef,
+    internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef,
     JoinConstraint, Result,
 };
-use datafusion_expr::expr::Alias;
 use datafusion_expr::expr_rewriter::replace_col;
 use datafusion_expr::logical_plan::tree_node::unwrap_arc;
 use datafusion_expr::logical_plan::{
@@ -131,7 +131,8 @@ use crate::{OptimizerConfig, OptimizerRule};
 #[derive(Default)]
 pub struct PushDownFilter {}
 
-/// For a given JOIN logical plan, determine whether each side of the join is 
preserved.
+/// For a given JOIN type, determine whether each side of the join is 
preserved.
+///
 /// We say a join side is preserved if the join returns all or a subset of the 
rows from
 /// the relevant side, such that each row of the output table directly maps to 
a row of
 /// the preserved input table. If a table is not preserved, it can provide 
extra null rows.
@@ -150,44 +151,33 @@ pub struct PushDownFilter {}
 /// non-preserved side it can be more tricky.
 ///
 /// Returns a tuple of booleans - (left_preserved, right_preserved).
-fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
-    match plan {
-        LogicalPlan::Join(Join { join_type, .. }) => match join_type {
-            JoinType::Inner => Ok((true, true)),
-            JoinType::Left => Ok((true, false)),
-            JoinType::Right => Ok((false, true)),
-            JoinType::Full => Ok((false, false)),
-            // No columns from the right side of the join can be referenced in 
output
-            // predicates for semi/anti joins, so whether we specify t/f 
doesn't matter.
-            JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
-            // No columns from the left side of the join can be referenced in 
output
-            // predicates for semi/anti joins, so whether we specify t/f 
doesn't matter.
-            JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
-        },
-        LogicalPlan::CrossJoin(_) => Ok((true, true)),
-        _ => internal_err!("lr_is_preserved only valid for JOIN nodes"),
+fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
+    match join_type {
+        JoinType::Inner => Ok((true, true)),
+        JoinType::Left => Ok((true, false)),
+        JoinType::Right => Ok((false, true)),
+        JoinType::Full => Ok((false, false)),
+        // No columns from the right side of the join can be referenced in 
output
+        // predicates for semi/anti joins, so whether we specify t/f doesn't 
matter.
+        JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
+        // No columns from the left side of the join can be referenced in 
output
+        // predicates for semi/anti joins, so whether we specify t/f doesn't 
matter.
+        JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
     }
 }
 
 /// For a given JOIN logical plan, determine whether each side of the join is 
preserved
 /// in terms on join filtering.
-///
 /// Predicates from join filter can only be pushed to preserved join side.
-fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
-    match plan {
-        LogicalPlan::Join(Join { join_type, .. }) => match join_type {
-            JoinType::Inner => Ok((true, true)),
-            JoinType::Left => Ok((false, true)),
-            JoinType::Right => Ok((true, false)),
-            JoinType::Full => Ok((false, false)),
-            JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
-            JoinType::LeftAnti => Ok((false, true)),
-            JoinType::RightAnti => Ok((true, false)),
-        },
-        LogicalPlan::CrossJoin(_) => {
-            internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN 
nodes")
-        }
-        _ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"),
+fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
+    match join_type {
+        JoinType::Inner => Ok((true, true)),
+        JoinType::Left => Ok((false, true)),
+        JoinType::Right => Ok((true, false)),
+        JoinType::Full => Ok((false, false)),
+        JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
+        JoinType::LeftAnti => Ok((false, true)),
+        JoinType::RightAnti => Ok((true, false)),
     }
 }
 
@@ -400,23 +390,20 @@ fn extract_or_clause(expr: &Expr, schema_columns: 
&HashSet<Column>) -> Option<Ex
 /// push down join/cross-join
 fn push_down_all_join(
     predicates: Vec<Expr>,
-    infer_predicates: Vec<Expr>,
-    join_plan: &LogicalPlan,
-    left: &LogicalPlan,
-    right: &LogicalPlan,
+    inferred_join_predicates: Vec<Expr>,
+    mut join: Join,
     on_filter: Vec<Expr>,
-    is_inner_join: bool,
 ) -> Result<Transformed<LogicalPlan>> {
-    let on_filter_empty = on_filter.is_empty();
+    let is_inner_join = join.join_type == JoinType::Inner;
     // Get pushable predicates from current optimizer state
-    let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
+    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?;
 
     // The predicates can be divided to three categories:
     // 1) can push through join to its children(left or right)
     // 2) can be converted to join conditions if the join type is Inner
     // 3) should be kept as filter conditions
-    let left_schema = left.schema();
-    let right_schema = right.schema();
+    let left_schema = join.left.schema();
+    let right_schema = join.right.schema();
     let mut left_push = vec![];
     let mut right_push = vec![];
     let mut keep_predicates = vec![];
@@ -438,7 +425,7 @@ fn push_down_all_join(
     }
 
     // For infer predicates, if they can not push through join, just drop them
-    for predicate in infer_predicates {
+    for predicate in inferred_join_predicates {
         if left_preserved && can_pushdown_join_predicate(&predicate, 
left_schema)? {
             left_push.push(predicate);
         } else if right_preserved
@@ -449,7 +436,7 @@ fn push_down_all_join(
     }
 
     if !on_filter.is_empty() {
-        let (on_left_preserved, on_right_preserved) = 
on_lr_is_preserved(join_plan)?;
+        let (on_left_preserved, on_right_preserved) = 
on_lr_is_preserved(join.join_type)?;
         for on in on_filter {
             if on_left_preserved && can_pushdown_join_predicate(&on, 
left_schema)? {
                 left_push.push(on)
@@ -474,46 +461,29 @@ fn push_down_all_join(
         right_push.extend(extract_or_clauses_for_join(&join_conditions, 
right_schema));
     }
 
-    let left = match conjunction(left_push) {
-        Some(predicate) => {
-            LogicalPlan::Filter(Filter::try_new(predicate, 
Arc::new(left.clone()))?)
-        }
-        None => left.clone(),
-    };
-    let right = match conjunction(right_push) {
-        Some(predicate) => {
-            LogicalPlan::Filter(Filter::try_new(predicate, 
Arc::new(right.clone()))?)
-        }
-        None => right.clone(),
-    };
-    // Create a new Join with the new `left` and `right`
-    //
-    // expressions() output for Join is a vector consisting of
-    //   1. join keys - columns mentioned in ON clause
-    //   2. optional predicate - in case join filter is not empty,
-    //      it always will be the last element, otherwise result
-    //      vector will contain only join keys (without additional
-    //      element representing filter).
-    let mut exprs = join_plan.expressions();
-    if !on_filter_empty {
-        exprs.pop();
-    }
-    exprs.extend(join_conditions.into_iter().reduce(Expr::and));
-    let plan = join_plan.with_new_exprs(exprs, vec![left, right])?;
-
-    // wrap the join on the filter whose predicates must be kept
-    match conjunction(keep_predicates) {
-        Some(predicate) => {
-            let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
-            Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
-        }
-        None => Ok(Transformed::no(plan)),
+    if let Some(predicate) = conjunction(left_push) {
+        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, 
join.left)?));
     }
+    if let Some(predicate) = conjunction(right_push) {
+        join.right =
+            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, 
join.right)?));
+    }
+
+    // Add any new join conditions as the non join predicates
+    join.filter = conjunction(join_conditions);
+
+    // wrap the join on the filter whose predicates must be kept, if any
+    let plan = LogicalPlan::Join(join);
+    let plan = if let Some(predicate) = conjunction(keep_predicates) {
+        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
+    } else {
+        plan
+    };
+    Ok(Transformed::yes(plan))
 }
 
 fn push_down_join(
-    plan: &LogicalPlan,
-    join: &Join,
+    join: Join,
     parent_predicate: Option<&Expr>,
 ) -> Result<Transformed<LogicalPlan>> {
     // Split the parent predicate into individual conjunctive parts.
@@ -526,93 +496,102 @@ fn push_down_join(
         .as_ref()
         .map_or_else(Vec::new, |filter| 
split_conjunction_owned(filter.clone()));
 
-    let mut is_inner_join = false;
-    let infer_predicates = if join.join_type == JoinType::Inner {
-        is_inner_join = true;
-
-        // Only allow both side key is column.
-        let join_col_keys = join
-            .on
-            .iter()
-            .filter_map(|(l, r)| {
-                let left_col = l.try_as_col().cloned()?;
-                let right_col = r.try_as_col().cloned()?;
-                Some((left_col, right_col))
-            })
-            .collect::<Vec<_>>();
-
-        // TODO refine the logic, introduce EquivalenceProperties to logical 
plan and infer additional filters to push down
-        // For inner joins, duplicate filters for joined columns so filters 
can be pushed down
-        // to both sides. Take the following query as an example:
-        //
-        // ```sql
-        // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
-        // ```
-        //
-        // `t1.id > 1` predicate needs to be pushed down to t1 table scan, 
while
-        // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
-        //
-        // Join clauses with `Using` constraints also take advantage of this 
logic to make sure
-        // predicates reference the shared join columns are pushed to both 
sides.
-        // This logic should also been applied to conditions in JOIN ON clause
-        predicates
-            .iter()
-            .chain(on_filters.iter())
-            .filter_map(|predicate| {
-                let mut join_cols_to_replace = HashMap::new();
-
-                let columns = match predicate.to_columns() {
-                    Ok(columns) => columns,
-                    Err(e) => return Some(Err(e)),
-                };
+    // Are there any new join predicates that can be inferred from the filter 
expressions?
+    let inferred_join_predicates =
+        infer_join_predicates(&join, &predicates, &on_filters)?;
 
-                for col in columns.iter() {
-                    for (l, r) in join_col_keys.iter() {
-                        if col == l {
-                            join_cols_to_replace.insert(col, r);
-                            break;
-                        } else if col == r {
-                            join_cols_to_replace.insert(col, l);
-                            break;
-                        }
-                    }
-                }
+    if on_filters.is_empty()
+        && predicates.is_empty()
+        && inferred_join_predicates.is_empty()
+    {
+        return Ok(Transformed::no(LogicalPlan::Join(join)));
+    }
 
-                if join_cols_to_replace.is_empty() {
-                    return None;
-                }
+    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
+}
 
-                let join_side_predicate =
-                    match replace_col(predicate.clone(), 
&join_cols_to_replace) {
-                        Ok(p) => p,
-                        Err(e) => {
-                            return Some(Err(e));
-                        }
-                    };
+/// Extracts any equi-join join predicates from the given filter expressions.
+///
+/// Parameters
+/// * `join` the join in question
+///
+/// * `predicates` the pushed down filter expression
+///
+/// * `on_filters` filters from the join ON clause that have not already been
+/// identified as join predicates
+///
+fn infer_join_predicates(
+    join: &Join,
+    predicates: &[Expr],
+    on_filters: &[Expr],
+) -> Result<Vec<Expr>> {
+    if join.join_type != JoinType::Inner {
+        return Ok(vec![]);
+    }
 
-                Some(Ok(join_side_predicate))
-            })
-            .collect::<Result<Vec<_>>>()?
-    } else {
-        vec![]
-    };
+    // Only allow both side key is column.
+    let join_col_keys = join
+        .on
+        .iter()
+        .filter_map(|(l, r)| {
+            let left_col = l.try_as_col()?;
+            let right_col = r.try_as_col()?;
+            Some((left_col, right_col))
+        })
+        .collect::<Vec<_>>();
 
-    if on_filters.is_empty() && predicates.is_empty() && 
infer_predicates.is_empty() {
-        return Ok(Transformed::no(plan.clone()));
-    }
+    // TODO refine the logic, introduce EquivalenceProperties to logical plan 
and infer additional filters to push down
+    // For inner joins, duplicate filters for joined columns so filters can be 
pushed down
+    // to both sides. Take the following query as an example:
+    //
+    // ```sql
+    // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+    // ```
+    //
+    // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
+    // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
+    //
+    // Join clauses with `Using` constraints also take advantage of this logic 
to make sure
+    // predicates reference the shared join columns are pushed to both sides.
+    // This logic should also been applied to conditions in JOIN ON clause
+    predicates
+        .iter()
+        .chain(on_filters.iter())
+        .filter_map(|predicate| {
+            let mut join_cols_to_replace = HashMap::new();
+
+            let columns = match predicate.to_columns() {
+                Ok(columns) => columns,
+                Err(e) => return Some(Err(e)),
+            };
+
+            for col in columns.iter() {
+                for (l, r) in join_col_keys.iter() {
+                    if col == *l {
+                        join_cols_to_replace.insert(col, *r);
+                        break;
+                    } else if col == *r {
+                        join_cols_to_replace.insert(col, *l);
+                        break;
+                    }
+                }
+            }
 
-    match push_down_all_join(
-        predicates,
-        infer_predicates,
-        plan,
-        &join.left,
-        &join.right,
-        on_filters,
-        is_inner_join,
-    ) {
-        Ok(plan) => Ok(Transformed::yes(plan.data)),
-        Err(e) => Err(e),
-    }
+            if join_cols_to_replace.is_empty() {
+                return None;
+            }
+
+            let join_side_predicate =
+                match replace_col(predicate.clone(), &join_cols_to_replace) {
+                    Ok(p) => p,
+                    Err(e) => {
+                        return Some(Err(e));
+                    }
+                };
+
+            Some(Ok(join_side_predicate))
+        })
+        .collect::<Result<Vec<_>>>()
 }
 
 impl OptimizerRule for PushDownFilter {
@@ -641,46 +620,57 @@ impl OptimizerRule for PushDownFilter {
         plan: LogicalPlan,
         _config: &dyn OptimizerConfig,
     ) -> Result<Transformed<LogicalPlan>> {
-        let filter = match plan {
-            LogicalPlan::Filter(ref filter) => filter,
-            LogicalPlan::Join(ref join) => return push_down_join(&plan, join, 
None),
-            _ => return Ok(Transformed::no(plan)),
+        if let LogicalPlan::Join(join) = plan {
+            return push_down_join(join, None);
+        };
+
+        let plan_schema = plan.schema().clone();
+
+        let LogicalPlan::Filter(mut filter) = plan else {
+            return Ok(Transformed::no(plan));
         };
 
-        let child_plan = filter.input.as_ref();
-        let new_plan = match child_plan {
-            LogicalPlan::Filter(ref child_filter) => {
-                let parents_predicates = split_conjunction(&filter.predicate);
-                let set: HashSet<&&Expr> = parents_predicates.iter().collect();
+        match unwrap_arc(filter.input) {
+            LogicalPlan::Filter(child_filter) => {
+                let parents_predicates = 
split_conjunction_owned(filter.predicate);
 
+                // remove duplicated filters
+                let child_predicates = 
split_conjunction_owned(child_filter.predicate);
                 let new_predicates = parents_predicates
-                    .iter()
-                    .chain(
-                        split_conjunction(&child_filter.predicate)
-                            .iter()
-                            .filter(|e| !set.contains(e)),
-                    )
-                    .map(|e| (*e).clone())
+                    .into_iter()
+                    .chain(child_predicates)
+                    // use IndexSet to remove dupes while preserving predicate 
order
+                    .collect::<IndexSet<_>>()
+                    .into_iter()
                     .collect::<Vec<_>>();
-                let new_predicate = conjunction(new_predicates).ok_or_else(|| {
-                    plan_datafusion_err!("at least one expression exists")
-                })?;
+
+                let Some(new_predicate) = conjunction(new_predicates) else {
+                    return plan_err!("at least one expression exists");
+                };
                 let new_filter = LogicalPlan::Filter(Filter::try_new(
                     new_predicate,
-                    child_filter.input.clone(),
+                    child_filter.input,
                 )?);
-                self.rewrite(new_filter, _config)?.data
+                self.rewrite(new_filter, _config)
             }
-            LogicalPlan::Repartition(_)
-            | LogicalPlan::Distinct(_)
-            | LogicalPlan::Sort(_) => {
-                let new_filter = plan.with_new_exprs(
-                    plan.expressions(),
-                    vec![child_plan.inputs()[0].clone()],
-                )?;
-                child_plan.with_new_exprs(child_plan.expressions(), 
vec![new_filter])?
+            LogicalPlan::Repartition(repartition) => {
+                let new_filter =
+                    Filter::try_new(filter.predicate, 
repartition.input.clone())
+                        .map(LogicalPlan::Filter)?;
+                insert_below(LogicalPlan::Repartition(repartition), new_filter)
             }
-            LogicalPlan::SubqueryAlias(ref subquery_alias) => {
+            LogicalPlan::Distinct(distinct) => {
+                let new_filter =
+                    Filter::try_new(filter.predicate, distinct.input().clone())
+                        .map(LogicalPlan::Filter)?;
+                insert_below(LogicalPlan::Distinct(distinct), new_filter)
+            }
+            LogicalPlan::Sort(sort) => {
+                let new_filter = Filter::try_new(filter.predicate, 
sort.input.clone())
+                    .map(LogicalPlan::Filter)?;
+                insert_below(LogicalPlan::Sort(sort), new_filter)
+            }
+            LogicalPlan::SubqueryAlias(subquery_alias) => {
                 let mut replace_map = HashMap::new();
                 for (i, (qualifier, field)) in
                     subquery_alias.input.schema().iter().enumerate()
@@ -692,15 +682,15 @@ impl OptimizerRule for PushDownFilter {
                         Expr::Column(Column::new(qualifier.cloned(), 
field.name())),
                     );
                 }
-                let new_predicate =
-                    replace_cols_by_name(filter.predicate.clone(), 
&replace_map)?;
+                let new_predicate = replace_cols_by_name(filter.predicate, 
&replace_map)?;
+
                 let new_filter = LogicalPlan::Filter(Filter::try_new(
                     new_predicate,
                     subquery_alias.input.clone(),
                 )?);
-                child_plan.with_new_exprs(child_plan.expressions(), 
vec![new_filter])?
+                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), 
new_filter)
             }
-            LogicalPlan::Projection(ref projection) => {
+            LogicalPlan::Projection(projection) => {
                 // A projection is filter-commutable if it do not contain 
volatile predicates or contain volatile
                 // predicates that are not used in the filter. However, we 
should re-writes all predicate expressions.
                 // collect projection.
@@ -711,10 +701,7 @@ impl OptimizerRule for PushDownFilter {
                         .enumerate()
                         .map(|(i, (qualifier, field))| {
                             // strip alias, as they should not be part of 
filters
-                            let expr = match &projection.expr[i] {
-                                Expr::Alias(Alias { expr, .. }) => 
expr.as_ref().clone(),
-                                expr => expr.clone(),
-                            };
+                            let expr = projection.expr[i].clone().unalias();
 
                             (qualified_name(qualifier, field.name()), expr)
                         })
@@ -741,23 +728,24 @@ impl OptimizerRule for PushDownFilter {
                         )?);
 
                         match conjunction(keep_predicates) {
-                            None => child_plan.with_new_exprs(
-                                child_plan.expressions(),
-                                vec![new_filter],
-                            )?,
-                            Some(keep_predicate) => {
-                                let child_plan = child_plan.with_new_exprs(
-                                    child_plan.expressions(),
-                                    vec![new_filter],
-                                )?;
-                                LogicalPlan::Filter(Filter::try_new(
-                                    keep_predicate,
-                                    Arc::new(child_plan),
-                                )?)
-                            }
+                            None => insert_below(
+                                LogicalPlan::Projection(projection),
+                                new_filter,
+                            ),
+                            Some(keep_predicate) => insert_below(
+                                LogicalPlan::Projection(projection),
+                                new_filter,
+                            )?
+                            .map_data(|child_plan| {
+                                Filter::try_new(keep_predicate, 
Arc::new(child_plan))
+                                    .map(LogicalPlan::Filter)
+                            }),
                         }
                     }
-                    None => return Ok(Transformed::no(plan)),
+                    None => {
+                        filter.input = 
Arc::new(LogicalPlan::Projection(projection));
+                        Ok(Transformed::no(LogicalPlan::Filter(filter)))
+                    }
                 }
             }
             LogicalPlan::Union(ref union) => {
@@ -780,12 +768,12 @@ impl OptimizerRule for PushDownFilter {
                         input.clone(),
                     )?)))
                 }
-                LogicalPlan::Union(Union {
+                Ok(Transformed::yes(LogicalPlan::Union(Union {
                     inputs,
-                    schema: plan.schema().clone(),
-                })
+                    schema: plan_schema.clone(),
+                })))
             }
-            LogicalPlan::Aggregate(ref agg) => {
+            LogicalPlan::Aggregate(agg) => {
                 // We can push down Predicate which in groupby_expr.
                 let group_expr_columns = agg
                     .group_expr
@@ -818,49 +806,33 @@ impl OptimizerRule for PushDownFilter {
                     .map(|expr| replace_cols_by_name(expr.clone(), 
&replace_map))
                     .collect::<Result<Vec<_>>>()?;
 
-                let child = match conjunction(replaced_push_predicates) {
-                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
-                        predicate,
-                        agg.input.clone(),
-                    )?),
-                    None => (*agg.input).clone(),
-                };
-                let new_agg = filter
-                    .input
-                    .with_new_exprs(filter.input.expressions(), vec![child])?;
-                match conjunction(keep_predicates) {
-                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
-                        predicate,
-                        Arc::new(new_agg),
-                    )?),
-                    None => new_agg,
-                }
-            }
-            LogicalPlan::Join(ref join) => {
-                push_down_join(
-                    &unwrap_arc(filter.clone().input),
-                    join,
-                    Some(&filter.predicate),
-                )?
-                .data
+                let agg_input = agg.input.clone();
+                Transformed::yes(LogicalPlan::Aggregate(agg))
+                    .transform_data(|new_plan| {
+                        // If we have a filter to push, we push it down to the 
input of the aggregate
+                        if let Some(predicate) = 
conjunction(replaced_push_predicates) {
+                            let new_filter = make_filter(predicate, 
agg_input)?;
+                            insert_below(new_plan, new_filter)
+                        } else {
+                            Ok(Transformed::no(new_plan))
+                        }
+                    })?
+                    .map_data(|child_plan| {
+                        // if there are any remaining predicates we can't 
push, add them
+                        // back as a filter
+                        if let Some(predicate) = conjunction(keep_predicates) {
+                            make_filter(predicate, Arc::new(child_plan))
+                        } else {
+                            Ok(child_plan)
+                        }
+                    })
             }
-            LogicalPlan::CrossJoin(ref cross_join) => {
+            LogicalPlan::Join(join) => push_down_join(join, 
Some(&filter.predicate)),
+            LogicalPlan::CrossJoin(cross_join) => {
                 let predicates = 
split_conjunction_owned(filter.predicate.clone());
-                let join = 
convert_cross_join_to_inner_join(cross_join.clone())?;
-                let join_plan = LogicalPlan::Join(join);
-                let inputs = join_plan.inputs();
-                let left = inputs[0];
-                let right = inputs[1];
-                let plan = push_down_all_join(
-                    predicates,
-                    vec![],
-                    &join_plan,
-                    left,
-                    right,
-                    vec![],
-                    true,
-                )?;
-                convert_to_cross_join_if_beneficial(plan.data)?
+                let join = convert_cross_join_to_inner_join(cross_join)?;
+                let plan = push_down_all_join(predicates, vec![], join, 
vec![])?;
+                convert_to_cross_join_if_beneficial(plan.data)
             }
             LogicalPlan::TableScan(ref scan) => {
                 let filter_predicates = split_conjunction(&filter.predicate);
@@ -901,25 +873,47 @@ impl OptimizerRule for PushDownFilter {
                     fetch: scan.fetch,
                 });
 
-                match conjunction(new_predicate) {
-                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
-                        predicate,
-                        Arc::new(new_scan),
-                    )?),
-                    None => new_scan,
-                }
+                Transformed::yes(new_scan).transform_data(|new_scan| {
+                    if let Some(predicate) = conjunction(new_predicate) {
+                        make_filter(predicate, 
Arc::new(new_scan)).map(Transformed::yes)
+                    } else {
+                        Ok(Transformed::no(new_scan))
+                    }
+                })
             }
-            LogicalPlan::Extension(ref extension_plan) => {
+            LogicalPlan::Extension(extension_plan) => {
                 let prevent_cols =
                     extension_plan.node.prevent_predicate_push_down_columns();
 
-                let predicates = 
split_conjunction_owned(filter.predicate.clone());
+                // determine if we can push any predicates down past the 
extension node
+
+                // each element is true for push, false to keep
+                let predicate_push_or_keep = 
split_conjunction(&filter.predicate)
+                    .iter()
+                    .map(|expr| {
+                        let cols = expr.to_columns()?;
+                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) 
{
+                            Ok(false) // No push (keep)
+                        } else {
+                            Ok(true) // push
+                        }
+                    })
+                    .collect::<Result<Vec<_>>>()?;
 
+                // all predicates are kept, no changes needed
+                if predicate_push_or_keep.iter().all(|&x| !x) {
+                    filter.input = 
Arc::new(LogicalPlan::Extension(extension_plan));
+                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
+                }
+
+                // going to push some predicates down, so split the predicates
                 let mut keep_predicates = vec![];
                 let mut push_predicates = vec![];
-                for expr in predicates {
-                    let cols = expr.to_columns()?;
-                    if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
+                for (push, expr) in predicate_push_or_keep
+                    .into_iter()
+                    .zip(split_conjunction_owned(filter.predicate).into_iter())
+                {
+                    if !push {
                         keep_predicates.push(expr);
                     } else {
                         push_predicates.push(expr);
@@ -941,22 +935,65 @@ impl OptimizerRule for PushDownFilter {
                     None => 
extension_plan.node.inputs().into_iter().cloned().collect(),
                 };
                 // extension with new inputs.
+                let child_plan = LogicalPlan::Extension(extension_plan);
                 let new_extension =
                     child_plan.with_new_exprs(child_plan.expressions(), 
new_children)?;
 
-                match conjunction(keep_predicates) {
+                let new_plan = match conjunction(keep_predicates) {
                     Some(predicate) => LogicalPlan::Filter(Filter::try_new(
                         predicate,
                         Arc::new(new_extension),
                     )?),
                     None => new_extension,
-                }
+                };
+                Ok(Transformed::yes(new_plan))
             }
-            _ => return Ok(Transformed::no(plan)),
-        };
+            child => {
+                filter.input = Arc::new(child);
+                Ok(Transformed::no(LogicalPlan::Filter(filter)))
+            }
+        }
+    }
+}
+
+/// Creates a new LogicalPlan::Filter node.
+pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> 
Result<LogicalPlan> {
+    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
+}
 
-        Ok(Transformed::yes(new_plan))
+/// Replace the existing child of the single input node with `new_child`.
+///
+/// Starting:
+/// ```text
+/// plan
+///   child
+/// ```
+///
+/// Ending:
+/// ```text
+/// plan
+///   new_child
+/// ```
+fn insert_below(
+    plan: LogicalPlan,
+    new_child: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+    let mut new_child = Some(new_child);
+    let transformed_plan = plan.map_children(|_child| {
+        if let Some(new_child) = new_child.take() {
+            Ok(Transformed::yes(new_child))
+        } else {
+            // already took the new child
+            internal_err!("node had more than one input")
+        }
+    })?;
+
+    // make sure we did the actual replacement
+    if new_child.is_some() {
+        return internal_err!("node had no  inputs");
     }
+
+    Ok(transformed_plan)
 }
 
 impl PushDownFilter {
@@ -985,21 +1022,27 @@ fn convert_cross_join_to_inner_join(cross_join: 
CrossJoin) -> Result<Join> {
 
 /// Converts the given inner join with an empty equality predicate and an
 /// empty filter condition to a cross join.
-fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> 
Result<LogicalPlan> {
-    if let LogicalPlan::Join(join) = &plan {
+fn convert_to_cross_join_if_beneficial(
+    plan: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+    match plan {
         // Can be converted back to cross join
-        if join.on.is_empty() && join.filter.is_none() {
-            return LogicalPlanBuilder::from(join.left.as_ref().clone())
-                .cross_join(join.right.as_ref().clone())?
-                .build();
+        LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() 
=> {
+            LogicalPlanBuilder::from(unwrap_arc(join.left))
+                .cross_join(unwrap_arc(join.right))?
+                .build()
+                .map(Transformed::yes)
         }
-    } else if let LogicalPlan::Filter(filter) = &plan {
-        let new_input =
-            
convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?;
-        return Filter::try_new(filter.predicate.clone(), Arc::new(new_input))
-            .map(LogicalPlan::Filter);
+        LogicalPlan::Filter(filter) => 
convert_to_cross_join_if_beneficial(unwrap_arc(
+            filter.input,
+        ))?
+        .transform_data(|child_plan| {
+            Filter::try_new(filter.predicate, Arc::new(child_plan))
+                .map(LogicalPlan::Filter)
+                .map(Transformed::yes)
+        }),
+        plan => Ok(Transformed::no(plan)),
     }
-    Ok(plan)
 }
 
 /// replaces columns by its name on the projection.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to