peter-toth commented on code in PR #11683:
URL: https://github.com/apache/datafusion/pull/11683#discussion_r1693953169


##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -454,136 +435,169 @@ impl CommonSubexprEliminate {
             group_expr,
             aggr_expr,
             input,
-            schema: orig_schema,
+            schema,
             ..
         } = aggregate;
-        // track transformed information
-        let mut transformed = false;
-
-        let name_perserver = NamePreserver::new_for_projection();
-        let saved_names = aggr_expr
-            .iter()
-            .map(|expr| name_perserver.save(expr))
-            .collect::<Result<Vec<_>>>()?;
-
-        let mut expr_stats = ExprStats::new();
-        // rewrite inputs
-        let (group_found_common, group_arrays) =
-            self.to_arrays(&group_expr, &mut expr_stats, ExprMask::Normal)?;
-        let (aggr_found_common, aggr_arrays) =
-            self.to_arrays(&aggr_expr, &mut expr_stats, ExprMask::Normal)?;
-        let (new_aggr_expr, new_group_expr, new_input) =
-            if group_found_common || aggr_found_common {
-                // rewrite both group exprs and aggr_expr
-                let rewritten = self.rewrite_expr(
-                    // Must clone as Identifiers use references to original 
expressions so
-                    // we have to keep the original expressions intact.
-                    vec![group_expr.clone(), aggr_expr.clone()],
-                    vec![group_arrays, aggr_arrays],
-                    unwrap_arc(input),
-                    &expr_stats,
-                    config,
-                )?;
-                assert!(rewritten.transformed);
-                transformed |= rewritten.transformed;
-                let (mut new_expr, new_input) = rewritten.data;
-
-                // note the reversed pop order.
-                let new_aggr_expr = pop_expr(&mut new_expr)?;
-                let new_group_expr = pop_expr(&mut new_expr)?;
-
-                (new_aggr_expr, new_group_expr, Arc::new(new_input))
-            } else {
-                (aggr_expr, group_expr, input)
-            };
-
-        // create potential projection on top
-        let mut expr_stats = ExprStats::new();
-        let (aggr_found_common, aggr_arrays) = self.to_arrays(
-            &new_aggr_expr,
-            &mut expr_stats,
-            ExprMask::NormalAndAggregates,
-        )?;
-        if aggr_found_common {
-            let mut common_exprs = CommonExprs::new();
-            let mut rewritten_exprs = self.rewrite_exprs_list(
-                // Must clone as Identifiers use references to original 
expressions so we
-                // have to keep the original expressions intact.
-                vec![new_aggr_expr.clone()],
-                vec![aggr_arrays],
-                &expr_stats,
-                &mut common_exprs,
-                &config.alias_generator(),
-            )?;
-            assert!(rewritten_exprs.transformed);
-            let rewritten = pop_expr(&mut rewritten_exprs.data)?;
-
-            assert!(!common_exprs.is_empty());
-            let mut agg_exprs = common_exprs
-                .into_values()
-                .map(|(expr, expr_alias)| expr.alias(expr_alias))
-                .collect::<Vec<_>>();
-
-            let new_input_schema = Arc::clone(new_input.schema());
-            let mut proj_exprs = vec![];
-            for expr in &new_group_expr {
-                extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
-            }
-            for (expr_rewritten, expr_orig) in 
rewritten.into_iter().zip(new_aggr_expr) {
-                if expr_rewritten == expr_orig {
-                    if let Expr::Alias(Alias { expr, name, .. }) = 
expr_rewritten {
-                        agg_exprs.push(expr.alias(&name));
-                        proj_exprs.push(Expr::Column(Column::from_name(name)));
-                    } else {
-                        let expr_alias = 
config.alias_generator().next(CSE_PREFIX);
-                        let (qualifier, field) =
-                            expr_rewritten.to_field(&new_input_schema)?;
-                        let out_name = qualified_name(qualifier.as_ref(), 
field.name());
-
-                        agg_exprs.push(expr_rewritten.alias(&expr_alias));
-                        proj_exprs.push(
-                            
Expr::Column(Column::from_name(expr_alias)).alias(out_name),
-                        );
+        let input = unwrap_arc(input);
+        // Extract common sub-expressions from the aggregate and grouping 
expressions.
+        self.find_common_exprs(vec![group_expr, aggr_expr], config, 
ExprMask::Normal)?
+            .map_data(|(mut new_expr_list, common)| {
+                let new_aggr_expr = new_expr_list.pop().unwrap();
+                let new_group_expr = new_expr_list.pop().unwrap();
+
+                match common {
+                    // If there are common sub-expressions, then insert a 
projection node
+                    // with the common expressions between the new aggregate 
node and the
+                    // original input.
+                    Some((common_exprs, mut expr_list)) => {
+                        build_common_expr_project_plan(input, 
common_exprs).map(
+                            |new_input| {
+                                let aggr_expr = expr_list.pop().unwrap();
+
+                                (
+                                    new_aggr_expr,
+                                    new_group_expr,
+                                    new_input,
+                                    Some(aggr_expr),
+                                )
+                            },
+                        )
                     }
-                } else {
-                    proj_exprs.push(expr_rewritten);
-                }
-            }
 
-            let agg = LogicalPlan::Aggregate(Aggregate::try_new(
-                new_input,
-                new_group_expr,
-                agg_exprs,
-            )?);
-
-            Projection::try_new(proj_exprs, Arc::new(agg))
-                .map(LogicalPlan::Projection)
-                .map(Transformed::yes)
-        } else {
-            // TODO: How exactly can the name or the schema change in this 
case?
-            //  In theory `new_aggr_expr` and `new_group_expr` are either the 
original expressions or they were crafted via `rewrite_expr()`, that keeps the 
original expression names.
-            //  If this is really needed can we have UT for it?
-            // Alias aggregation expressions if they have changed
-            let new_aggr_expr = new_aggr_expr
-                .into_iter()
-                .zip(saved_names.into_iter())
-                .map(|(new_expr, saved_name)| saved_name.restore(new_expr))
-                .collect::<Result<Vec<Expr>>>()?;
-            // Since group_expr may have changed, schema may also. Use try_new 
method.
-            let new_agg = if transformed {
-                Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)?
-            } else {
-                Aggregate::try_new_with_schema(
-                    new_input,
-                    new_group_expr,
-                    new_aggr_expr,
-                    orig_schema,
+                    None => Ok((new_aggr_expr, new_group_expr, input, None)),
+                }
+            })?
+            // Recurse into the new input. this is similar to top-down 
optimizer rule's
+            // logic.
+            .transform_data(|(new_aggr_expr, new_group_expr, new_input, 
aggr_expr)| {
+                self.rewrite(new_input, config)?.map_data(|new_input| {
+                    Ok((
+                        new_aggr_expr,
+                        new_group_expr,
+                        aggr_expr,
+                        Arc::new(new_input),
+                    ))
+                })
+            })?
+            // Try extracting common aggregate expressions and rebuild the 
aggregate node.
+            .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr, 
new_input)| {
+                // Extract common aggregate sub-expressions from the aggregate 
expressions.
+                self.find_common_exprs(
+                    vec![new_aggr_expr],
+                    config,
+                    ExprMask::NormalAndAggregates,
                 )?
-            };
-            let new_agg = LogicalPlan::Aggregate(new_agg);
-
-            Ok(Transformed::new_transformed(new_agg, transformed))
-        }
+                .map_data(|(mut new_aggr_list, common)| {
+                    let rewritten_aggr_expr = new_aggr_list.pop().unwrap();
+
+                    match common {
+                        // If there are common aggregate sub-expressions, then 
insert a
+                        // projection above the new rebuilt aggregate node.
+                        Some((common_aggr_exprs, mut aggr_list)) => {
+                            let new_aggr_expr = aggr_list.pop().unwrap();
+
+                            let mut agg_exprs = common_aggr_exprs

Review Comment:
   This part is basically the same as it was: 
https://github.com/apache/datafusion/pull/11683/files#diff-351499880963d6a383c92e156e75019cd9ce33107724a9635853d7d4cd1898d0L522



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to