jonahgao commented on code in PR #15876: URL: https://github.com/apache/datafusion/pull/15876#discussion_r2079814968
########## datafusion/expr/src/logical_plan/builder.rs: ########## @@ -797,26 +807,146 @@ impl LogicalPlanBuilder { } // remove pushed down sort columns - let new_expr = schema.columns().into_iter().map(Expr::Column).collect(); + let sort_output = schema.columns().into_iter().map(Expr::Column).collect(); + + let plan = Arc::unwrap_or_clone(self.plan); + + let (plan, agg_func_to_col, sorts) = if missing_agg_funcs.is_empty() { + (plan, HashMap::new(), sorts) + } else { + { + let (plan, agg_func_to_col) = + Self::add_missing_agg_funcs_to_logical_agg(plan, &missing_agg_funcs)?; + + let sorts = sorts + .iter() + .map(|x| { + Self::replace_subexpr_to_col(&x.expr, &agg_func_to_col) + .map(|expr| x.with_expr(expr)) + }) + .collect::<Result<Vec<_>, _>>()?; + + (plan, agg_func_to_col, sorts) + } + }; + + // we need downstream filter/project return missing col(agg_funcs) + missing_cols.extend(agg_func_to_col.into_values()); let is_distinct = false; - let plan = Self::add_missing_columns( - Arc::unwrap_or_clone(self.plan), - &missing_cols, - is_distinct, - )?; + let plan = Self::add_missing_columns(plan, &missing_cols, is_distinct)?; let sort_plan = LogicalPlan::Sort(Sort { expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), fetch, }); - Projection::try_new(new_expr, Arc::new(sort_plan)) + Projection::try_new(sort_output, Arc::new(sort_plan)) .map(LogicalPlan::Projection) .map(Self::new) } + fn replace_subexpr_to_col( + expr: &Expr, + func_to_col: &HashMap<Expr, Column>, + ) -> Result<Expr> { + Ok(expr + .clone() + .transform_down(|nested_expr| { + if let Some(col) = func_to_col.get(&nested_expr) { + Ok(Transformed::yes(Expr::Column(col.clone()))) + } else { + Ok(Transformed::no(nested_expr)) + } + })? + .data) + } + + fn add_missing_agg_funcs_to_logical_agg( + plan: LogicalPlan, + missing_agg_funcs: &IndexSet<Expr>, + ) -> Result<(LogicalPlan, HashMap<Expr, Column>)> { + let mut agg_func_to_output_col: HashMap<Expr, Column> = + HashMap::with_capacity(missing_agg_funcs.len()); + + if missing_agg_funcs.is_empty() { + return Ok((plan, agg_func_to_output_col)); + } + + let plan = plan + .transform_down(|plan| { Review Comment: We need to first check whether the current query has a GROUP BY or if there are other aggregate functions in the select list. If there are, we can add the missing aggregate functions. Otherwise, we should throw an error. The current approach does not correctly perform this check because it cannot confirm whether a `LogicalPlan::Aggregate` node in the plan comes from the SELECT list or somewhere else. I created [another example](https://github.com/jonahgao/datafusion/commit/aec8ea022d87915fb23c14b0b7cb39fce59e779a) using DataFrame to illustrate this problem. In this example, the missing `min(a)` is added to the right side of the join. It's a confusing behavior for me. ```sh Projection: b, test.a [b:Utf8, a:Utf8] Sort: min(test.a) ASC NULLS LAST [b:Utf8, a:Utf8, min(test.a):Utf8;N] Inner Join: b = test.a [b:Utf8, a:Utf8, min(test.a):Utf8;N] Projection: test.a AS b [b:Utf8] TableScan: test projection=[a] [a:Utf8] Aggregate: groupBy=[[test.a]], aggr=[[min(test.a)]] [a:Utf8, min(test.a):Utf8;N] TableScan: test projection=[a] [a:Utf8] ``` At an earlier stage, such as during select_to_plan, we are able to know the exact select list. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org