alamb commented on code in PR #5067:
URL: https://github.com/apache/arrow-datafusion/pull/5067#discussion_r1089305897


##########
datafusion/sql/tests/integration_test.rs:
##########
@@ -36,6 +36,12 @@ use datafusion_sql::planner::{ContextProvider, 
ParserOptions, SqlToRel};
 
 use rstest::rstest;
 
+#[cfg(test)]

Review Comment:
   this is to enable debug logging in the sql tests



##########
datafusion/expr/src/expr_rewriter/order_by.rs:
##########
@@ -55,55 +55,122 @@ pub fn rewrite_sort_cols_by_aggs(
 
 fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
     match plan {
-        LogicalPlan::Aggregate(Aggregate {
-            input,
-            aggr_expr,
-            group_expr,
-            ..
-        }) => {
-            struct Rewriter<'a> {
-                plan: &'a LogicalPlan,
-                input: &'a LogicalPlan,
-                aggr_expr: &'a Vec<Expr>,
-                distinct_group_exprs: &'a Vec<Expr>,
-            }
+        LogicalPlan::Aggregate(aggregate) => {
+            rewrite_in_terms_of_aggregate(expr, plan, aggregate)
+        }
+        LogicalPlan::Projection(projection) => {
+            rewrite_in_terms_of_projection(expr, projection)
+        }
+        _ => Ok(expr),
+    }
+}
 
-            impl<'a> ExprRewriter for Rewriter<'a> {
-                fn mutate(&mut self, expr: Expr) -> Result<Expr> {
-                    let normalized_expr = normalize_col(expr.clone(), 
self.plan);
-                    if normalized_expr.is_err() {
-                        // The expr is not based on Aggregate plan output. 
Skip it.
-                        return Ok(expr);
-                    }
-                    let normalized_expr = normalized_expr?;
-                    if let Some(found_agg) = self
-                        .aggr_expr
-                        .iter()
-                        .chain(self.distinct_group_exprs)
-                        .find(|a| (**a) == normalized_expr)
-                    {
-                        let agg = normalize_col(found_agg.clone(), self.plan)?;
-                        let col = Expr::Column(
-                            agg.to_field(self.input.schema())
-                                .map(|f| f.qualified_column())?,
-                        );
-                        Ok(col)
-                    } else {
-                        Ok(expr)
-                    }
-                }
-            }
+/// rewrites a sort expression in terms of the output of an [`Aggregate`].
+///
+/// Note The SQL planner always puts a `Projection` at the output of
+/// an aggregate, the other paths such as LogicalPlanBuilder can
+/// create a Sort directly above an Aggregate
+fn rewrite_in_terms_of_aggregate(

Review Comment:
   this is the same logic, just extracted into a function and using 
`rewrite_expr` to make it easier to read



##########
datafusion/expr/src/expr_rewriter/order_by.rs:
##########
@@ -55,55 +55,121 @@ pub fn rewrite_sort_cols_by_aggs(
 
 fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
     match plan {
-        LogicalPlan::Aggregate(Aggregate {
-            input,
-            aggr_expr,
-            group_expr,
-            ..
-        }) => {
-            struct Rewriter<'a> {
-                plan: &'a LogicalPlan,
-                input: &'a LogicalPlan,
-                aggr_expr: &'a Vec<Expr>,
-                distinct_group_exprs: &'a Vec<Expr>,
-            }
+        LogicalPlan::Aggregate(aggregate) => {
+            rewrite_in_terms_of_aggregate(expr, plan, aggregate)
+        }
+        LogicalPlan::Projection(projection) => {
+            rewrite_in_terms_of_projection(expr, projection)
+        }
+        _ => Ok(expr),
+    }
+}
 
-            impl<'a> ExprRewriter for Rewriter<'a> {
-                fn mutate(&mut self, expr: Expr) -> Result<Expr> {
-                    let normalized_expr = normalize_col(expr.clone(), 
self.plan);
-                    if normalized_expr.is_err() {
-                        // The expr is not based on Aggregate plan output. 
Skip it.
-                        return Ok(expr);
-                    }
-                    let normalized_expr = normalized_expr?;
-                    if let Some(found_agg) = self
-                        .aggr_expr
-                        .iter()
-                        .chain(self.distinct_group_exprs)
-                        .find(|a| (**a) == normalized_expr)
-                    {
-                        let agg = normalize_col(found_agg.clone(), self.plan)?;
-                        let col = Expr::Column(
-                            agg.to_field(self.input.schema())
-                                .map(|f| f.qualified_column())?,
-                        );
-                        Ok(col)
-                    } else {
-                        Ok(expr)
-                    }
-                }
-            }
+/// rewrites a sort expression in terms of the output of an [`Aggregate`].
+///
+///
+fn rewrite_in_terms_of_aggregate(
+    expr: Expr,
+    // the LogicalPlan::Aggregate
+    plan: &LogicalPlan,
+    aggregate: &Aggregate,
+) -> Result<Expr> {
+    let Aggregate {
+        input,
+        aggr_expr,
+        group_expr,
+        ..
+    } = aggregate;
 
-            let distinct_group_exprs = 
grouping_set_to_exprlist(group_expr.as_slice())?;
-            expr.rewrite(&mut Rewriter {
-                plan,
-                input,
-                aggr_expr,
-                distinct_group_exprs: &distinct_group_exprs,
-            })
+    let distinct_group_exprs = 
grouping_set_to_exprlist(group_expr.as_slice())?;
+
+    rewrite_expr(expr, |expr| {
+        // normalize in terms of the input plan
+        let normalized_expr = normalize_col(expr.clone(), plan);
+        if normalized_expr.is_err() {
+            // The expr is not based on Aggregate plan output. Skip it.
+            return Ok(expr);
         }
-        LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, 
plan.inputs()[0]),
-        _ => Ok(expr),
+        let normalized_expr = normalized_expr?;
+        if let Some(found_agg) = aggr_expr
+            .iter()
+            .chain(distinct_group_exprs.iter())
+            .find(|a| (**a) == normalized_expr)
+        {
+            let agg = normalize_col(found_agg.clone(), plan)?;
+            let col =
+                Expr::Column(agg.to_field(input.schema()).map(|f| 
f.qualified_column())?);
+            Ok(col)
+        } else {
+            Ok(expr)
+        }
+    })
+}
+
+/// Rewrites a sort expression in terms of the output of a [`Projection`].
+/// For exmaple, will rewrite an input expression such as
+/// `a + b + c` into `col(a) + col("b + c")`
+///
+/// Remember that:
+/// 1. given a projection with exprs: [a, b + c]
+/// 2. t produces an output schema with two columns "a", "b + c"
+fn rewrite_in_terms_of_projection(expr: Expr, projection: &Projection) -> 
Result<Expr> {

Review Comment:
   this is the fix -- it has slightly different semantics than the raw 
aggregate case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to