This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8190cb9721 Optimized push down filter #10291 (#10366)
8190cb9721 is described below
commit 8190cb97216e4f46faccbeddae57f6773587955f
Author: Dmitry Bugakov <[email protected]>
AuthorDate: Fri May 3 18:22:36 2024 +0200
Optimized push down filter #10291 (#10366)
---
datafusion/optimizer/src/push_down_filter.rs | 139 ++++++++++++++++-----------
1 file changed, 81 insertions(+), 58 deletions(-)
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 8462cf86f1..2355ee604e 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -17,8 +17,7 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
-use crate::optimizer::ApplyOrder;
-use crate::{OptimizerConfig, OptimizerRule};
+use itertools::Itertools;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
@@ -29,6 +28,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::Alias;
use datafusion_expr::expr_rewriter::replace_col;
+use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
};
@@ -38,7 +38,8 @@ use datafusion_expr::{
ScalarFunctionDefinition, TableProviderFilterPushDown,
};
-use itertools::Itertools;
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
/// they are applied as early as possible.
@@ -407,7 +408,7 @@ fn push_down_all_join(
right: &LogicalPlan,
on_filter: Vec<Expr>,
is_inner_join: bool,
-) -> Result<LogicalPlan> {
+) -> Result<Transformed<LogicalPlan>> {
let on_filter_empty = on_filter.is_empty();
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
@@ -505,9 +506,10 @@ fn push_down_all_join(
// wrap the join on the filter whose predicates must be kept
match conjunction(keep_predicates) {
Some(predicate) => {
- Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter)
+ let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
+ Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
}
- None => Ok(plan),
+ None => Ok(Transformed::no(plan)),
}
}
@@ -515,31 +517,32 @@ fn push_down_join(
plan: &LogicalPlan,
join: &Join,
parent_predicate: Option<&Expr>,
-) -> Result<Option<LogicalPlan>> {
- let predicates = match parent_predicate {
- Some(parent_predicate) =>
split_conjunction_owned(parent_predicate.clone()),
- None => vec![],
- };
+) -> Result<Transformed<LogicalPlan>> {
+ // Split the parent predicate into individual conjunctive parts.
+ let predicates = parent_predicate
+ .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
- // Convert JOIN ON predicate to Predicates
+ // Extract conjunctions from the JOIN's ON filter, if present.
let on_filters = join
.filter
.as_ref()
- .map(|e| split_conjunction_owned(e.clone()))
- .unwrap_or_default();
+ .map_or_else(Vec::new, |filter|
split_conjunction_owned(filter.clone()));
let mut is_inner_join = false;
let infer_predicates = if join.join_type == JoinType::Inner {
is_inner_join = true;
+
// Only allow both side key is column.
let join_col_keys = join
.on
.iter()
- .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) {
- (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
- _ => None,
+ .filter_map(|(l, r)| {
+ let left_col = l.try_into_col().ok()?;
+ let right_col = r.try_into_col().ok()?;
+ Some((left_col, right_col))
})
.collect::<Vec<_>>();
+
// TODO refine the logic, introduce EquivalenceProperties to logical
plan and infer additional filters to push down
// For inner joins, duplicate filters for joined columns so filters
can be pushed down
// to both sides. Take the following query as an example:
@@ -559,6 +562,7 @@ fn push_down_join(
.chain(on_filters.iter())
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();
+
let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
@@ -596,9 +600,10 @@ fn push_down_join(
};
if on_filters.is_empty() && predicates.is_empty() &&
infer_predicates.is_empty() {
- return Ok(None);
+ return Ok(Transformed::no(plan.clone()));
}
- Ok(Some(push_down_all_join(
+
+ match push_down_all_join(
predicates,
infer_predicates,
plan,
@@ -606,10 +611,21 @@ fn push_down_join(
&join.right,
on_filters,
is_inner_join,
- )?))
+ ) {
+ Ok(plan) => Ok(Transformed::yes(plan.data)),
+ Err(e) => Err(e),
+ }
}
impl OptimizerRule for PushDownFilter {
+ fn try_optimize(
+ &self,
+ _plan: &LogicalPlan,
+ _config: &dyn OptimizerConfig,
+ ) -> Result<Option<LogicalPlan>> {
+ internal_err!("Should have called PushDownFilter::rewrite")
+ }
+
fn name(&self) -> &str {
"push_down_filter"
}
@@ -618,21 +634,24 @@ impl OptimizerRule for PushDownFilter {
Some(ApplyOrder::TopDown)
}
- fn try_optimize(
+ fn supports_rewrite(&self) -> bool {
+ true
+ }
+
+ fn rewrite(
&self,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
_config: &dyn OptimizerConfig,
- ) -> Result<Option<LogicalPlan>> {
+ ) -> Result<Transformed<LogicalPlan>> {
let filter = match plan {
- LogicalPlan::Filter(filter) => filter,
- // we also need to pushdown filter in Join.
- LogicalPlan::Join(join) => return push_down_join(plan, join, None),
- _ => return Ok(None),
+ LogicalPlan::Filter(ref filter) => filter,
+ LogicalPlan::Join(ref join) => return push_down_join(&plan, join,
None),
+ _ => return Ok(Transformed::no(plan)),
};
let child_plan = filter.input.as_ref();
let new_plan = match child_plan {
- LogicalPlan::Filter(child_filter) => {
+ LogicalPlan::Filter(ref child_filter) => {
let parents_predicates = split_conjunction(&filter.predicate);
let set: HashSet<&&Expr> = parents_predicates.iter().collect();
@@ -652,20 +671,18 @@ impl OptimizerRule for PushDownFilter {
new_predicate,
child_filter.input.clone(),
)?);
- self.try_optimize(&new_filter, _config)?
- .unwrap_or(new_filter)
+ self.rewrite(new_filter, _config)?.data
}
LogicalPlan::Repartition(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Sort(_) => {
- // commutable
let new_filter = plan.with_new_exprs(
plan.expressions(),
vec![child_plan.inputs()[0].clone()],
)?;
child_plan.with_new_exprs(child_plan.expressions(),
vec![new_filter])?
}
- LogicalPlan::SubqueryAlias(subquery_alias) => {
+ LogicalPlan::SubqueryAlias(ref subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in
subquery_alias.input.schema().iter().enumerate()
@@ -685,7 +702,7 @@ impl OptimizerRule for PushDownFilter {
)?);
child_plan.with_new_exprs(child_plan.expressions(),
vec![new_filter])?
}
- LogicalPlan::Projection(projection) => {
+ LogicalPlan::Projection(ref projection) => {
// A projection is filter-commutable if it do not contain
volatile predicates or contain volatile
// predicates that are not used in the filter. However, we
should re-writes all predicate expressions.
// collect projection.
@@ -742,10 +759,10 @@ impl OptimizerRule for PushDownFilter {
}
}
}
- None => return Ok(None),
+ None => return Ok(Transformed::no(plan)),
}
}
- LogicalPlan::Union(union) => {
+ LogicalPlan::Union(ref union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
@@ -770,7 +787,7 @@ impl OptimizerRule for PushDownFilter {
schema: plan.schema().clone(),
})
}
- LogicalPlan::Aggregate(agg) => {
+ LogicalPlan::Aggregate(ref agg) => {
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
@@ -821,13 +838,15 @@ impl OptimizerRule for PushDownFilter {
None => new_agg,
}
}
- LogicalPlan::Join(join) => {
- match push_down_join(&filter.input, join,
Some(&filter.predicate))? {
- Some(optimized_plan) => optimized_plan,
- None => return Ok(None),
- }
+ LogicalPlan::Join(ref join) => {
+ push_down_join(
+ &unwrap_arc(filter.clone().input),
+ join,
+ Some(&filter.predicate),
+ )?
+ .data
}
- LogicalPlan::CrossJoin(cross_join) => {
+ LogicalPlan::CrossJoin(ref cross_join) => {
let predicates =
split_conjunction_owned(filter.predicate.clone());
let join =
convert_cross_join_to_inner_join(cross_join.clone())?;
let join_plan = LogicalPlan::Join(join);
@@ -843,9 +862,9 @@ impl OptimizerRule for PushDownFilter {
vec![],
true,
)?;
- convert_to_cross_join_if_beneficial(plan)?
+ convert_to_cross_join_if_beneficial(plan.data)?
}
- LogicalPlan::TableScan(scan) => {
+ LogicalPlan::TableScan(ref scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let results = scan
.source
@@ -892,7 +911,7 @@ impl OptimizerRule for PushDownFilter {
None => new_scan,
}
}
- LogicalPlan::Extension(extension_plan) => {
+ LogicalPlan::Extension(ref extension_plan) => {
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();
@@ -935,9 +954,10 @@ impl OptimizerRule for PushDownFilter {
None => new_extension,
}
}
- _ => return Ok(None),
+ _ => return Ok(Transformed::no(plan)),
};
- Ok(Some(new_plan))
+
+ Ok(Transformed::yes(new_plan))
}
}
@@ -1024,16 +1044,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>)
-> bool {
#[cfg(test)]
mod tests {
- use super::*;
use std::any::Any;
use std::fmt::{Debug, Formatter};
- use crate::optimizer::Optimizer;
- use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
- use crate::test::*;
- use crate::OptimizerContext;
-
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+ use async_trait::async_trait;
+
use datafusion_common::ScalarValue;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::table_scan;
@@ -1043,7 +1059,13 @@ mod tests {
Volatility,
};
- use async_trait::async_trait;
+ use crate::optimizer::Optimizer;
+ use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
+ use crate::test::*;
+ use crate::OptimizerContext;
+
+ use super::*;
+
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) ->
Result<()> {
@@ -2298,9 +2320,9 @@ mod tests {
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let optimized_plan = PushDownFilter::new()
- .try_optimize(&plan, &OptimizerContext::new())
+ .rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
- .unwrap();
+ .data;
let expected = "\
Filter: a = Int64(1)\
@@ -2667,8 +2689,9 @@ Projection: a, b
// Originally global state which can help to avoid duplicate Filters
been generated and pushed down.
// Now the global state is removed. Need to double confirm that avoid
duplicate Filters.
let optimized_plan = PushDownFilter::new()
- .try_optimize(&plan, &OptimizerContext::new())?
- .expect("failed to optimize plan");
+ .rewrite(plan, &OptimizerContext::new())
+ .expect("failed to optimize plan")
+ .data;
assert_optimized_plan_eq(optimized_plan, expected)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]