This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ad273cab8b Improve unparsing for `ORDER BY`, `UNION`, Windows
functions with Aggregation (#12946)
ad273cab8b is described below
commit ad273cab8bf300a704baf005df072bb980645e51
Author: Sergei Grebnov <[email protected]>
AuthorDate: Thu Oct 17 10:23:29 2024 -0700
Improve unparsing for `ORDER BY`, `UNION`, Windows functions with
Aggregation (#12946)
* Improve unparsing for ORDER BY with Aggregation functions (#38)
* Improve UNION unparsing (#39)
* Scalar functions in ORDER BY unparsing support (#41)
* Improve unparsing for complex Window functions with Aggregation (#42)
* WindowFunction order_by should respect `supports_nulls_first_in_sort`
dialect setting (#43)
* Fix plan_to_sql
* Improve
---
datafusion/sql/src/unparser/expr.rs | 10 ++---
datafusion/sql/src/unparser/plan.rs | 42 +++++++++++++------
datafusion/sql/src/unparser/utils.rs | 69 ++++++++++++++++++++++++-------
datafusion/sql/tests/cases/plan_to_sql.rs | 63 +++++++++++++++++++++++++++-
4 files changed, 148 insertions(+), 36 deletions(-)
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index 1be5aa68bf..8864c97bb1 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -76,11 +76,6 @@ pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
unparser.expr_to_sql(expr)
}
-pub fn sort_to_sql(sort: &Sort) -> Result<ast::OrderByExpr> {
- let unparser = Unparser::default();
- unparser.sort_to_sql(sort)
-}
-
const LOWEST: &BinaryOperator = &BinaryOperator::Or;
// Closest precedence we have to IS operator is BitwiseAnd (any other) in PG
docs
// (https://www.postgresql.org/docs/7.2/sql-precedence.html)
@@ -229,9 +224,10 @@ impl Unparser<'_> {
ast::WindowFrameUnits::Groups
}
};
- let order_by: Vec<ast::OrderByExpr> = order_by
+
+ let order_by = order_by
.iter()
- .map(sort_to_sql)
+ .map(|sort_expr| self.sort_to_sql(sort_expr))
.collect::<Result<Vec<_>>>()?;
let start_bound =
self.convert_bound(&window_frame.start_bound)?;
diff --git a/datafusion/sql/src/unparser/plan.rs
b/datafusion/sql/src/unparser/plan.rs
index 9b4818b98c..c22400f1fa 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -27,7 +27,7 @@ use super::{
},
utils::{
find_agg_node_within_select, find_window_nodes_within_select,
- unproject_window_exprs,
+ unproject_sort_expr, unproject_window_exprs,
},
Unparser,
};
@@ -352,19 +352,30 @@ impl Unparser<'_> {
if select.already_projected() {
return self.derive(plan, relation);
}
- if let Some(query_ref) = query {
- if let Some(fetch) = sort.fetch {
-
query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
- fetch.to_string(),
- false,
- ))));
- }
- query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?);
- } else {
+ let Some(query_ref) = query else {
return internal_err!(
"Sort operator only valid in a statement context."
);
- }
+ };
+
+ if let Some(fetch) = sort.fetch {
+ query_ref.limit(Some(ast::Expr::Value(ast::Value::Number(
+ fetch.to_string(),
+ false,
+ ))));
+ };
+
+ let agg = find_agg_node_within_select(plan,
select.already_projected());
+ // unproject sort expressions
+ let sort_exprs: Vec<SortExpr> = sort
+ .expr
+ .iter()
+ .map(|sort_expr| {
+ unproject_sort_expr(sort_expr, agg,
sort.input.as_ref())
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ query_ref.order_by(self.sorts_to_sql(&sort_exprs)?);
self.select_to_sql_recursively(
sort.input.as_ref(),
@@ -402,7 +413,7 @@ impl Unparser<'_> {
.collect::<Result<Vec<_>>>()?;
if let Some(sort_expr) = &on.sort_expr {
if let Some(query_ref) = query {
-
query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?);
+
query_ref.order_by(self.sorts_to_sql(sort_expr)?);
} else {
return internal_err!(
"Sort operator only valid in a statement
context."
@@ -546,6 +557,11 @@ impl Unparser<'_> {
);
}
+ // Covers cases where the UNION is a subquery and the
projection is at the top level
+ if select.already_projected() {
+ return self.derive(plan, relation);
+ }
+
let input_exprs: Vec<SetExpr> = union
.inputs
.iter()
@@ -691,7 +707,7 @@ impl Unparser<'_> {
}
}
- fn sorts_to_sql(&self, sort_exprs: Vec<SortExpr>) ->
Result<Vec<ast::OrderByExpr>> {
+ fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) ->
Result<Vec<ast::OrderByExpr>> {
sort_exprs
.iter()
.map(|sort_expr| self.sort_to_sql(sort_expr))
diff --git a/datafusion/sql/src/unparser/utils.rs
b/datafusion/sql/src/unparser/utils.rs
index e8c4eca569..5e3a3aa600 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -20,10 +20,11 @@ use std::cmp::Ordering;
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
- Column, DataFusionError, Result, ScalarValue,
+ Column, Result, ScalarValue,
};
use datafusion_expr::{
- utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
+ utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection,
SortExpr,
+ Window,
};
use sqlparser::ast;
@@ -118,21 +119,11 @@ pub(crate) fn unproject_agg_exprs(
if let Expr::Column(c) = sub_expr {
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
Ok(Transformed::yes(unprojected_expr.clone()))
- } else if let Some(mut unprojected_expr) =
+ } else if let Some(unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
- if let Expr::WindowFunction(func) = &mut unprojected_expr {
- // Window function can contain an aggregation column,
e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
- func.args.iter_mut().try_for_each(|arg| {
- if let Expr::Column(c) = arg {
- if let Some(expr) = find_agg_expr(agg, c)? {
- *arg = expr.clone();
- }
- }
- Ok::<(), DataFusionError>(())
- })?;
- }
- Ok(Transformed::yes(unprojected_expr))
+ // Window function can contain an aggregation columns,
e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
+ return
Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?));
} else {
internal_err!(
"Tried to unproject agg expr for column '{}' that was
not found in the provided Aggregate!", &c.name
@@ -200,6 +191,54 @@ fn find_window_expr<'a>(
.find(|expr| expr.schema_name().to_string() == column_name)
}
+/// Transforms a Column expression into the actual expression from aggregation
or projection if found.
+/// This is required because if an ORDER BY expression is present in an
Aggregate or Select, it is replaced
+/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We
need to transform it back to
+/// the actual expression, such as sum("catalog_returns"."cr_net_loss").
+pub(crate) fn unproject_sort_expr(
+ sort_expr: &SortExpr,
+ agg: Option<&Aggregate>,
+ input: &LogicalPlan,
+) -> Result<SortExpr> {
+ let mut sort_expr = sort_expr.clone();
+
+ // Remove alias if present, because ORDER BY cannot use aliases
+ if let Expr::Alias(alias) = &sort_expr.expr {
+ sort_expr.expr = *alias.expr.clone();
+ }
+
+ let Expr::Column(ref col_ref) = sort_expr.expr else {
+ return Ok(sort_expr);
+ };
+
+ if col_ref.relation.is_some() {
+ return Ok(sort_expr);
+ };
+
+ // In case of aggregation there could be columns containing aggregation
functions we need to unproject
+ if let Some(agg) = agg {
+ if agg.schema.is_column_from_schema(col_ref) {
+ let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?;
+ sort_expr.expr = new_expr;
+ return Ok(sort_expr);
+ }
+ }
+
+ // If SELECT and ORDER BY contain the same expression with a scalar
function, the ORDER BY expression will
+ // be replaced by a Column expression (e.g., "substr(customer.c_last_name,
Int64(0), Int64(5))"), and we need
+ // to transform it back to the actual expression.
+ if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input {
+ if let Ok(idx) = schema.index_of_column(col_ref) {
+ if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
+ sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone());
+ }
+ }
+ return Ok(sort_expr);
+ }
+
+ Ok(sort_expr)
+}
+
/// Converts a date_part function to SQL, tailoring it to the supported date
field extraction style.
pub(crate) fn date_part_to_sql(
unparser: &Unparser,
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs
b/datafusion/sql/tests/cases/plan_to_sql.rs
index e4e5d6a929..74abdf075f 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -22,6 +22,9 @@ use arrow_schema::*;
use datafusion_common::{DFSchema, Result, TableReference};
use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf,
sum_udaf};
use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
+use datafusion_functions::unicode;
+use datafusion_functions_aggregate::grouping::grouping_udaf;
+use datafusion_functions_window::rank::rank_udwf;
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
@@ -139,6 +142,13 @@ fn roundtrip_statement() -> Result<()> {
SELECT j2_string as string FROM j2
ORDER BY string DESC
LIMIT 10"#,
+ r#"SELECT col1, id FROM (
+ SELECT j1_string AS col1, j1_id AS id FROM j1
+ UNION ALL
+ SELECT j2_string AS col1, j2_id AS id FROM j2
+ UNION ALL
+ SELECT j3_string AS col1, j3_id AS id FROM j3
+ ) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#,
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person",
@@ -657,7 +667,12 @@ where
.unwrap();
let context = MockContextProvider {
- state: MockSessionState::default(),
+ state: MockSessionState::default()
+ .with_aggregate_function(sum_udaf())
+ .with_aggregate_function(max_udaf())
+ .with_aggregate_function(grouping_udaf())
+ .with_window_function(rank_udwf())
+
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone())),
};
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
@@ -969,3 +984,49 @@ fn test_with_offset0() {
fn test_with_offset95() {
sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET
95");
}
+
+#[test]
+fn test_order_by_to_sql() {
+ // order by aggregation function
+ sql_round_trip(
+ GenericDialect {},
+ r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name
ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#,
+ r#"SELECT person.id, person.first_name, sum(person.id) FROM person
GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST,
person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name
ASC NULLS LAST LIMIT 10"#,
+ );
+
+ // order by aggregation function alias
+ sql_round_trip(
+ GenericDialect {},
+ r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY
id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT
10"#,
+ r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum
FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS
LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST,
person.first_name ASC NULLS LAST LIMIT 10"#,
+ );
+
+ // order by scalar function from projection
+ sql_round_trip(
+ GenericDialect {},
+ r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY
id, substr(first_name,0,5)"#,
+ r#"SELECT person.id, person.first_name, substr(person.first_name, 0,
5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0,
5) ASC NULLS LAST"#,
+ );
+}
+
+#[test]
+fn test_aggregation_to_sql() {
+ sql_round_trip(
+ GenericDialect {},
+ r#"SELECT id, first_name,
+ SUM(id) AS total_sum,
+ SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2
FOLLOWING) AS moving_sum,
+ MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED
PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
+ rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN
grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1,
+ rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN
(CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS
rank_within_parent_2
+ FROM person
+ GROUP BY id, first_name;"#,
+ r#"SELECT person.id, person.first_name,
+sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY
person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum,
+max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total,
+rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE
WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC
NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS
rank_within_parent_1,
+rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE
WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY
sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT
ROW) AS rank_within_parent_2
+FROM person
+GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(),
+ );
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]