This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 385d9dbe71 try to remove redundant alias in expression rewriter and
select (#20867)
385d9dbe71 is described below
commit 385d9dbe7106077b8f07922e9f0c2cf6d0d7ce6e
Author: Burak Şen <[email protected]>
AuthorDate: Thu Mar 12 15:48:29 2026 +0300
try to remove redundant alias in expression rewriter and select (#20867)
## Which issue does this PR close?
Not closes
## Rationale for this change
In
https://github.com/apache/datafusion/pull/20780#discussion_r2911482011
@alamb mentioned whether we can remove redundant alias of `count(*) AS
count(*)` to `count(*)` and I tried to give this a go.
### I'm not sure about the implications at the moment it would be great
to have input on this PR
## What changes are included in this PR?
Main changes are in:
- order_by.rs: match only top level expressions instead of recursively
searching sub expressions (otherwise we may match wrong expressions)
- select.rs: strip alias before comparing otherwise we dont use existing
alias at all
## Are these changes tested?
I've added some tests for alias. Existing tests and plan outputs changed
as well you can see in the PR.
## Are there any user-facing changes?
Plans will change but not sure if it has impact
---
datafusion/core/tests/dataframe/mod.rs | 2 +-
datafusion/expr/src/expr_rewriter/order_by.rs | 240 +++++++++++++++++++---
datafusion/sql/src/select.rs | 17 +-
datafusion/sqllogictest/test_files/clickbench.slt | 8 +-
datafusion/sqllogictest/test_files/order.slt | 48 +++++
5 files changed, 279 insertions(+), 36 deletions(-)
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index b1ee8b09b9..80bbde1f6b 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -3004,7 +3004,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
+---------------+------------------------------------------------------------------------------------+
| plan_type | plan
|
+---------------+------------------------------------------------------------------------------------+
- | logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST
|
+ | logical_plan | Sort: count(*) ASC NULLS LAST
|
| | Projection: t1.b, count(Int64(1)) AS count(*)
|
| | Aggregate: groupBy=[[t1.b]],
aggr=[[count(Int64(1))]] |
| | TableScan: t1 projection=[b]
|
diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs
b/datafusion/expr/src/expr_rewriter/order_by.rs
index a897e56d27..720788113c 100644
--- a/datafusion/expr/src/expr_rewriter/order_by.rs
+++ b/datafusion/expr/src/expr_rewriter/order_by.rs
@@ -21,9 +21,7 @@ use crate::expr::Alias;
use crate::expr_rewriter::normalize_col;
use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
-use datafusion_common::tree_node::{
- Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
-};
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{Column, Result};
/// Rewrite sort on aggregate expressions to sort on the column of aggregate
output
@@ -104,29 +102,27 @@ fn rewrite_in_terms_of_projection(
let search_col = Expr::Column(Column::new_unqualified(name));
- // look for the column named the same as this expr
- let mut found = None;
- for proj_expr in proj_exprs {
- proj_expr.apply(|e| {
- if expr_match(&search_col, e) {
- found = Some(e.clone());
- return Ok(TreeNodeRecursion::Stop);
- }
- Ok(TreeNodeRecursion::Continue)
- })?;
- }
+ // Search only top-level projection expressions for a match.
+ // We intentionally avoid a recursive search (e.g. `apply`) to
+ // prevent matching sub-expressions of composites like
+ // `min(c2) + max(c3)` when the ORDER BY is just `min(c2)`.
+ let found = proj_exprs
+ .iter()
+ .find(|proj_expr| expr_match(&search_col, proj_expr));
if let Some(found) = found {
+ let (qualifier, field_name) = found.qualified_name();
+ let col = Expr::Column(Column::new(qualifier, field_name));
return Ok(Transformed::yes(match normalized_expr {
Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
- expr: Box::new(found),
+ expr: Box::new(col),
field,
}),
Expr::TryCast(TryCast { expr: _, field }) =>
Expr::TryCast(TryCast {
- expr: Box::new(found),
+ expr: Box::new(col),
field,
}),
- _ => found,
+ _ => col,
}));
}
@@ -160,7 +156,10 @@ mod test {
use super::*;
use crate::test::function_stub::avg;
+ use crate::test::function_stub::count;
+ use crate::test::function_stub::max;
use crate::test::function_stub::min;
+ use crate::test::function_stub::sum;
#[test]
fn rewrite_sort_cols_by_agg() {
@@ -242,17 +241,14 @@ mod test {
TestCase {
desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named*
"min(t.c2)"!)"#,
input: sort(col("c1") + min(col("c2"))),
- // should be "c1" not t.c1
expected: sort(
col("c1") +
Expr::Column(Column::new_unqualified("min(t.c2)")),
),
},
TestCase {
- desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named*
"avg(t.c3)", aliased)"#,
+ desc: r#"avg(c3) --> "average" (column *named* "average", from
alias)"#,
input: sort(avg(col("c3"))),
- expected: sort(
-
Expr::Column(Column::new_unqualified("avg(t.c3)")).alias("average"),
- ),
+ expected: sort(col("average")),
},
];
@@ -261,6 +257,202 @@ mod test {
}
}
+ /// When an aggregate is aliased in the projection,
+ /// ORDER BY on the original aggregate expression should resolve to
+ /// a Column reference using the alias name — not leak the inner
+ /// Alias expression node or resolve to a descendant subtree.
+ #[test]
+ fn rewrite_sort_resolves_alias_to_column_ref() {
+ let plan = make_input()
+ .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
+ .unwrap()
+ .project(vec![
+ col("c1"),
+ min(col("c2")).alias("min_val"),
+ max(col("c3")).alias("max_val"),
+ ])
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let cases = vec![
+ TestCase {
+ desc: "min(c2) with alias 'min_val' should resolve to
col(min_val)",
+ input: sort(min(col("c2"))),
+ expected: sort(col("min_val")),
+ },
+ TestCase {
+ desc: "max(c3) with alias 'max_val' should resolve to
col(max_val)",
+ input: sort(max(col("c3"))),
+ expected: sort(col("max_val")),
+ },
+ ];
+
+ for case in cases {
+ case.run(&plan)
+ }
+ }
+
+ #[test]
+ fn composite_proj_expr_containing_sort_col_as_subexpr() {
+ let plan = make_input()
+ .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
+ .unwrap()
+ .project(vec![
+ col("c1"),
+ (min(col("c2")) + max(col("c3"))).alias("range"),
+ min(col("c2")).alias("min_val"),
+ max(col("c3")).alias("max_val"),
+ ])
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let cases = vec![
+ TestCase {
+ desc: "sort by min(c2) should resolve to col(min_val), not
col(range)",
+ input: sort(min(col("c2"))),
+ expected: sort(col("min_val")),
+ },
+ TestCase {
+ desc: "sort by max(c3) should resolve to col(max_val), not
col(range)",
+ input: sort(max(col("c3"))),
+ expected: sort(col("max_val")),
+ },
+ ];
+
+ for case in cases {
+ case.run(&plan)
+ }
+ }
+
+ #[test]
+ fn composite_before_standalone_should_not_shadow() {
+ let plan = make_input()
+ .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c2"))])
+ .unwrap()
+ .project(vec![
+ col("c1"),
+ (min(col("c2")) + max(col("c2"))).alias("combined"),
+ min(col("c2")),
+ ])
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let cases = vec![TestCase {
+ desc: "sort by min(c2) should resolve to col(min(t.c2)), not
col(combined)",
+ input: sort(min(col("c2"))),
+ expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
+ }];
+
+ for case in cases {
+ case.run(&plan)
+ }
+ }
+
+ #[test]
+ fn duplicate_aggregate_in_multiple_proj_exprs() {
+ let plan = make_input()
+ .aggregate(vec![col("c1")], vec![min(col("c2"))])
+ .unwrap()
+ .project(vec![
+ col("c1"),
+ min(col("c2")).alias("first_alias"),
+ min(col("c2")).alias("second_alias"),
+ ])
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let cases = vec![TestCase {
+ desc: "sort by min(c2) with two aliases picks first_alias",
+ input: sort(min(col("c2"))),
+ expected: sort(col("first_alias")),
+ }];
+
+ for case in cases {
+ case.run(&plan)
+ }
+ }
+
+ #[test]
+ fn sort_agg_not_in_select_with_aliased_aggs() {
+ let plan = make_input()
+ .aggregate(
+ vec![col("c1")],
+ vec![min(col("c2")), max(col("c3")), sum(col("c3"))],
+ )
+ .unwrap()
+ .project(vec![
+ col("c1"),
+ min(col("c2")).alias("min_val"),
+ max(col("c3")).alias("max_val"),
+ ])
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let cases = vec![TestCase {
+ desc: "sort by sum(c3) not in projection should not be rewritten",
+ input: sort(sum(col("c3"))),
+ expected: sort(sum(col("c3"))),
+ }];
+
+ for case in cases {
+ case.run(&plan)
+ }
+ }
+
+ #[test]
+ fn cast_on_aliased_aggregate() {
+ let plan = make_input()
+ .aggregate(vec![col("c1")], vec![min(col("c2"))])
+ .unwrap()
+ .project(vec![col("c1"), min(col("c2")).alias("min_val")])
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let cases = vec![
+ TestCase {
+ desc: "CAST on aliased aggregate should preserve cast and
resolve alias",
+ input: sort(cast(min(col("c2")), DataType::Int64)),
+ expected: sort(cast(col("min_val"), DataType::Int64)),
+ },
+ TestCase {
+ desc: "TryCast on aliased aggregate should preserve try_cast
and resolve alias",
+ input: sort(try_cast(min(col("c2")), DataType::Int64)),
+ expected: sort(try_cast(col("min_val"), DataType::Int64)),
+ },
+ ];
+
+ for case in cases {
+ case.run(&plan)
+ }
+ }
+
+ #[test]
+ fn count_star_with_alias() {
+ let plan = make_input()
+ .aggregate(vec![col("c1")], vec![count(lit(1))])
+ .unwrap()
+ .project(vec![col("c1"), count(lit(1)).alias("cnt")])
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let cases = vec![TestCase {
+ desc: "sort by count(1) should resolve to cnt alias",
+ input: sort(count(lit(1))),
+ expected: sort(col("cnt")),
+ }];
+
+ for case in cases {
+ case.run(&plan)
+ }
+ }
+
#[test]
fn preserve_cast() {
let plan = make_input()
@@ -275,12 +467,12 @@ mod test {
TestCase {
desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
input: sort(cast(col("c2"), DataType::Int64)),
- expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
+ expected: sort(cast(col("c2"), DataType::Int64)),
},
TestCase {
desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
input: sort(try_cast(col("c2"), DataType::Int64)),
- expected: sort(try_cast(col("c2").alias("c2"),
DataType::Int64)),
+ expected: sort(try_cast(col("c2"), DataType::Int64)),
},
];
diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs
index edf4b9ef79..7e291afa04 100644
--- a/datafusion/sql/src/select.rs
+++ b/datafusion/sql/src/select.rs
@@ -1056,13 +1056,16 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.iter()
.find_map(|select_expr| {
// Only consider aliased expressions
- if let Expr::Alias(alias) = select_expr
- && alias.expr.as_ref() == &rewritten_expr
- {
- // Use the alias name
- return Some(Expr::Column(Column::new_unqualified(
- alias.name.clone(),
- )));
+ if let Expr::Alias(alias) = select_expr {
+ let rewritten_unaliased = match &rewritten_expr {
+ Expr::Alias(a) => a.expr.as_ref(),
+ other => other,
+ };
+ if alias.expr.as_ref() == rewritten_unaliased {
+ return
Some(Expr::Column(Column::new_unqualified(
+ alias.name.clone(),
+ )));
+ }
}
None
})
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt
b/datafusion/sqllogictest/test_files/clickbench.slt
index 881e49cdeb..e14d28d5ef 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -205,7 +205,7 @@ query TT
EXPLAIN SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0
GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC;
----
logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST
+01)Sort: count(*) DESC NULLS FIRST
02)--Projection: hits.AdvEngineID, count(Int64(1)) AS count(*)
03)----Aggregate: groupBy=[[hits.AdvEngineID]], aggr=[[count(Int64(1))]]
04)------SubqueryAlias: hits
@@ -431,7 +431,7 @@ query TT
EXPLAIN SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY
COUNT(*) DESC LIMIT 10;
----
logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
+01)Sort: count(*) DESC NULLS FIRST, fetch=10
02)--Projection: hits.UserID, count(Int64(1)) AS count(*)
03)----Aggregate: groupBy=[[hits.UserID]], aggr=[[count(Int64(1))]]
04)------SubqueryAlias: hits
@@ -459,7 +459,7 @@ query TT
EXPLAIN SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID",
"SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
----
logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
+01)Sort: count(*) DESC NULLS FIRST, fetch=10
02)--Projection: hits.UserID, hits.SearchPhrase, count(Int64(1)) AS count(*)
03)----Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]],
aggr=[[count(Int64(1))]]
04)------SubqueryAlias: hits
@@ -514,7 +514,7 @@ query TT
EXPLAIN SELECT "UserID", extract(minute FROM
to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits
GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
----
logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
+01)Sort: count(*) DESC NULLS FIRST, fetch=10
02)--Projection: hits.UserID,
date_part(Utf8("MINUTE"),to_timestamp_seconds(hits.EventTime)) AS m,
hits.SearchPhrase, count(Int64(1)) AS count(*)
03)----Aggregate: groupBy=[[hits.UserID, date_part(Utf8("MINUTE"),
to_timestamp_seconds(hits.EventTime)), hits.SearchPhrase]],
aggr=[[count(Int64(1))]]
04)------SubqueryAlias: hits
diff --git a/datafusion/sqllogictest/test_files/order.slt
b/datafusion/sqllogictest/test_files/order.slt
index 7c857cae36..892a42ad61 100644
--- a/datafusion/sqllogictest/test_files/order.slt
+++ b/datafusion/sqllogictest/test_files/order.slt
@@ -471,6 +471,54 @@ select column1 from foo order by column2 % 2, column2;
3
5
+# ORDER BY aggregate expression that is aliased in SELECT
+query II
+select column1, min(column2) as min_val from foo group by column1 order by
min(column2);
+----
+1 2
+3 4
+5 6
+
+# ORDER BY aggregate with alias, using DESC
+query II rowsort
+select column1, count(*) as cnt from foo group by column1 order by count(*)
desc;
+----
+1 1
+3 1
+5 1
+
+# ORDER BY aggregate not in SELECT, while other aggregates in SELECT are
aliased
+query I
+select column1 from foo group by column1 order by max(column2);
+----
+1
+3
+5
+
+# SELECT has composite expression containing the aggregate, plus standalone
alias
+query III
+select column1, min(column2) + max(column2) as range_val, min(column2) as
min_val from foo group by column1 order by min(column2);
+----
+1 4 2
+3 8 4
+5 12 6
+
+# ORDER BY aggregate that matches multiple aliased SELECT expressions
+query III
+select column1, min(column2) as first_min, min(column2) as second_min from foo
group by column1 order by min(column2);
+----
+1 2 2
+3 4 4
+5 6 6
+
+# ORDER BY with CAST on aliased aggregate
+query II
+select column1, min(column2) as min_val from foo group by column1 order by
CAST(min(column2) AS BIGINT);
+----
+1 2
+3 4
+5 6
+
# Cleanup
statement ok
drop table foo;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]