This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 256ea91b1 Allow sorting by aggregated groups (#3280)
256ea91b1 is described below

commit 256ea91b1c0864449f9b41520c808695aa00460b
Author: Batuhan Taskaya <[email protected]>
AuthorDate: Tue Aug 30 23:51:05 2022 +0300

    Allow sorting by aggregated groups (#3280)
    
    * Add the test for mix of order by/group by on a complex expr
    
    * Allow sorting by aggregated groups
    
    * Prevent duplicate sort expressions with mismatched alias to be included
---
 datafusion/core/tests/sql/group_by.rs       | 72 +++++++++++++++++++++++++++++
 datafusion/expr/src/expr_rewriter.rs        | 16 +++++--
 datafusion/expr/src/logical_plan/builder.rs | 11 +++--
 3 files changed, 91 insertions(+), 8 deletions(-)

diff --git a/datafusion/core/tests/sql/group_by.rs 
b/datafusion/core/tests/sql/group_by.rs
index e3da1b021..2e1007be8 100644
--- a/datafusion/core/tests/sql/group_by.rs
+++ b/datafusion/core/tests/sql/group_by.rs
@@ -681,3 +681,75 @@ async fn group_by_dictionary() {
     run_test_case::<UInt32Type>().await;
     run_test_case::<UInt64Type>().await;
 }
+
+#[tokio::test]
+async fn csv_query_group_by_order_by_substr() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT substr(c1, 1, 1), avg(c12) \
+        FROM aggregate_test_100 \
+        GROUP BY substr(c1, 1, 1) \
+        ORDER BY substr(c1, 1, 1)";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        
"+-------------------------------------------------+-----------------------------+",
+        "| substr(aggregate_test_100.c1,Int64(1),Int64(1)) | 
AVG(aggregate_test_100.c12) |",
+        
"+-------------------------------------------------+-----------------------------+",
+        "| a                                               | 
0.48754517466109415         |",
+        "| b                                               | 
0.41040709263815384         |",
+        "| c                                               | 
0.6600456536439784          |",
+        "| d                                               | 
0.48855379387549824         |",
+        "| e                                               | 
0.48600669271341534         |",
+        
"+-------------------------------------------------+-----------------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_group_by_order_by_substr_aliased_projection() -> Result<()> 
{
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \
+        FROM aggregate_test_100 \
+        GROUP BY substr(c1, 1, 1) \
+        ORDER BY substr(c1, 1, 1)";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+------+---------------------+",
+        "| name | average             |",
+        "+------+---------------------+",
+        "| a    | 0.48754517466109415 |",
+        "| b    | 0.41040709263815384 |",
+        "| c    | 0.6600456536439784  |",
+        "| d    | 0.48855379387549824 |",
+        "| e    | 0.48600669271341534 |",
+        "+------+---------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \
+        FROM aggregate_test_100 \
+        GROUP BY substr(c1, 1, 1) \
+        ORDER BY avg(c12)";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+------+---------------------+",
+        "| name | average             |",
+        "+------+---------------------+",
+        "| b    | 0.41040709263815384 |",
+        "| e    | 0.48600669271341534 |",
+        "| a    | 0.48754517466109415 |",
+        "| d    | 0.48855379387549824 |",
+        "| c    | 0.6600456536439784  |",
+        "+------+---------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
diff --git a/datafusion/expr/src/expr_rewriter.rs 
b/datafusion/expr/src/expr_rewriter.rs
index e8cf049dd..9e8fa8a7e 100644
--- a/datafusion/expr/src/expr_rewriter.rs
+++ b/datafusion/expr/src/expr_rewriter.rs
@@ -19,6 +19,7 @@
 
 use crate::expr::GroupingSet;
 use crate::logical_plan::Aggregate;
+use crate::utils::grouping_set_to_exprlist;
 use crate::{Expr, ExprSchemable, LogicalPlan};
 use datafusion_common::Result;
 use datafusion_common::{Column, DFSchema};
@@ -325,12 +326,16 @@ pub fn rewrite_sort_cols_by_aggs(
 fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
     match plan {
         LogicalPlan::Aggregate(Aggregate {
-            input, aggr_expr, ..
+            input,
+            aggr_expr,
+            group_expr,
+            ..
         }) => {
             struct Rewriter<'a> {
                 plan: &'a LogicalPlan,
                 input: &'a LogicalPlan,
                 aggr_expr: &'a Vec<Expr>,
+                distinct_group_exprs: &'a Vec<Expr>,
             }
 
             impl<'a> ExprRewriter for Rewriter<'a> {
@@ -341,8 +346,11 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: 
&LogicalPlan) -> Result<Expr> {
                         return Ok(expr);
                     }
                     let normalized_expr = normalized_expr.unwrap();
-                    if let Some(found_agg) =
-                        self.aggr_expr.iter().find(|a| (**a) == 
normalized_expr)
+                    if let Some(found_agg) = self
+                        .aggr_expr
+                        .iter()
+                        .chain(self.distinct_group_exprs)
+                        .find(|a| (**a) == normalized_expr)
                     {
                         let agg = normalize_col(found_agg.clone(), self.plan)?;
                         let col = Expr::Column(
@@ -356,10 +364,12 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan: 
&LogicalPlan) -> Result<Expr> {
                 }
             }
 
+            let distinct_group_exprs = 
grouping_set_to_exprlist(group_expr.as_slice())?;
             expr.rewrite(&mut Rewriter {
                 plan,
                 input,
                 aggr_expr,
+                distinct_group_exprs: &distinct_group_exprs,
             })
         }
         LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, 
plan.inputs()[0]),
diff --git a/datafusion/expr/src/logical_plan/builder.rs 
b/datafusion/expr/src/logical_plan/builder.rs
index 9eb379142..2946a74af 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -335,16 +335,17 @@ impl LogicalPlanBuilder {
                 .iter()
                 .all(|c| input.schema().field_from_column(c).is_ok()) =>
             {
-                let missing_exprs = missing_cols
+                let mut missing_exprs = missing_cols
                     .iter()
                     .map(|c| normalize_col(Expr::Column(c.clone()), &input))
                     .collect::<Result<Vec<_>>>()?;
 
+                // Do not let duplicate columns to be added, some of the
+                // missing_cols may be already present but without the new
+                // projected alias.
+                missing_exprs.retain(|e| !expr.contains(e));
                 expr.extend(missing_exprs);
-
-                Ok(LogicalPlan::Projection(Projection::try_new(
-                    expr, input, alias,
-                )?))
+                Ok(project_with_alias((*input).clone(), expr, alias)?)
             }
             _ => {
                 let new_inputs = curr_plan

Reply via email to