This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 26a8000fe2 Fix group by aliased expression in
LogicalPLanBuilder::aggregate (#8629)
26a8000fe2 is described below
commit 26a8000fe2343e6a187dcd6e4e8fc037d55e213f
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Dec 26 07:04:43 2023 -0500
Fix group by aliased expression in LogicalPLanBuilder::aggregate (#8629)
---
datafusion/core/src/dataframe/mod.rs | 36 +++++++++++++++++-
datafusion/expr/src/logical_plan/builder.rs | 58 +++++++++++++++++++----------
2 files changed, 73 insertions(+), 21 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 2ae4a7c21a..3c3bcd497b 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -1769,8 +1769,8 @@ mod tests {
let df_results = df.collect().await?;
#[rustfmt::skip]
- assert_batches_sorted_eq!(
- [ "+----+",
+ assert_batches_sorted_eq!([
+ "+----+",
"| id |",
"+----+",
"| 1 |",
@@ -1781,6 +1781,38 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_aggregate_alias() -> Result<()> {
+ let df = test_table().await?;
+
+ let df = df
+ // GROUP BY `c2 + 1`
+ .aggregate(vec![col("c2") + lit(1)], vec![])?
+ // SELECT `c2 + 1` as c2
+ .select(vec![(col("c2") + lit(1)).alias("c2")])?
+ // GROUP BY c2 as "c2" (alias in expr is not supported by SQL)
+ .aggregate(vec![col("c2").alias("c2")], vec![])?;
+
+ let df_results = df.collect().await?;
+
+ #[rustfmt::skip]
+ assert_batches_sorted_eq!([
+ "+----+",
+ "| c2 |",
+ "+----+",
+ "| 2 |",
+ "| 3 |",
+ "| 4 |",
+ "| 5 |",
+ "| 6 |",
+ "+----+",
+ ],
+ &df_results
+ );
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_distinct() -> Result<()> {
let t = test_table().await?;
diff --git a/datafusion/expr/src/logical_plan/builder.rs
b/datafusion/expr/src/logical_plan/builder.rs
index 88310dab82..549c25f89b 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -904,27 +904,11 @@ impl LogicalPlanBuilder {
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
- let mut group_expr = normalize_cols(group_expr, &self.plan)?;
+ let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
- // Rewrite groupby exprs according to functional dependencies
- let group_by_expr_names = group_expr
- .iter()
- .map(|group_by_expr| group_by_expr.display_name())
- .collect::<Result<Vec<_>>>()?;
- let schema = self.plan.schema();
- if let Some(target_indices) =
- get_target_functional_dependencies(schema, &group_by_expr_names)
- {
- for idx in target_indices {
- let field = schema.field(idx);
- let expr =
- Expr::Column(Column::new(field.qualifier().cloned(),
field.name()));
- if !group_expr.contains(&expr) {
- group_expr.push(expr);
- }
- }
- }
+ let group_expr =
+ add_group_by_exprs_from_dependencies(group_expr,
self.plan.schema())?;
Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::from)
@@ -1189,6 +1173,42 @@ pub fn build_join_schema(
schema.with_functional_dependencies(func_dependencies)
}
+/// Add additional "synthetic" group by expressions based on functional
+/// dependencies.
+///
+/// For example, if we are grouping on `[c1]`, and we know from
+/// functional dependencies that column `c1` determines `c2`, this function
+/// adds `c2` to the group by list.
+///
+/// This allows MySQL style selects like
+/// `SELECT col FROM t WHERE pk = 5` if col is unique
+fn add_group_by_exprs_from_dependencies(
+ mut group_expr: Vec<Expr>,
+ schema: &DFSchemaRef,
+) -> Result<Vec<Expr>> {
+ // Names of the fields produced by the GROUP BY exprs for example, `GROUP
BY
+ // c1 + 1` produces an output field named `"c1 + 1"`
+ let mut group_by_field_names = group_expr
+ .iter()
+ .map(|e| e.display_name())
+ .collect::<Result<Vec<_>>>()?;
+
+ if let Some(target_indices) =
+ get_target_functional_dependencies(schema, &group_by_field_names)
+ {
+ for idx in target_indices {
+ let field = schema.field(idx);
+ let expr =
+ Expr::Column(Column::new(field.qualifier().cloned(),
field.name()));
+ let expr_name = expr.display_name()?;
+ if !group_by_field_names.contains(&expr_name) {
+ group_by_field_names.push(expr_name);
+ group_expr.push(expr);
+ }
+ }
+ }
+ Ok(group_expr)
+}
/// Errors if one or more expressions have equal names.
pub(crate) fn validate_unique_names<'a>(
node_name: &str,