This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e4f7b9811f feat: support unparsing LogicalPlan::Window nodes (#10767)
e4f7b9811f is described below
commit e4f7b9811f245f0ccf8d0289f7d5edfe1499947a
Author: Devin D'Angelo <[email protected]>
AuthorDate: Mon Jun 3 15:00:32 2024 -0400
feat: support unparsing LogicalPlan::Window nodes (#10767)
* unparse window plans
* new tests + fixes
* fmt
---
datafusion/sql/src/unparser/expr.rs | 32 +++++++++-----
datafusion/sql/src/unparser/plan.rs | 71 +++++++++++++++++++++----------
datafusion/sql/src/unparser/utils.rs | 68 +++++++++++++++++++++++------
datafusion/sql/tests/cases/plan_to_sql.rs | 8 +++-
4 files changed, 132 insertions(+), 47 deletions(-)
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index df390ce6ea..1ba6638e73 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -236,8 +236,8 @@ impl Unparser<'_> {
.map(|expr| expr_to_unparsed(expr)?.into_order_by_expr())
.collect::<Result<Vec<_>>>()?;
- let start_bound =
self.convert_bound(&window_frame.start_bound);
- let end_bound = self.convert_bound(&window_frame.end_bound);
+ let start_bound =
self.convert_bound(&window_frame.start_bound)?;
+ let end_bound = self.convert_bound(&window_frame.end_bound)?;
let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
window_name: None,
partition_by: partition_by
@@ -513,20 +513,30 @@ impl Unparser<'_> {
fn convert_bound(
&self,
bound: &datafusion_expr::window_frame::WindowFrameBound,
- ) -> ast::WindowFrameBound {
+ ) -> Result<ast::WindowFrameBound> {
match bound {
datafusion_expr::window_frame::WindowFrameBound::Preceding(val) =>
{
- ast::WindowFrameBound::Preceding(
- self.scalar_to_sql(val).map(Box::new).ok(),
- )
+ Ok(ast::WindowFrameBound::Preceding({
+ let val = self.scalar_to_sql(val)?;
+ if let ast::Expr::Value(ast::Value::Null) = &val {
+ None
+ } else {
+ Some(Box::new(val))
+ }
+ }))
}
datafusion_expr::window_frame::WindowFrameBound::Following(val) =>
{
- ast::WindowFrameBound::Following(
- self.scalar_to_sql(val).map(Box::new).ok(),
- )
+ Ok(ast::WindowFrameBound::Following({
+ let val = self.scalar_to_sql(val)?;
+ if let ast::Expr::Value(ast::Value::Null) = &val {
+ None
+ } else {
+ Some(Box::new(val))
+ }
+ }))
}
datafusion_expr::window_frame::WindowFrameBound::CurrentRow => {
- ast::WindowFrameBound::CurrentRow
+ Ok(ast::WindowFrameBound::CurrentRow)
}
}
}
@@ -1148,7 +1158,7 @@ mod tests {
window_frame: WindowFrame::new(None),
null_treatment: None,
}),
- r#"ROW_NUMBER(col) OVER (ROWS BETWEEN NULL PRECEDING AND NULL
FOLLOWING)"#,
+ r#"ROW_NUMBER(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND
UNBOUNDED FOLLOWING)"#,
),
(
Expr::WindowFunction(WindowFunction {
diff --git a/datafusion/sql/src/unparser/plan.rs
b/datafusion/sql/src/unparser/plan.rs
index e7e4d7700a..183bb1f7fb 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -28,7 +28,7 @@ use super::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
- utils::find_agg_node_within_select,
+ utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Unparser,
};
@@ -162,23 +162,42 @@ impl Unparser<'_> {
// A second projection implies a derived tablefactor
if !select.already_projected() {
// Special handling when projecting an agregation plan
- if let Some(agg) = find_agg_node_within_select(plan, true)
{
- let items = p
- .expr
- .iter()
- .map(|proj_expr| {
- let unproj = unproject_agg_exprs(proj_expr,
agg)?;
- self.select_item_to_sql(&unproj)
- })
- .collect::<Result<Vec<_>>>()?;
-
- select.projection(items);
- select.group_by(ast::GroupByExpr::Expressions(
- agg.group_expr
- .iter()
- .map(|expr| self.expr_to_sql(expr))
- .collect::<Result<Vec<_>>>()?,
- ));
+ if let Some(aggvariant) =
+ find_agg_node_within_select(plan, None, true)
+ {
+ match aggvariant {
+ AggVariant::Aggregate(agg) => {
+ let items = p
+ .expr
+ .iter()
+ .map(|proj_expr| {
+ let unproj =
unproject_agg_exprs(proj_expr, agg)?;
+ self.select_item_to_sql(&unproj)
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ select.projection(items);
+ select.group_by(ast::GroupByExpr::Expressions(
+ agg.group_expr
+ .iter()
+ .map(|expr| self.expr_to_sql(expr))
+ .collect::<Result<Vec<_>>>()?,
+ ));
+ }
+ AggVariant::Window(window) => {
+ let items = p
+ .expr
+ .iter()
+ .map(|proj_expr| {
+ let unproj =
+ unproject_window_exprs(proj_expr,
&window)?;
+ self.select_item_to_sql(&unproj)
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ select.projection(items);
+ }
+ }
} else {
let items = p
.expr
@@ -210,8 +229,8 @@ impl Unparser<'_> {
}
}
LogicalPlan::Filter(filter) => {
- if let Some(agg) =
- find_agg_node_within_select(plan,
select.already_projected())
+ if let Some(AggVariant::Aggregate(agg)) =
+ find_agg_node_within_select(plan, None,
select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate,
agg)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
@@ -265,7 +284,7 @@ impl Unparser<'_> {
)
}
LogicalPlan::Aggregate(agg) => {
- // Aggregate nodes are handled simulatenously with Projection
nodes
+ // Aggregate nodes are handled simultaneously with Projection
nodes
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
@@ -441,8 +460,14 @@ impl Unparser<'_> {
Ok(())
}
- LogicalPlan::Window(_window) => {
- not_impl_err!("Unsupported operator: {plan:?}")
+ LogicalPlan::Window(window) => {
+ // Window nodes are handled simultaneously with Projection
nodes
+ self.select_to_sql_recursively(
+ window.input.as_ref(),
+ query,
+ select,
+ relation,
+ )
}
LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator:
{plan:?}"),
_ => not_impl_err!("Unsupported operator: {plan:?}"),
diff --git a/datafusion/sql/src/unparser/utils.rs
b/datafusion/sql/src/unparser/utils.rs
index c1b02c330f..326cd15ba1 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -20,16 +20,24 @@ use datafusion_common::{
tree_node::{Transformed, TreeNode},
Result,
};
-use datafusion_expr::{Aggregate, Expr, LogicalPlan};
+use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};
-/// Recursively searches children of [LogicalPlan] to find an Aggregate node
if one exists
+/// One of the possible aggregation plans which can be found within a single
select query.
+pub(crate) enum AggVariant<'a> {
+ Aggregate(&'a Aggregate),
+ Window(Vec<&'a Window>),
+}
+
+/// Recursively searches children of [LogicalPlan] to find an Aggregate or
window node if one exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived
table factor).
-/// If an Aggregate node is not found prior to this or at all before reaching
the end
-/// of the tree, None is returned.
-pub(crate) fn find_agg_node_within_select(
- plan: &LogicalPlan,
+/// If an Aggregate or window node is not found prior to this or at all before
reaching the end
+/// of the tree, None is returned. It is assumed that a Window and Aggegate
node cannot both
+/// be found in a single select query.
+pub(crate) fn find_agg_node_within_select<'a>(
+ plan: &'a LogicalPlan,
+ mut prev_windows: Option<AggVariant<'a>>,
already_projected: bool,
-) -> Option<&Aggregate> {
+) -> Option<AggVariant<'a>> {
// Note that none of the nodes that have a corresponding agg node can have
more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
@@ -38,18 +46,29 @@ pub(crate) fn find_agg_node_within_select(
} else {
input.first()?
};
+ // Agg nodes explicitly return immediately with a single node
+ // Window nodes accumulate in a vec until encountering a TableScan or 2nd
projection
if let LogicalPlan::Aggregate(agg) = input {
- Some(agg)
+ Some(AggVariant::Aggregate(agg))
+ } else if let LogicalPlan::Window(window) = input {
+ prev_windows = match &mut prev_windows {
+ Some(AggVariant::Window(windows)) => {
+ windows.push(window);
+ prev_windows
+ }
+ _ => Some(AggVariant::Window(vec![window])),
+ };
+ find_agg_node_within_select(input, prev_windows, already_projected)
} else if let LogicalPlan::TableScan(_) = input {
- None
+ prev_windows
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
- None
+ prev_windows
} else {
- find_agg_node_within_select(input, true)
+ find_agg_node_within_select(input, prev_windows, true)
}
} else {
- find_agg_node_within_select(input, already_projected)
+ find_agg_node_within_select(input, prev_windows, already_projected)
}
}
@@ -82,3 +101,28 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg:
&Aggregate) -> Result<Expr>
})
.map(|e| e.data)
}
+
+/// Recursively identify all Column expressions and transform them into the
appropriate
+/// window expression contained in window.
+///
+/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id"
it will be transformed
+/// into an actual window expression as identified in the window node.
+pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) ->
Result<Expr> {
+ expr.clone()
+ .transform(|sub_expr| {
+ if let Expr::Column(c) = sub_expr {
+ if let Some(unproj) = windows
+ .iter()
+ .flat_map(|w| w.window_expr.iter())
+ .find(|window_expr| window_expr.display_name().unwrap() ==
c.name)
+ {
+ Ok(Transformed::yes(unproj.clone()))
+ } else {
+ Ok(Transformed::no(Expr::Column(c)))
+ }
+ } else {
+ Ok(Transformed::no(sub_expr))
+ }
+ })
+ .map(|e| e.data)
+}
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 1bf441351a..4a430bdc80 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -127,7 +127,13 @@ fn roundtrip_statement() -> Result<()> {
UNION ALL
SELECT j2_string as string FROM j2
ORDER BY string DESC
- LIMIT 10"#
+ LIMIT 10"#,
+ "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
+ last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
+ first_name from person",
+ r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED
PRECEDING AND UNBOUNDED FOLLOWING),
+ sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED
PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
+ "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5
PRECEDING AND 2 FOLLOWING) from person",
];
// For each test sql string, we transform as follows:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]