jonahgao commented on code in PR #15876: URL: https://github.com/apache/datafusion/pull/15876#discussion_r2075625947
########## datafusion/expr/src/logical_plan/builder.rs: ########## @@ -797,26 +807,146 @@ impl LogicalPlanBuilder { } // remove pushed down sort columns - let new_expr = schema.columns().into_iter().map(Expr::Column).collect(); + let sort_output = schema.columns().into_iter().map(Expr::Column).collect(); + + let plan = Arc::unwrap_or_clone(self.plan); + + let (plan, agg_func_to_col, sorts) = if missing_agg_funcs.is_empty() { + (plan, HashMap::new(), sorts) + } else { + { + let (plan, agg_func_to_col) = + Self::add_missing_agg_funcs_to_logical_agg(plan, &missing_agg_funcs)?; + + let sorts = sorts + .iter() + .map(|x| { + Self::replace_subexpr_to_col(&x.expr, &agg_func_to_col) + .map(|expr| x.with_expr(expr)) + }) + .collect::<Result<Vec<_>, _>>()?; + + (plan, agg_func_to_col, sorts) + } + }; + + // we need downstream filter/project return missing col(agg_funcs) + missing_cols.extend(agg_func_to_col.into_values()); let is_distinct = false; - let plan = Self::add_missing_columns( - Arc::unwrap_or_clone(self.plan), - &missing_cols, - is_distinct, - )?; + let plan = Self::add_missing_columns(plan, &missing_cols, is_distinct)?; let sort_plan = LogicalPlan::Sort(Sort { expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), fetch, }); - Projection::try_new(new_expr, Arc::new(sort_plan)) + Projection::try_new(sort_output, Arc::new(sort_plan)) .map(LogicalPlan::Projection) .map(Self::new) } + fn replace_subexpr_to_col( + expr: &Expr, + func_to_col: &HashMap<Expr, Column>, + ) -> Result<Expr> { + Ok(expr + .clone() + .transform_down(|nested_expr| { + if let Some(col) = func_to_col.get(&nested_expr) { + Ok(Transformed::yes(Expr::Column(col.clone()))) + } else { + Ok(Transformed::no(nested_expr)) + } + })? + .data) + } + + fn add_missing_agg_funcs_to_logical_agg( + plan: LogicalPlan, + missing_agg_funcs: &IndexSet<Expr>, + ) -> Result<(LogicalPlan, HashMap<Expr, Column>)> { + let mut agg_func_to_output_col: HashMap<Expr, Column> = + HashMap::with_capacity(missing_agg_funcs.len()); + + if missing_agg_funcs.is_empty() { + return Ok((plan, agg_func_to_output_col)); + } + + let plan = plan + .transform_down(|plan| { Review Comment: We should only consider aggregates from the select list of the current query. Using traversal does not guarantee this, meaning that we might encounter aggregate functions not originating from the select list, leading to a bad logical plan. For example, the following query is illegal, but it can pass in this PR. ```sh > create table t(a int); 0 row(s) fetched. Elapsed 0.007 seconds. > select * from t, (select max(a) from t) order by max(a); +---+----------+ | a | max(t.a) | +---+----------+ +---+----------+ 0 row(s) fetched. Elapsed 0.017 seconds. ``` There may be more complex cases with similar problems. I think a safer approach is to directly add the missing aggregate functions to the select list at an earlier stage. That is, transform `select min(a) from t order by max(a)` into `select min(a), max(a) from t order by max(a)` , similar to what was done in #14180. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org