This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 7607ace Fix ORDER BY on aggregate (#1506)
7607ace is described below
commit 7607ace992a5a42840bf546221a8635e70e10885
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Jan 1 04:04:58 2022 -0800
Fix ORDER BY on aggregate (#1506)
* Fix sort on aggregate
* Use ExprRewriter.
* For review comment
* Update datafusion/src/logical_plan/expr.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Update datafusion/src/logical_plan/expr.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Update datafusion/src/logical_plan/expr.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Fix format.
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/src/logical_plan/builder.rs | 8 ++--
datafusion/src/logical_plan/expr.rs | 79 +++++++++++++++++++++++++++++++++-
datafusion/src/logical_plan/mod.rs | 10 ++---
datafusion/tests/sql/order.rs | 21 +++++++++
4 files changed, 108 insertions(+), 10 deletions(-)
diff --git a/datafusion/src/logical_plan/builder.rs
b/datafusion/src/logical_plan/builder.rs
index 90d2ae2..fc60939 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -46,8 +46,8 @@ use std::{
use super::dfschema::ToDFSchema;
use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan,
PlanType};
use crate::logical_plan::{
- columnize_expr, normalize_col, normalize_cols, Column, CrossJoin, DFField,
DFSchema,
- DFSchemaRef, Limit, Partitioning, Repartition, Values,
+ columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs,
Column,
+ CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning,
Repartition, Values,
};
use crate::sql::utils::group_window_expr_by_sort_keys;
@@ -521,6 +521,8 @@ impl LogicalPlanBuilder {
&self,
exprs: impl IntoIterator<Item = impl Into<Expr>> + Clone,
) -> Result<Self> {
+ let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?;
+
let schema = self.plan.schema();
// Collect sort columns that are missing in the input plan's schema
@@ -530,7 +532,7 @@ impl LogicalPlanBuilder {
.into_iter()
.try_for_each::<_, Result<()>>(|expr| {
let mut columns: HashSet<Column> = HashSet::new();
- utils::expr_to_columns(&expr.into(), &mut columns)?;
+ utils::expr_to_columns(&expr, &mut columns)?;
columns.into_iter().for_each(|c| {
if schema.field_from_column(&c).is_err() {
diff --git a/datafusion/src/logical_plan/expr.rs
b/datafusion/src/logical_plan/expr.rs
index fc862cd..dadc168 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -21,7 +21,9 @@
pub use super::Operator;
use crate::error::{DataFusionError, Result};
use crate::field_util::get_indexed_field;
-use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan};
+use crate::logical_plan::{
+ plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan,
+};
use crate::physical_plan::functions::Volatility;
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions,
udf::ScalarUDF,
@@ -1306,7 +1308,6 @@ fn normalize_col_with_schemas(
}
/// Recursively normalize all Column expressions in a list of expression trees
-#[inline]
pub fn normalize_cols(
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
@@ -1317,6 +1318,80 @@ pub fn normalize_cols(
.collect()
}
+/// Rewrite sort on aggregate expressions to sort on the column of aggregate
output
+/// For example, `max(x)` is written to `col("MAX(x)")`
+pub fn rewrite_sort_cols_by_aggs(
+ exprs: impl IntoIterator<Item = impl Into<Expr>>,
+ plan: &LogicalPlan,
+) -> Result<Vec<Expr>> {
+ exprs
+ .into_iter()
+ .map(|e| {
+ let expr = e.into();
+ match expr {
+ Expr::Sort {
+ expr,
+ asc,
+ nulls_first,
+ } => {
+ let sort = Expr::Sort {
+ expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?),
+ asc,
+ nulls_first,
+ };
+ Ok(sort)
+ }
+ expr => Ok(expr),
+ }
+ })
+ .collect()
+}
+
+fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
+ match plan {
+ LogicalPlan::Aggregate(Aggregate {
+ input, aggr_expr, ..
+ }) => {
+ struct Rewriter<'a> {
+ plan: &'a LogicalPlan,
+ input: &'a LogicalPlan,
+ aggr_expr: &'a Vec<Expr>,
+ }
+
+ impl<'a> ExprRewriter for Rewriter<'a> {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ let normalized_expr = normalize_col(expr.clone(),
self.plan);
+ if normalized_expr.is_err() {
+ // The expr is not based on Aggregate plan output.
Skip it.
+ return Ok(expr);
+ }
+ let normalized_expr = normalized_expr.unwrap();
+ if let Some(found_agg) =
+ self.aggr_expr.iter().find(|a| (**a) ==
normalized_expr)
+ {
+ let agg = normalize_col(found_agg.clone(), self.plan)?;
+ let col = Expr::Column(
+ agg.to_field(self.input.schema())
+ .map(|f| f.qualified_column())?,
+ );
+ Ok(col)
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ expr.rewrite(&mut Rewriter {
+ plan,
+ input,
+ aggr_expr,
+ })
+ }
+ LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr,
plan.inputs()[0]),
+ _ => Ok(expr),
+ }
+}
+
/// Recursively 'unnormalize' (remove all qualifiers) from an
/// expression tree.
///
diff --git a/datafusion/src/logical_plan/mod.rs
b/datafusion/src/logical_plan/mod.rs
index a20d572..56fec3c 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -42,11 +42,11 @@ pub use expr::{
create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor,
in_list,
initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower,
lpad, ltrim,
max, md5, min, normalize_col, normalize_cols, now, octet_length, or,
random,
- regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
right, round,
- rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
- starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
unalias,
- unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
- ExpressionVisitor, Literal, Recursion, RewriteRecursion,
+ regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
+ rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256,
sha384, sha512,
+ signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan,
to_hex,
+ translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper,
when,
+ Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion,
RewriteRecursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs
index 631b6af..fa59d9d 100644
--- a/datafusion/tests/sql/order.rs
+++ b/datafusion/tests/sql/order.rs
@@ -33,6 +33,27 @@ async fn test_sort_unprojected_col() -> Result<()> {
}
#[tokio::test]
+async fn test_order_by_agg_expr() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+ let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+-----------------------------+",
+ "| MIN(aggregate_test_100.c12) |",
+ "+-----------------------------+",
+ "| 0.01479305307777301 |",
+ "+-----------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) +
0.1";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
async fn test_nulls_first_asc() -> Result<()> {
let mut ctx = ExecutionContext::new();
let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three'))
AS t (num,letter) ORDER BY num";