comphead commented on code in PR #10431: URL: https://github.com/apache/datafusion/pull/10431#discussion_r1598634008
########## datafusion/optimizer/src/eliminate_cross_join.rs: ########## @@ -39,102 +41,147 @@ impl EliminateCrossJoin { } } -/// Attempt to reorder join to eliminate cross joins to inner joins. -/// for queries: -/// 'select ... from a, b where a.x = b.y and b.xx = 100;' -/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' -/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) -/// or (a.x = b.y and b.xx = 200 and a.z=c.z);' -/// 'select ... from a, b where a.x > b.y' +/// Eliminate cross joins by rewriting them to inner joins when possible. +/// +/// # Example +/// The initial plan for this query: +/// ```sql +/// select ... from a, b where a.x = b.y and b.xx = 100; +/// ``` +/// +/// Looks like this: +/// ```text +/// Filter(a.x = b.y AND b.xx = 100) +/// CrossJoin +/// TableScan a +/// TableScan b +/// ``` +/// +/// After the rule is applied, the plan will look like this: +/// ```text +/// Filter(b.xx = 100) +/// InnerJoin(a.x = b.y) +/// TableScan a +/// TableScan b +/// ``` +/// +/// # Other Examples +/// * 'select ... from a, b where a.x = b.y and b.xx = 100;' +/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' +/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) +/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// * 'select ... from a, b where a.x > b.y' +/// /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately +/// /// This fix helps to improve the performance of TPCH Q19. issue#78 impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result<Option<LogicalPlan>> { + internal_err!("Should have called EliminateCrossJoin::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result<Transformed<LogicalPlan>> { + let plan_schema = plan.schema().clone(); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec<LogicalPlan> = vec![]; - let parent_predicate = match plan { - LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref(); - match input { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs( - input, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - extract_possible_join_keys( - &filter.predicate, - &mut possible_join_keys, - ); - Some(&filter.predicate) - } - _ => { - return utils::optimize_children(self, plan, config); - } - } + + let can_flatten_inputs = can_flatten_join_inputs(&plan); + let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + // if input isn't a join that can potentially be rewritten + // avoid unwrapping the input + let rewriteable = matches!( + filter.input.as_ref(), + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) | LogicalPlan::CrossJoin(_) + ); + + if !rewriteable { + return rewrite_children(self, LogicalPlan::Filter(filter), config); } + + if !can_flatten_join_inputs(&filter.input) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + let Filter { + input, predicate, .. + } = filter; + flatten_join_inputs( + unwrap_arc(input), + &mut possible_join_keys, + &mut all_inputs, + )?; + + extract_possible_join_keys(&predicate, &mut possible_join_keys); + Some(predicate) + } else if matches!( + plan, LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) => { - if !try_flatten_join_inputs( - plan, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - None + }) + ) { + if !can_flatten_inputs { + return Ok(Transformed::no(plan)); } - _ => return utils::optimize_children(self, plan, config), + flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + None + } else { + return rewrite_children(self, plan, config); }; // Join keys are handled locally: let mut all_join_keys = JoinKeySet::new(); let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { left = find_inner_join( - &left, + left, &mut all_inputs, &possible_join_keys, &mut all_join_keys, )?; } - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + left = rewrite_children(self, left, config)?.data; - if plan.schema() != left.schema() { + if &plan_schema != left.schema() { left = LogicalPlan::Projection(Projection::new_from_schema( Arc::new(left), - plan.schema().clone(), + plan_schema.clone(), )); } let Some(predicate) = parent_predicate else { - return Ok(Some(left)); + return Ok(Transformed::yes(left)); }; // If there are no join keys then do nothing: if all_join_keys.is_empty() { - Filter::try_new(predicate.clone(), Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))) + Filter::try_new(predicate, Arc::new(left)) + .map(LogicalPlan::Filter) Review Comment: can we do it in 1 map iteration ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org