This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 7607ace  Fix ORDER BY on aggregate (#1506)
7607ace is described below

commit 7607ace992a5a42840bf546221a8635e70e10885
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Jan 1 04:04:58 2022 -0800

    Fix ORDER BY on aggregate (#1506)
    
    * Fix sort on aggregate
    
    * Use ExprRewriter.
    
    * For review comment
    
    * Update datafusion/src/logical_plan/expr.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Update datafusion/src/logical_plan/expr.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Update datafusion/src/logical_plan/expr.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Fix format.
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/src/logical_plan/builder.rs |  8 ++--
 datafusion/src/logical_plan/expr.rs    | 79 +++++++++++++++++++++++++++++++++-
 datafusion/src/logical_plan/mod.rs     | 10 ++---
 datafusion/tests/sql/order.rs          | 21 +++++++++
 4 files changed, 108 insertions(+), 10 deletions(-)

diff --git a/datafusion/src/logical_plan/builder.rs 
b/datafusion/src/logical_plan/builder.rs
index 90d2ae2..fc60939 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -46,8 +46,8 @@ use std::{
 use super::dfschema::ToDFSchema;
 use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, 
PlanType};
 use crate::logical_plan::{
-    columnize_expr, normalize_col, normalize_cols, Column, CrossJoin, DFField, 
DFSchema,
-    DFSchemaRef, Limit, Partitioning, Repartition, Values,
+    columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, 
Column,
+    CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, 
Repartition, Values,
 };
 use crate::sql::utils::group_window_expr_by_sort_keys;
 
@@ -521,6 +521,8 @@ impl LogicalPlanBuilder {
         &self,
         exprs: impl IntoIterator<Item = impl Into<Expr>> + Clone,
     ) -> Result<Self> {
+        let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?;
+
         let schema = self.plan.schema();
 
         // Collect sort columns that are missing in the input plan's schema
@@ -530,7 +532,7 @@ impl LogicalPlanBuilder {
             .into_iter()
             .try_for_each::<_, Result<()>>(|expr| {
                 let mut columns: HashSet<Column> = HashSet::new();
-                utils::expr_to_columns(&expr.into(), &mut columns)?;
+                utils::expr_to_columns(&expr, &mut columns)?;
 
                 columns.into_iter().for_each(|c| {
                     if schema.field_from_column(&c).is_err() {
diff --git a/datafusion/src/logical_plan/expr.rs 
b/datafusion/src/logical_plan/expr.rs
index fc862cd..dadc168 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -21,7 +21,9 @@
 pub use super::Operator;
 use crate::error::{DataFusionError, Result};
 use crate::field_util::get_indexed_field;
-use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan};
+use crate::logical_plan::{
+    plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan,
+};
 use crate::physical_plan::functions::Volatility;
 use crate::physical_plan::{
     aggregates, expressions::binary_operator_data_type, functions, 
udf::ScalarUDF,
@@ -1306,7 +1308,6 @@ fn normalize_col_with_schemas(
 }
 
 /// Recursively normalize all Column expressions in a list of expression trees
-#[inline]
 pub fn normalize_cols(
     exprs: impl IntoIterator<Item = impl Into<Expr>>,
     plan: &LogicalPlan,
@@ -1317,6 +1318,80 @@ pub fn normalize_cols(
         .collect()
 }
 
+/// Rewrite sort on aggregate expressions to sort on the column of aggregate 
output
+/// For example, `max(x)` is written to `col("MAX(x)")`
+pub fn rewrite_sort_cols_by_aggs(
+    exprs: impl IntoIterator<Item = impl Into<Expr>>,
+    plan: &LogicalPlan,
+) -> Result<Vec<Expr>> {
+    exprs
+        .into_iter()
+        .map(|e| {
+            let expr = e.into();
+            match expr {
+                Expr::Sort {
+                    expr,
+                    asc,
+                    nulls_first,
+                } => {
+                    let sort = Expr::Sort {
+                        expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?),
+                        asc,
+                        nulls_first,
+                    };
+                    Ok(sort)
+                }
+                expr => Ok(expr),
+            }
+        })
+        .collect()
+}
+
+fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
+    match plan {
+        LogicalPlan::Aggregate(Aggregate {
+            input, aggr_expr, ..
+        }) => {
+            struct Rewriter<'a> {
+                plan: &'a LogicalPlan,
+                input: &'a LogicalPlan,
+                aggr_expr: &'a Vec<Expr>,
+            }
+
+            impl<'a> ExprRewriter for Rewriter<'a> {
+                fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+                    let normalized_expr = normalize_col(expr.clone(), 
self.plan);
+                    if normalized_expr.is_err() {
+                        // The expr is not based on Aggregate plan output. 
Skip it.
+                        return Ok(expr);
+                    }
+                    let normalized_expr = normalized_expr.unwrap();
+                    if let Some(found_agg) =
+                        self.aggr_expr.iter().find(|a| (**a) == 
normalized_expr)
+                    {
+                        let agg = normalize_col(found_agg.clone(), self.plan)?;
+                        let col = Expr::Column(
+                            agg.to_field(self.input.schema())
+                                .map(|f| f.qualified_column())?,
+                        );
+                        Ok(col)
+                    } else {
+                        Ok(expr)
+                    }
+                }
+            }
+
+            expr.rewrite(&mut Rewriter {
+                plan,
+                input,
+                aggr_expr,
+            })
+        }
+        LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, 
plan.inputs()[0]),
+        _ => Ok(expr),
+    }
+}
+
 /// Recursively 'unnormalize' (remove all qualifiers) from an
 /// expression tree.
 ///
diff --git a/datafusion/src/logical_plan/mod.rs 
b/datafusion/src/logical_plan/mod.rs
index a20d572..56fec3c 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -42,11 +42,11 @@ pub use expr::{
     create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, 
in_list,
     initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, 
lpad, ltrim,
     max, md5, min, normalize_col, normalize_cols, now, octet_length, or, 
random,
-    regexp_match, regexp_replace, repeat, replace, replace_col, reverse, 
right, round,
-    rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
-    starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, 
unalias,
-    unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
-    ExpressionVisitor, Literal, Recursion, RewriteRecursion,
+    regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
+    rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, 
sha384, sha512,
+    signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, 
to_hex,
+    translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, 
when,
+    Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, 
RewriteRecursion,
 };
 pub use extension::UserDefinedLogicalNode;
 pub use operators::Operator;
diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs
index 631b6af..fa59d9d 100644
--- a/datafusion/tests/sql/order.rs
+++ b/datafusion/tests/sql/order.rs
@@ -33,6 +33,27 @@ async fn test_sort_unprojected_col() -> Result<()> {
 }
 
 #[tokio::test]
+async fn test_order_by_agg_expr() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+    let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+-----------------------------+",
+        "| MIN(aggregate_test_100.c12) |",
+        "+-----------------------------+",
+        "| 0.01479305307777301         |",
+        "+-----------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) + 
0.1";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
 async fn test_nulls_first_asc() -> Result<()> {
     let mut ctx = ExecutionContext::new();
     let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) 
AS t (num,letter) ORDER BY num";

Reply via email to