peter-toth commented on code in PR #11683: URL: https://github.com/apache/datafusion/pull/11683#discussion_r1693953169
########## datafusion/optimizer/src/common_subexpr_eliminate.rs: ########## @@ -454,136 +435,169 @@ impl CommonSubexprEliminate { group_expr, aggr_expr, input, - schema: orig_schema, + schema, .. } = aggregate; - // track transformed information - let mut transformed = false; - - let name_perserver = NamePreserver::new_for_projection(); - let saved_names = aggr_expr - .iter() - .map(|expr| name_perserver.save(expr)) - .collect::<Result<Vec<_>>>()?; - - let mut expr_stats = ExprStats::new(); - // rewrite inputs - let (group_found_common, group_arrays) = - self.to_arrays(&group_expr, &mut expr_stats, ExprMask::Normal)?; - let (aggr_found_common, aggr_arrays) = - self.to_arrays(&aggr_expr, &mut expr_stats, ExprMask::Normal)?; - let (new_aggr_expr, new_group_expr, new_input) = - if group_found_common || aggr_found_common { - // rewrite both group exprs and aggr_expr - let rewritten = self.rewrite_expr( - // Must clone as Identifiers use references to original expressions so - // we have to keep the original expressions intact. - vec![group_expr.clone(), aggr_expr.clone()], - vec![group_arrays, aggr_arrays], - unwrap_arc(input), - &expr_stats, - config, - )?; - assert!(rewritten.transformed); - transformed |= rewritten.transformed; - let (mut new_expr, new_input) = rewritten.data; - - // note the reversed pop order. - let new_aggr_expr = pop_expr(&mut new_expr)?; - let new_group_expr = pop_expr(&mut new_expr)?; - - (new_aggr_expr, new_group_expr, Arc::new(new_input)) - } else { - (aggr_expr, group_expr, input) - }; - - // create potential projection on top - let mut expr_stats = ExprStats::new(); - let (aggr_found_common, aggr_arrays) = self.to_arrays( - &new_aggr_expr, - &mut expr_stats, - ExprMask::NormalAndAggregates, - )?; - if aggr_found_common { - let mut common_exprs = CommonExprs::new(); - let mut rewritten_exprs = self.rewrite_exprs_list( - // Must clone as Identifiers use references to original expressions so we - // have to keep the original expressions intact. - vec![new_aggr_expr.clone()], - vec![aggr_arrays], - &expr_stats, - &mut common_exprs, - &config.alias_generator(), - )?; - assert!(rewritten_exprs.transformed); - let rewritten = pop_expr(&mut rewritten_exprs.data)?; - - assert!(!common_exprs.is_empty()); - let mut agg_exprs = common_exprs - .into_values() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::<Vec<_>>(); - - let new_input_schema = Arc::clone(new_input.schema()); - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &new_input_schema, &mut proj_exprs)? - } - for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { - agg_exprs.push(expr.alias(&name)); - proj_exprs.push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = config.alias_generator().next(CSE_PREFIX); - let (qualifier, field) = - expr_rewritten.to_field(&new_input_schema)?; - let out_name = qualified_name(qualifier.as_ref(), field.name()); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)).alias(out_name), - ); + let input = unwrap_arc(input); + // Extract common sub-expressions from the aggregate and grouping expressions. + self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)? + .map_data(|(mut new_expr_list, common)| { + let new_aggr_expr = new_expr_list.pop().unwrap(); + let new_group_expr = new_expr_list.pop().unwrap(); + + match common { + // If there are common sub-expressions, then insert a projection node + // with the common expressions between the new aggregate node and the + // original input. + Some((common_exprs, mut expr_list)) => { + build_common_expr_project_plan(input, common_exprs).map( + |new_input| { + let aggr_expr = expr_list.pop().unwrap(); + + ( + new_aggr_expr, + new_group_expr, + new_input, + Some(aggr_expr), + ) + }, + ) } - } else { - proj_exprs.push(expr_rewritten); - } - } - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - new_input, - new_group_expr, - agg_exprs, - )?); - - Projection::try_new(proj_exprs, Arc::new(agg)) - .map(LogicalPlan::Projection) - .map(Transformed::yes) - } else { - // TODO: How exactly can the name or the schema change in this case? - // In theory `new_aggr_expr` and `new_group_expr` are either the original expressions or they were crafted via `rewrite_expr()`, that keeps the original expression names. - // If this is really needed can we have UT for it? - // Alias aggregation expressions if they have changed - let new_aggr_expr = new_aggr_expr - .into_iter() - .zip(saved_names.into_iter()) - .map(|(new_expr, saved_name)| saved_name.restore(new_expr)) - .collect::<Result<Vec<Expr>>>()?; - // Since group_expr may have changed, schema may also. Use try_new method. - let new_agg = if transformed { - Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)? - } else { - Aggregate::try_new_with_schema( - new_input, - new_group_expr, - new_aggr_expr, - orig_schema, + None => Ok((new_aggr_expr, new_group_expr, input, None)), + } + })? + // Recurse into the new input. this is similar to top-down optimizer rule's + // logic. + .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok(( + new_aggr_expr, + new_group_expr, + aggr_expr, + Arc::new(new_input), + )) + }) + })? + // Try extracting common aggregate expressions and rebuild the aggregate node. + .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { + // Extract common aggregate sub-expressions from the aggregate expressions. + self.find_common_exprs( + vec![new_aggr_expr], + config, + ExprMask::NormalAndAggregates, )? - }; - let new_agg = LogicalPlan::Aggregate(new_agg); - - Ok(Transformed::new_transformed(new_agg, transformed)) - } + .map_data(|(mut new_aggr_list, common)| { + let rewritten_aggr_expr = new_aggr_list.pop().unwrap(); + + match common { + // If there are common aggregate sub-expressions, then insert a + // projection above the new rebuilt aggregate node. + Some((common_aggr_exprs, mut aggr_list)) => { + let new_aggr_expr = aggr_list.pop().unwrap(); + + let mut agg_exprs = common_aggr_exprs Review Comment: This part is basically the same as it was: https://github.com/apache/datafusion/pull/11683/files#diff-351499880963d6a383c92e156e75019cd9ce33107724a9635853d7d4cd1898d0L522 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org