This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 87169f06ab Stop copying LogicalPlan and Exprs in `PushDownFilter`
(#10444)
87169f06ab is described below
commit 87169f06ab590f20bd03b1be504a2119ddca6d68
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu May 16 20:08:39 2024 -0400
Stop copying LogicalPlan and Exprs in `PushDownFilter` (#10444)
---
datafusion/expr/src/logical_plan/plan.rs | 10 +
datafusion/optimizer/src/push_down_filter.rs | 623 ++++++++++++++-------------
2 files changed, 343 insertions(+), 290 deletions(-)
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index ddf075c2c2..4872e5acda 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -2407,6 +2407,16 @@ pub enum Distinct {
On(DistinctOn),
}
+impl Distinct {
+ /// return a reference to the nodes input
+ pub fn input(&self) -> &Arc<LogicalPlan> {
+ match self {
+ Distinct::All(input) => input,
+ Distinct::On(DistinctOn { input, .. }) => input,
+ }
+ }
+}
+
/// Removes duplicate rows from the input
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct DistinctOn {
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 57b38bd0d0..b684b54903 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -14,6 +14,7 @@
//! [`PushDownFilter`] applies filters as early as possible
+use indexmap::IndexSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
@@ -23,10 +24,9 @@ use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{
- internal_err, plan_datafusion_err, qualified_name, Column, DFSchema,
DFSchemaRef,
+ internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef,
JoinConstraint, Result,
};
-use datafusion_expr::expr::Alias;
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
@@ -131,7 +131,8 @@ use crate::{OptimizerConfig, OptimizerRule};
#[derive(Default)]
pub struct PushDownFilter {}
-/// For a given JOIN logical plan, determine whether each side of the join is
preserved.
+/// For a given JOIN type, determine whether each side of the join is
preserved.
+///
/// We say a join side is preserved if the join returns all or a subset of the
rows from
/// the relevant side, such that each row of the output table directly maps to
a row of
/// the preserved input table. If a table is not preserved, it can provide
extra null rows.
@@ -150,44 +151,33 @@ pub struct PushDownFilter {}
/// non-preserved side it can be more tricky.
///
/// Returns a tuple of booleans - (left_preserved, right_preserved).
-fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
- match plan {
- LogicalPlan::Join(Join { join_type, .. }) => match join_type {
- JoinType::Inner => Ok((true, true)),
- JoinType::Left => Ok((true, false)),
- JoinType::Right => Ok((false, true)),
- JoinType::Full => Ok((false, false)),
- // No columns from the right side of the join can be referenced in
output
- // predicates for semi/anti joins, so whether we specify t/f
doesn't matter.
- JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
- // No columns from the left side of the join can be referenced in
output
- // predicates for semi/anti joins, so whether we specify t/f
doesn't matter.
- JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
- },
- LogicalPlan::CrossJoin(_) => Ok((true, true)),
- _ => internal_err!("lr_is_preserved only valid for JOIN nodes"),
+fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
+ match join_type {
+ JoinType::Inner => Ok((true, true)),
+ JoinType::Left => Ok((true, false)),
+ JoinType::Right => Ok((false, true)),
+ JoinType::Full => Ok((false, false)),
+ // No columns from the right side of the join can be referenced in
output
+ // predicates for semi/anti joins, so whether we specify t/f doesn't
matter.
+ JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
+ // No columns from the left side of the join can be referenced in
output
+ // predicates for semi/anti joins, so whether we specify t/f doesn't
matter.
+ JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
}
}
/// For a given JOIN logical plan, determine whether each side of the join is
preserved
/// in terms on join filtering.
-///
/// Predicates from join filter can only be pushed to preserved join side.
-fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
- match plan {
- LogicalPlan::Join(Join { join_type, .. }) => match join_type {
- JoinType::Inner => Ok((true, true)),
- JoinType::Left => Ok((false, true)),
- JoinType::Right => Ok((true, false)),
- JoinType::Full => Ok((false, false)),
- JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
- JoinType::LeftAnti => Ok((false, true)),
- JoinType::RightAnti => Ok((true, false)),
- },
- LogicalPlan::CrossJoin(_) => {
- internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN
nodes")
- }
- _ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"),
+fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
+ match join_type {
+ JoinType::Inner => Ok((true, true)),
+ JoinType::Left => Ok((false, true)),
+ JoinType::Right => Ok((true, false)),
+ JoinType::Full => Ok((false, false)),
+ JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
+ JoinType::LeftAnti => Ok((false, true)),
+ JoinType::RightAnti => Ok((true, false)),
}
}
@@ -400,23 +390,20 @@ fn extract_or_clause(expr: &Expr, schema_columns:
&HashSet<Column>) -> Option<Ex
/// push down join/cross-join
fn push_down_all_join(
predicates: Vec<Expr>,
- infer_predicates: Vec<Expr>,
- join_plan: &LogicalPlan,
- left: &LogicalPlan,
- right: &LogicalPlan,
+ inferred_join_predicates: Vec<Expr>,
+ mut join: Join,
on_filter: Vec<Expr>,
- is_inner_join: bool,
) -> Result<Transformed<LogicalPlan>> {
- let on_filter_empty = on_filter.is_empty();
+ let is_inner_join = join.join_type == JoinType::Inner;
// Get pushable predicates from current optimizer state
- let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
+ let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?;
// The predicates can be divided to three categories:
// 1) can push through join to its children(left or right)
// 2) can be converted to join conditions if the join type is Inner
// 3) should be kept as filter conditions
- let left_schema = left.schema();
- let right_schema = right.schema();
+ let left_schema = join.left.schema();
+ let right_schema = join.right.schema();
let mut left_push = vec![];
let mut right_push = vec![];
let mut keep_predicates = vec![];
@@ -438,7 +425,7 @@ fn push_down_all_join(
}
// For infer predicates, if they can not push through join, just drop them
- for predicate in infer_predicates {
+ for predicate in inferred_join_predicates {
if left_preserved && can_pushdown_join_predicate(&predicate,
left_schema)? {
left_push.push(predicate);
} else if right_preserved
@@ -449,7 +436,7 @@ fn push_down_all_join(
}
if !on_filter.is_empty() {
- let (on_left_preserved, on_right_preserved) =
on_lr_is_preserved(join_plan)?;
+ let (on_left_preserved, on_right_preserved) =
on_lr_is_preserved(join.join_type)?;
for on in on_filter {
if on_left_preserved && can_pushdown_join_predicate(&on,
left_schema)? {
left_push.push(on)
@@ -474,46 +461,29 @@ fn push_down_all_join(
right_push.extend(extract_or_clauses_for_join(&join_conditions,
right_schema));
}
- let left = match conjunction(left_push) {
- Some(predicate) => {
- LogicalPlan::Filter(Filter::try_new(predicate,
Arc::new(left.clone()))?)
- }
- None => left.clone(),
- };
- let right = match conjunction(right_push) {
- Some(predicate) => {
- LogicalPlan::Filter(Filter::try_new(predicate,
Arc::new(right.clone()))?)
- }
- None => right.clone(),
- };
- // Create a new Join with the new `left` and `right`
- //
- // expressions() output for Join is a vector consisting of
- // 1. join keys - columns mentioned in ON clause
- // 2. optional predicate - in case join filter is not empty,
- // it always will be the last element, otherwise result
- // vector will contain only join keys (without additional
- // element representing filter).
- let mut exprs = join_plan.expressions();
- if !on_filter_empty {
- exprs.pop();
- }
- exprs.extend(join_conditions.into_iter().reduce(Expr::and));
- let plan = join_plan.with_new_exprs(exprs, vec![left, right])?;
-
- // wrap the join on the filter whose predicates must be kept
- match conjunction(keep_predicates) {
- Some(predicate) => {
- let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
- Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
- }
- None => Ok(Transformed::no(plan)),
+ if let Some(predicate) = conjunction(left_push) {
+ join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate,
join.left)?));
}
+ if let Some(predicate) = conjunction(right_push) {
+ join.right =
+ Arc::new(LogicalPlan::Filter(Filter::try_new(predicate,
join.right)?));
+ }
+
+ // Add any new join conditions as the non join predicates
+ join.filter = conjunction(join_conditions);
+
+ // wrap the join on the filter whose predicates must be kept, if any
+ let plan = LogicalPlan::Join(join);
+ let plan = if let Some(predicate) = conjunction(keep_predicates) {
+ LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
+ } else {
+ plan
+ };
+ Ok(Transformed::yes(plan))
}
fn push_down_join(
- plan: &LogicalPlan,
- join: &Join,
+ join: Join,
parent_predicate: Option<&Expr>,
) -> Result<Transformed<LogicalPlan>> {
// Split the parent predicate into individual conjunctive parts.
@@ -526,93 +496,102 @@ fn push_down_join(
.as_ref()
.map_or_else(Vec::new, |filter|
split_conjunction_owned(filter.clone()));
- let mut is_inner_join = false;
- let infer_predicates = if join.join_type == JoinType::Inner {
- is_inner_join = true;
-
- // Only allow both side key is column.
- let join_col_keys = join
- .on
- .iter()
- .filter_map(|(l, r)| {
- let left_col = l.try_as_col().cloned()?;
- let right_col = r.try_as_col().cloned()?;
- Some((left_col, right_col))
- })
- .collect::<Vec<_>>();
-
- // TODO refine the logic, introduce EquivalenceProperties to logical
plan and infer additional filters to push down
- // For inner joins, duplicate filters for joined columns so filters
can be pushed down
- // to both sides. Take the following query as an example:
- //
- // ```sql
- // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
- // ```
- //
- // `t1.id > 1` predicate needs to be pushed down to t1 table scan,
while
- // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
- //
- // Join clauses with `Using` constraints also take advantage of this
logic to make sure
- // predicates reference the shared join columns are pushed to both
sides.
- // This logic should also been applied to conditions in JOIN ON clause
- predicates
- .iter()
- .chain(on_filters.iter())
- .filter_map(|predicate| {
- let mut join_cols_to_replace = HashMap::new();
-
- let columns = match predicate.to_columns() {
- Ok(columns) => columns,
- Err(e) => return Some(Err(e)),
- };
+ // Are there any new join predicates that can be inferred from the filter
expressions?
+ let inferred_join_predicates =
+ infer_join_predicates(&join, &predicates, &on_filters)?;
- for col in columns.iter() {
- for (l, r) in join_col_keys.iter() {
- if col == l {
- join_cols_to_replace.insert(col, r);
- break;
- } else if col == r {
- join_cols_to_replace.insert(col, l);
- break;
- }
- }
- }
+ if on_filters.is_empty()
+ && predicates.is_empty()
+ && inferred_join_predicates.is_empty()
+ {
+ return Ok(Transformed::no(LogicalPlan::Join(join)));
+ }
- if join_cols_to_replace.is_empty() {
- return None;
- }
+ push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
+}
- let join_side_predicate =
- match replace_col(predicate.clone(),
&join_cols_to_replace) {
- Ok(p) => p,
- Err(e) => {
- return Some(Err(e));
- }
- };
+/// Extracts any equi-join join predicates from the given filter expressions.
+///
+/// Parameters
+/// * `join` the join in question
+///
+/// * `predicates` the pushed down filter expression
+///
+/// * `on_filters` filters from the join ON clause that have not already been
+/// identified as join predicates
+///
+fn infer_join_predicates(
+ join: &Join,
+ predicates: &[Expr],
+ on_filters: &[Expr],
+) -> Result<Vec<Expr>> {
+ if join.join_type != JoinType::Inner {
+ return Ok(vec![]);
+ }
- Some(Ok(join_side_predicate))
- })
- .collect::<Result<Vec<_>>>()?
- } else {
- vec![]
- };
+ // Only allow both side key is column.
+ let join_col_keys = join
+ .on
+ .iter()
+ .filter_map(|(l, r)| {
+ let left_col = l.try_as_col()?;
+ let right_col = r.try_as_col()?;
+ Some((left_col, right_col))
+ })
+ .collect::<Vec<_>>();
- if on_filters.is_empty() && predicates.is_empty() &&
infer_predicates.is_empty() {
- return Ok(Transformed::no(plan.clone()));
- }
+ // TODO refine the logic, introduce EquivalenceProperties to logical plan
and infer additional filters to push down
+ // For inner joins, duplicate filters for joined columns so filters can be
pushed down
+ // to both sides. Take the following query as an example:
+ //
+ // ```sql
+ // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+ // ```
+ //
+ // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
+ // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
+ //
+ // Join clauses with `Using` constraints also take advantage of this logic
to make sure
+ // predicates reference the shared join columns are pushed to both sides.
+ // This logic should also been applied to conditions in JOIN ON clause
+ predicates
+ .iter()
+ .chain(on_filters.iter())
+ .filter_map(|predicate| {
+ let mut join_cols_to_replace = HashMap::new();
+
+ let columns = match predicate.to_columns() {
+ Ok(columns) => columns,
+ Err(e) => return Some(Err(e)),
+ };
+
+ for col in columns.iter() {
+ for (l, r) in join_col_keys.iter() {
+ if col == *l {
+ join_cols_to_replace.insert(col, *r);
+ break;
+ } else if col == *r {
+ join_cols_to_replace.insert(col, *l);
+ break;
+ }
+ }
+ }
- match push_down_all_join(
- predicates,
- infer_predicates,
- plan,
- &join.left,
- &join.right,
- on_filters,
- is_inner_join,
- ) {
- Ok(plan) => Ok(Transformed::yes(plan.data)),
- Err(e) => Err(e),
- }
+ if join_cols_to_replace.is_empty() {
+ return None;
+ }
+
+ let join_side_predicate =
+ match replace_col(predicate.clone(), &join_cols_to_replace) {
+ Ok(p) => p,
+ Err(e) => {
+ return Some(Err(e));
+ }
+ };
+
+ Some(Ok(join_side_predicate))
+ })
+ .collect::<Result<Vec<_>>>()
}
impl OptimizerRule for PushDownFilter {
@@ -641,46 +620,57 @@ impl OptimizerRule for PushDownFilter {
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
- let filter = match plan {
- LogicalPlan::Filter(ref filter) => filter,
- LogicalPlan::Join(ref join) => return push_down_join(&plan, join,
None),
- _ => return Ok(Transformed::no(plan)),
+ if let LogicalPlan::Join(join) = plan {
+ return push_down_join(join, None);
+ };
+
+ let plan_schema = plan.schema().clone();
+
+ let LogicalPlan::Filter(mut filter) = plan else {
+ return Ok(Transformed::no(plan));
};
- let child_plan = filter.input.as_ref();
- let new_plan = match child_plan {
- LogicalPlan::Filter(ref child_filter) => {
- let parents_predicates = split_conjunction(&filter.predicate);
- let set: HashSet<&&Expr> = parents_predicates.iter().collect();
+ match unwrap_arc(filter.input) {
+ LogicalPlan::Filter(child_filter) => {
+ let parents_predicates =
split_conjunction_owned(filter.predicate);
+ // remove duplicated filters
+ let child_predicates =
split_conjunction_owned(child_filter.predicate);
let new_predicates = parents_predicates
- .iter()
- .chain(
- split_conjunction(&child_filter.predicate)
- .iter()
- .filter(|e| !set.contains(e)),
- )
- .map(|e| (*e).clone())
+ .into_iter()
+ .chain(child_predicates)
+ // use IndexSet to remove dupes while preserving predicate
order
+ .collect::<IndexSet<_>>()
+ .into_iter()
.collect::<Vec<_>>();
- let new_predicate = conjunction(new_predicates).ok_or_else(|| {
- plan_datafusion_err!("at least one expression exists")
- })?;
+
+ let Some(new_predicate) = conjunction(new_predicates) else {
+ return plan_err!("at least one expression exists");
+ };
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
- child_filter.input.clone(),
+ child_filter.input,
)?);
- self.rewrite(new_filter, _config)?.data
+ self.rewrite(new_filter, _config)
}
- LogicalPlan::Repartition(_)
- | LogicalPlan::Distinct(_)
- | LogicalPlan::Sort(_) => {
- let new_filter = plan.with_new_exprs(
- plan.expressions(),
- vec![child_plan.inputs()[0].clone()],
- )?;
- child_plan.with_new_exprs(child_plan.expressions(),
vec![new_filter])?
+ LogicalPlan::Repartition(repartition) => {
+ let new_filter =
+ Filter::try_new(filter.predicate,
repartition.input.clone())
+ .map(LogicalPlan::Filter)?;
+ insert_below(LogicalPlan::Repartition(repartition), new_filter)
}
- LogicalPlan::SubqueryAlias(ref subquery_alias) => {
+ LogicalPlan::Distinct(distinct) => {
+ let new_filter =
+ Filter::try_new(filter.predicate, distinct.input().clone())
+ .map(LogicalPlan::Filter)?;
+ insert_below(LogicalPlan::Distinct(distinct), new_filter)
+ }
+ LogicalPlan::Sort(sort) => {
+ let new_filter = Filter::try_new(filter.predicate,
sort.input.clone())
+ .map(LogicalPlan::Filter)?;
+ insert_below(LogicalPlan::Sort(sort), new_filter)
+ }
+ LogicalPlan::SubqueryAlias(subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in
subquery_alias.input.schema().iter().enumerate()
@@ -692,15 +682,15 @@ impl OptimizerRule for PushDownFilter {
Expr::Column(Column::new(qualifier.cloned(),
field.name())),
);
}
- let new_predicate =
- replace_cols_by_name(filter.predicate.clone(),
&replace_map)?;
+ let new_predicate = replace_cols_by_name(filter.predicate,
&replace_map)?;
+
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
subquery_alias.input.clone(),
)?);
- child_plan.with_new_exprs(child_plan.expressions(),
vec![new_filter])?
+ insert_below(LogicalPlan::SubqueryAlias(subquery_alias),
new_filter)
}
- LogicalPlan::Projection(ref projection) => {
+ LogicalPlan::Projection(projection) => {
// A projection is filter-commutable if it do not contain
volatile predicates or contain volatile
// predicates that are not used in the filter. However, we
should re-writes all predicate expressions.
// collect projection.
@@ -711,10 +701,7 @@ impl OptimizerRule for PushDownFilter {
.enumerate()
.map(|(i, (qualifier, field))| {
// strip alias, as they should not be part of
filters
- let expr = match &projection.expr[i] {
- Expr::Alias(Alias { expr, .. }) =>
expr.as_ref().clone(),
- expr => expr.clone(),
- };
+ let expr = projection.expr[i].clone().unalias();
(qualified_name(qualifier, field.name()), expr)
})
@@ -741,23 +728,24 @@ impl OptimizerRule for PushDownFilter {
)?);
match conjunction(keep_predicates) {
- None => child_plan.with_new_exprs(
- child_plan.expressions(),
- vec![new_filter],
- )?,
- Some(keep_predicate) => {
- let child_plan = child_plan.with_new_exprs(
- child_plan.expressions(),
- vec![new_filter],
- )?;
- LogicalPlan::Filter(Filter::try_new(
- keep_predicate,
- Arc::new(child_plan),
- )?)
- }
+ None => insert_below(
+ LogicalPlan::Projection(projection),
+ new_filter,
+ ),
+ Some(keep_predicate) => insert_below(
+ LogicalPlan::Projection(projection),
+ new_filter,
+ )?
+ .map_data(|child_plan| {
+ Filter::try_new(keep_predicate,
Arc::new(child_plan))
+ .map(LogicalPlan::Filter)
+ }),
}
}
- None => return Ok(Transformed::no(plan)),
+ None => {
+ filter.input =
Arc::new(LogicalPlan::Projection(projection));
+ Ok(Transformed::no(LogicalPlan::Filter(filter)))
+ }
}
}
LogicalPlan::Union(ref union) => {
@@ -780,12 +768,12 @@ impl OptimizerRule for PushDownFilter {
input.clone(),
)?)))
}
- LogicalPlan::Union(Union {
+ Ok(Transformed::yes(LogicalPlan::Union(Union {
inputs,
- schema: plan.schema().clone(),
- })
+ schema: plan_schema.clone(),
+ })))
}
- LogicalPlan::Aggregate(ref agg) => {
+ LogicalPlan::Aggregate(agg) => {
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
@@ -818,49 +806,33 @@ impl OptimizerRule for PushDownFilter {
.map(|expr| replace_cols_by_name(expr.clone(),
&replace_map))
.collect::<Result<Vec<_>>>()?;
- let child = match conjunction(replaced_push_predicates) {
- Some(predicate) => LogicalPlan::Filter(Filter::try_new(
- predicate,
- agg.input.clone(),
- )?),
- None => (*agg.input).clone(),
- };
- let new_agg = filter
- .input
- .with_new_exprs(filter.input.expressions(), vec![child])?;
- match conjunction(keep_predicates) {
- Some(predicate) => LogicalPlan::Filter(Filter::try_new(
- predicate,
- Arc::new(new_agg),
- )?),
- None => new_agg,
- }
- }
- LogicalPlan::Join(ref join) => {
- push_down_join(
- &unwrap_arc(filter.clone().input),
- join,
- Some(&filter.predicate),
- )?
- .data
+ let agg_input = agg.input.clone();
+ Transformed::yes(LogicalPlan::Aggregate(agg))
+ .transform_data(|new_plan| {
+ // If we have a filter to push, we push it down to the
input of the aggregate
+ if let Some(predicate) =
conjunction(replaced_push_predicates) {
+ let new_filter = make_filter(predicate,
agg_input)?;
+ insert_below(new_plan, new_filter)
+ } else {
+ Ok(Transformed::no(new_plan))
+ }
+ })?
+ .map_data(|child_plan| {
+ // if there are any remaining predicates we can't
push, add them
+ // back as a filter
+ if let Some(predicate) = conjunction(keep_predicates) {
+ make_filter(predicate, Arc::new(child_plan))
+ } else {
+ Ok(child_plan)
+ }
+ })
}
- LogicalPlan::CrossJoin(ref cross_join) => {
+ LogicalPlan::Join(join) => push_down_join(join,
Some(&filter.predicate)),
+ LogicalPlan::CrossJoin(cross_join) => {
let predicates =
split_conjunction_owned(filter.predicate.clone());
- let join =
convert_cross_join_to_inner_join(cross_join.clone())?;
- let join_plan = LogicalPlan::Join(join);
- let inputs = join_plan.inputs();
- let left = inputs[0];
- let right = inputs[1];
- let plan = push_down_all_join(
- predicates,
- vec![],
- &join_plan,
- left,
- right,
- vec![],
- true,
- )?;
- convert_to_cross_join_if_beneficial(plan.data)?
+ let join = convert_cross_join_to_inner_join(cross_join)?;
+ let plan = push_down_all_join(predicates, vec![], join,
vec![])?;
+ convert_to_cross_join_if_beneficial(plan.data)
}
LogicalPlan::TableScan(ref scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
@@ -901,25 +873,47 @@ impl OptimizerRule for PushDownFilter {
fetch: scan.fetch,
});
- match conjunction(new_predicate) {
- Some(predicate) => LogicalPlan::Filter(Filter::try_new(
- predicate,
- Arc::new(new_scan),
- )?),
- None => new_scan,
- }
+ Transformed::yes(new_scan).transform_data(|new_scan| {
+ if let Some(predicate) = conjunction(new_predicate) {
+ make_filter(predicate,
Arc::new(new_scan)).map(Transformed::yes)
+ } else {
+ Ok(Transformed::no(new_scan))
+ }
+ })
}
- LogicalPlan::Extension(ref extension_plan) => {
+ LogicalPlan::Extension(extension_plan) => {
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();
- let predicates =
split_conjunction_owned(filter.predicate.clone());
+ // determine if we can push any predicates down past the
extension node
+
+ // each element is true for push, false to keep
+ let predicate_push_or_keep =
split_conjunction(&filter.predicate)
+ .iter()
+ .map(|expr| {
+ let cols = expr.to_columns()?;
+ if cols.iter().any(|c| prevent_cols.contains(&c.name))
{
+ Ok(false) // No push (keep)
+ } else {
+ Ok(true) // push
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+ // all predicates are kept, no changes needed
+ if predicate_push_or_keep.iter().all(|&x| !x) {
+ filter.input =
Arc::new(LogicalPlan::Extension(extension_plan));
+ return Ok(Transformed::no(LogicalPlan::Filter(filter)));
+ }
+
+ // going to push some predicates down, so split the predicates
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
- for expr in predicates {
- let cols = expr.to_columns()?;
- if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
+ for (push, expr) in predicate_push_or_keep
+ .into_iter()
+ .zip(split_conjunction_owned(filter.predicate).into_iter())
+ {
+ if !push {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
@@ -941,22 +935,65 @@ impl OptimizerRule for PushDownFilter {
None =>
extension_plan.node.inputs().into_iter().cloned().collect(),
};
// extension with new inputs.
+ let child_plan = LogicalPlan::Extension(extension_plan);
let new_extension =
child_plan.with_new_exprs(child_plan.expressions(),
new_children)?;
- match conjunction(keep_predicates) {
+ let new_plan = match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_extension),
)?),
None => new_extension,
- }
+ };
+ Ok(Transformed::yes(new_plan))
}
- _ => return Ok(Transformed::no(plan)),
- };
+ child => {
+ filter.input = Arc::new(child);
+ Ok(Transformed::no(LogicalPlan::Filter(filter)))
+ }
+ }
+ }
+}
+
+/// Creates a new LogicalPlan::Filter node.
+pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) ->
Result<LogicalPlan> {
+ Filter::try_new(predicate, input).map(LogicalPlan::Filter)
+}
- Ok(Transformed::yes(new_plan))
+/// Replace the existing child of the single input node with `new_child`.
+///
+/// Starting:
+/// ```text
+/// plan
+/// child
+/// ```
+///
+/// Ending:
+/// ```text
+/// plan
+/// new_child
+/// ```
+fn insert_below(
+ plan: LogicalPlan,
+ new_child: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+ let mut new_child = Some(new_child);
+ let transformed_plan = plan.map_children(|_child| {
+ if let Some(new_child) = new_child.take() {
+ Ok(Transformed::yes(new_child))
+ } else {
+ // already took the new child
+ internal_err!("node had more than one input")
+ }
+ })?;
+
+ // make sure we did the actual replacement
+ if new_child.is_some() {
+ return internal_err!("node had no inputs");
}
+
+ Ok(transformed_plan)
}
impl PushDownFilter {
@@ -985,21 +1022,27 @@ fn convert_cross_join_to_inner_join(cross_join:
CrossJoin) -> Result<Join> {
/// Converts the given inner join with an empty equality predicate and an
/// empty filter condition to a cross join.
-fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) ->
Result<LogicalPlan> {
- if let LogicalPlan::Join(join) = &plan {
+fn convert_to_cross_join_if_beneficial(
+ plan: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+ match plan {
// Can be converted back to cross join
- if join.on.is_empty() && join.filter.is_none() {
- return LogicalPlanBuilder::from(join.left.as_ref().clone())
- .cross_join(join.right.as_ref().clone())?
- .build();
+ LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none()
=> {
+ LogicalPlanBuilder::from(unwrap_arc(join.left))
+ .cross_join(unwrap_arc(join.right))?
+ .build()
+ .map(Transformed::yes)
}
- } else if let LogicalPlan::Filter(filter) = &plan {
- let new_input =
-
convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?;
- return Filter::try_new(filter.predicate.clone(), Arc::new(new_input))
- .map(LogicalPlan::Filter);
+ LogicalPlan::Filter(filter) =>
convert_to_cross_join_if_beneficial(unwrap_arc(
+ filter.input,
+ ))?
+ .transform_data(|child_plan| {
+ Filter::try_new(filter.predicate, Arc::new(child_plan))
+ .map(LogicalPlan::Filter)
+ .map(Transformed::yes)
+ }),
+ plan => Ok(Transformed::no(plan)),
}
- Ok(plan)
}
/// replaces columns by its name on the projection.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]