This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new adf0bfc757 Stop copying LogicalPlan and Exprs in `EliminateCrossJoin`
(4% faster planning) (#10431)
adf0bfc757 is described below
commit adf0bfc757d2f9ba48c45d368578d07806858b89
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon May 13 14:00:35 2024 -0400
Stop copying LogicalPlan and Exprs in `EliminateCrossJoin` (4% faster
planning) (#10431)
* Stop copying LogicalPlan and Exprs in `EliminateCrossJoin`
* Clarify when can_flatten_join_inputs runs
* Use a single `map`
---
datafusion/optimizer/src/eliminate_cross_join.rs | 298 +++++++++++++++--------
datafusion/optimizer/src/join_key_set.rs | 73 +++++-
2 files changed, 254 insertions(+), 117 deletions(-)
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs
b/datafusion/optimizer/src/eliminate_cross_join.rs
index 923be75748..9d871c50ad 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -18,11 +18,13 @@
//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join
predicates are available.
use std::sync::Arc;
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
use crate::join_key_set::JoinKeySet;
-use datafusion_common::{plan_err, Result};
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{internal_err, Result};
use datafusion_expr::expr::{BinaryExpr, Expr};
+use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
@@ -39,65 +41,109 @@ impl EliminateCrossJoin {
}
}
-/// Attempt to reorder join to eliminate cross joins to inner joins.
-/// for queries:
-/// 'select ... from a, b where a.x = b.y and b.xx = 100;'
-/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and
b.xx = 200);'
-/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
-/// or (a.x = b.y and b.xx = 200 and a.z=c.z);'
-/// 'select ... from a, b where a.x > b.y'
+/// Eliminate cross joins by rewriting them to inner joins when possible.
+///
+/// # Example
+/// The initial plan for this query:
+/// ```sql
+/// select ... from a, b where a.x = b.y and b.xx = 100;
+/// ```
+///
+/// Looks like this:
+/// ```text
+/// Filter(a.x = b.y AND b.xx = 100)
+/// CrossJoin
+/// TableScan a
+/// TableScan b
+/// ```
+///
+/// After the rule is applied, the plan will look like this:
+/// ```text
+/// Filter(b.xx = 100)
+/// InnerJoin(a.x = b.y)
+/// TableScan a
+/// TableScan b
+/// ```
+///
+/// # Other Examples
+/// * 'select ... from a, b where a.x = b.y and b.xx = 100;'
+/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and
b.xx = 200);'
+/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
+/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);'
+/// * 'select ... from a, b where a.x > b.y'
+///
/// For above queries, the join predicate is available in filters and they are
moved to
/// join nodes appropriately
+///
/// This fix helps to improve the performance of TPCH Q19. issue#78
impl OptimizerRule for EliminateCrossJoin {
fn try_optimize(
&self,
- plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
+ _plan: &LogicalPlan,
+ _config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
+ internal_err!("Should have called EliminateCrossJoin::rewrite")
+ }
+
+ fn supports_rewrite(&self) -> bool {
+ true
+ }
+
+ fn rewrite(
+ &self,
+ plan: LogicalPlan,
+ config: &dyn OptimizerConfig,
+ ) -> Result<Transformed<LogicalPlan>> {
+ let plan_schema = plan.schema().clone();
let mut possible_join_keys = JoinKeySet::new();
let mut all_inputs: Vec<LogicalPlan> = vec![];
- let parent_predicate = match plan {
- LogicalPlan::Filter(filter) => {
- let input = filter.input.as_ref();
- match input {
- LogicalPlan::Join(Join {
- join_type: JoinType::Inner,
- ..
- })
- | LogicalPlan::CrossJoin(_) => {
- if !try_flatten_join_inputs(
- input,
- &mut possible_join_keys,
- &mut all_inputs,
- )? {
- return Ok(None);
- }
- extract_possible_join_keys(
- &filter.predicate,
- &mut possible_join_keys,
- );
- Some(&filter.predicate)
- }
- _ => {
- return utils::optimize_children(self, plan, config);
- }
- }
+
+ let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
+ // if input isn't a join that can potentially be rewritten
+ // avoid unwrapping the input
+ let rewriteable = matches!(
+ filter.input.as_ref(),
+ LogicalPlan::Join(Join {
+ join_type: JoinType::Inner,
+ ..
+ }) | LogicalPlan::CrossJoin(_)
+ );
+
+ if !rewriteable {
+ // recursively try to rewrite children
+ return rewrite_children(self, LogicalPlan::Filter(filter),
config);
}
+
+ if !can_flatten_join_inputs(&filter.input) {
+ return Ok(Transformed::no(LogicalPlan::Filter(filter)));
+ }
+
+ let Filter {
+ input, predicate, ..
+ } = filter;
+ flatten_join_inputs(
+ unwrap_arc(input),
+ &mut possible_join_keys,
+ &mut all_inputs,
+ )?;
+
+ extract_possible_join_keys(&predicate, &mut possible_join_keys);
+ Some(predicate)
+ } else if matches!(
+ plan,
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
- }) => {
- if !try_flatten_join_inputs(
- plan,
- &mut possible_join_keys,
- &mut all_inputs,
- )? {
- return Ok(None);
- }
- None
+ })
+ ) {
+ if !can_flatten_join_inputs(&plan) {
+ return Ok(Transformed::no(plan));
}
- _ => return utils::optimize_children(self, plan, config),
+ flatten_join_inputs(plan, &mut possible_join_keys, &mut
all_inputs)?;
+ None
+ } else {
+ // recursively try to rewrite children
+ return rewrite_children(self, plan, config);
};
// Join keys are handled locally:
@@ -105,36 +151,36 @@ impl OptimizerRule for EliminateCrossJoin {
let mut left = all_inputs.remove(0);
while !all_inputs.is_empty() {
left = find_inner_join(
- &left,
+ left,
&mut all_inputs,
&possible_join_keys,
&mut all_join_keys,
)?;
}
- left = utils::optimize_children(self, &left, config)?.unwrap_or(left);
+ left = rewrite_children(self, left, config)?.data;
- if plan.schema() != left.schema() {
+ if &plan_schema != left.schema() {
left = LogicalPlan::Projection(Projection::new_from_schema(
Arc::new(left),
- plan.schema().clone(),
+ plan_schema.clone(),
));
}
let Some(predicate) = parent_predicate else {
- return Ok(Some(left));
+ return Ok(Transformed::yes(left));
};
// If there are no join keys then do nothing:
if all_join_keys.is_empty() {
- Filter::try_new(predicate.clone(), Arc::new(left))
- .map(|f| Some(LogicalPlan::Filter(f)))
+ Filter::try_new(predicate, Arc::new(left))
+ .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
} else {
// Remove join expressions from filter:
- match remove_join_expressions(predicate.clone(), &all_join_keys) {
+ match remove_join_expressions(predicate, &all_join_keys) {
Some(filter_expr) => Filter::try_new(filter_expr,
Arc::new(left))
- .map(|f| Some(LogicalPlan::Filter(f))),
- _ => Ok(Some(left)),
+ .map(|filter|
Transformed::yes(LogicalPlan::Filter(filter))),
+ _ => Ok(Transformed::yes(left)),
}
}
}
@@ -144,49 +190,89 @@ impl OptimizerRule for EliminateCrossJoin {
}
}
+fn rewrite_children(
+ optimizer: &impl OptimizerRule,
+ plan: LogicalPlan,
+ config: &dyn OptimizerConfig,
+) -> Result<Transformed<LogicalPlan>> {
+ let transformed_plan = plan.map_children(|input| optimizer.rewrite(input,
config))?;
+
+ // recompute schema if the plan was transformed
+ if transformed_plan.transformed {
+ transformed_plan.map_data(|plan| plan.recompute_schema())
+ } else {
+ Ok(transformed_plan)
+ }
+}
+
/// Recursively accumulate possible_join_keys and inputs from inner joins
/// (including cross joins).
///
-/// Returns a boolean indicating whether the flattening was successful.
-fn try_flatten_join_inputs(
- plan: &LogicalPlan,
+/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
+/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
+/// possible_join_keys
+fn flatten_join_inputs(
+ plan: LogicalPlan,
possible_join_keys: &mut JoinKeySet,
all_inputs: &mut Vec<LogicalPlan>,
-) -> Result<bool> {
- let children = match plan {
+) -> Result<()> {
+ match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
+ // checked in can_flatten_join_inputs
if join.filter.is_some() {
- // The filter of inner join will lost, skip this rule.
- // issue: https://github.com/apache/datafusion/issues/4844
- return Ok(false);
+ return internal_err!(
+ "should not have filter in inner join in
flatten_join_inputs"
+ );
}
- possible_join_keys.insert_all(join.on.iter());
- vec![&join.left, &join.right]
+ possible_join_keys.insert_all_owned(join.on);
+ flatten_join_inputs(unwrap_arc(join.left), possible_join_keys,
all_inputs)?;
+ flatten_join_inputs(unwrap_arc(join.right), possible_join_keys,
all_inputs)?;
}
LogicalPlan::CrossJoin(join) => {
- vec![&join.left, &join.right]
+ flatten_join_inputs(unwrap_arc(join.left), possible_join_keys,
all_inputs)?;
+ flatten_join_inputs(unwrap_arc(join.right), possible_join_keys,
all_inputs)?;
}
_ => {
- return plan_err!("flatten_join_inputs just can call
join/cross_join");
+ all_inputs.push(plan);
}
};
+ Ok(())
+}
- for child in children.iter() {
- let child = child.as_ref();
+/// Returns true if the plan is a Join or Cross join could be flattened with
+/// `flatten_join_inputs`
+///
+/// Must stay in sync with `flatten_join_inputs`
+fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
+ // can only flatten inner / cross joins
+ match plan {
+ LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
+ // The filter of inner join will lost, skip this rule.
+ // issue: https://github.com/apache/datafusion/issues/4844
+ if join.filter.is_some() {
+ return false;
+ }
+ }
+ LogicalPlan::CrossJoin(_) => {}
+ _ => return false,
+ };
+
+ for child in plan.inputs() {
match child {
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
| LogicalPlan::CrossJoin(_) => {
- if !try_flatten_join_inputs(child, possible_join_keys,
all_inputs)? {
- return Ok(false);
+ if !can_flatten_join_inputs(child) {
+ return false;
}
}
- _ => all_inputs.push(child.clone()),
+ // the child is not a join/cross join
+ _ => (),
}
}
- Ok(true)
+ true
}
/// Finds the next to join with the left input plan,
@@ -202,7 +288,7 @@ fn try_flatten_join_inputs(
/// 1. Removes the first plan from `rights`
/// 2. Returns `left_input CROSS JOIN right`.
fn find_inner_join(
- left_input: &LogicalPlan,
+ left_input: LogicalPlan,
rights: &mut Vec<LogicalPlan>,
possible_join_keys: &JoinKeySet,
all_join_keys: &mut JoinKeySet,
@@ -237,7 +323,7 @@ fn find_inner_join(
)?);
return Ok(LogicalPlan::Join(Join {
- left: Arc::new(left_input.clone()),
+ left: Arc::new(left_input),
right: Arc::new(right_input),
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
@@ -259,7 +345,7 @@ fn find_inner_join(
)?);
Ok(LogicalPlan::CrossJoin(CrossJoin {
- left: Arc::new(left_input.clone()),
+ left: Arc::new(left_input),
right: Arc::new(right),
schema: join_schema,
}))
@@ -341,12 +427,12 @@ mod tests {
Operator::{And, Or},
};
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) {
+ fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) {
+ let starting_schema = plan.schema().clone();
let rule = EliminateCrossJoin::new();
- let optimized_plan = rule
- .try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan");
+ let transformed_plan = rule.rewrite(plan,
&OptimizerContext::new()).unwrap();
+ assert!(transformed_plan.transformed, "failed to optimize plan");
+ let optimized_plan = transformed_plan.data;
let formatted = optimized_plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -355,13 +441,13 @@ mod tests {
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);
- assert_eq!(plan.schema(), optimized_plan.schema())
+ assert_eq!(&starting_schema, optimized_plan.schema())
}
- fn assert_optimization_rule_fails(plan: &LogicalPlan) {
+ fn assert_optimization_rule_fails(plan: LogicalPlan) {
let rule = EliminateCrossJoin::new();
- let optimized_plan = rule.try_optimize(plan,
&OptimizerContext::new()).unwrap();
- assert!(optimized_plan.is_none());
+ let transformed_plan = rule.rewrite(plan,
&OptimizerContext::new()).unwrap();
+ assert!(!transformed_plan.transformed)
}
#[test]
@@ -386,7 +472,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -414,7 +500,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -441,7 +527,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -471,7 +557,7 @@ mod tests {
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -501,7 +587,7 @@ mod tests {
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -527,7 +613,7 @@ mod tests {
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -551,7 +637,7 @@ mod tests {
.filter(col("t1.a").gt(lit(15u32)))?
.build()?;
- assert_optimization_rule_fails(&plan);
+ assert_optimization_rule_fails(plan);
Ok(())
}
@@ -598,7 +684,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -675,7 +761,7 @@ mod tests {
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -750,7 +836,7 @@ mod tests {
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -825,7 +911,7 @@ mod tests {
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -904,7 +990,7 @@ mod tests {
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -987,7 +1073,7 @@ mod tests {
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -1074,7 +1160,7 @@ mod tests {
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -1100,7 +1186,7 @@ mod tests {
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -1128,7 +1214,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -1156,7 +1242,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -1184,7 +1270,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
@@ -1224,7 +1310,7 @@ mod tests {
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
- assert_optimized_plan_eq(&plan, expected);
+ assert_optimized_plan_eq(plan, expected);
Ok(())
}
diff --git a/datafusion/optimizer/src/join_key_set.rs
b/datafusion/optimizer/src/join_key_set.rs
index c47afa012c..cd8ed382f0 100644
--- a/datafusion/optimizer/src/join_key_set.rs
+++ b/datafusion/optimizer/src/join_key_set.rs
@@ -66,20 +66,46 @@ impl JoinKeySet {
}
}
+ /// Same as [`Self::insert`] but avoids cloning expression if they
+ /// are owned
+ pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool {
+ if self.contains(&left, &right) {
+ false
+ } else {
+ self.inner.insert((left, right));
+ true
+ }
+ }
+
/// Inserts potentially many join keys into the set, copying only when
necessary
///
/// returns true if any of the pairs were inserted
pub fn insert_all<'a>(
&mut self,
- iter: impl Iterator<Item = &'a (Expr, Expr)>,
+ iter: impl IntoIterator<Item = &'a (Expr, Expr)>,
) -> bool {
let mut inserted = false;
- for (left, right) in iter {
+ for (left, right) in iter.into_iter() {
inserted |= self.insert(left, right);
}
inserted
}
+ /// Same as [`Self::insert_all`] but avoids cloning expressions if they are
+ /// already owned
+ ///
+ /// returns true if any of the pairs were inserted
+ pub fn insert_all_owned(
+ &mut self,
+ iter: impl IntoIterator<Item = (Expr, Expr)>,
+ ) -> bool {
+ let mut inserted = false;
+ for (left, right) in iter.into_iter() {
+ inserted |= self.insert_owned(left, right);
+ }
+ inserted
+ }
+
/// Inserts any join keys that are common to both `s1` and `s2` into self
pub fn insert_intersection(&mut self, s1: JoinKeySet, s2: JoinKeySet) {
// note can't use inner.intersection as we need to consider both (l, r)
@@ -156,6 +182,15 @@ mod test {
assert_eq!(set.len(), 2);
}
+ #[test]
+ fn test_insert_owned() {
+ let mut set = JoinKeySet::new();
+ assert!(set.insert_owned(col("a"), col("b")));
+ assert!(set.contains(&col("a"), &col("b")));
+ assert!(set.contains(&col("b"), &col("a")));
+ assert!(!set.contains(&col("a"), &col("c")));
+ }
+
#[test]
fn test_contains() {
let mut set = JoinKeySet::new();
@@ -217,18 +252,34 @@ mod test {
}
#[test]
- fn test_insert_many() {
+ fn test_insert_all() {
let mut set = JoinKeySet::new();
// insert (a=b), (b=c), (b=a)
- set.insert_all(
- vec![
- &(col("a"), col("b")),
- &(col("b"), col("c")),
- &(col("b"), col("a")),
- ]
- .into_iter(),
- );
+ set.insert_all(vec![
+ &(col("a"), col("b")),
+ &(col("b"), col("c")),
+ &(col("b"), col("a")),
+ ]);
+ assert_eq!(set.len(), 2);
+ assert!(set.contains(&col("a"), &col("b")));
+ assert!(set.contains(&col("b"), &col("c")));
+ assert!(set.contains(&col("b"), &col("a")));
+
+ // should not contain (a=c)
+ assert!(!set.contains(&col("a"), &col("c")));
+ }
+
+ #[test]
+ fn test_insert_all_owned() {
+ let mut set = JoinKeySet::new();
+
+ // insert (a=b), (b=c), (b=a)
+ set.insert_all_owned(vec![
+ (col("a"), col("b")),
+ (col("b"), col("c")),
+ (col("b"), col("a")),
+ ]);
assert_eq!(set.len(), 2);
assert!(set.contains(&col("a"), &col("b")));
assert!(set.contains(&col("b"), &col("c")));
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]