This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2f55003214 Simplify Expr::map_children (#9876)
2f55003214 is described below
commit 2f550032140d42d1ee6d8ed86f7790766fa7302e
Author: Peter Toth <[email protected]>
AuthorDate: Wed Apr 3 22:20:01 2024 +0200
Simplify Expr::map_children (#9876)
* add map_until_stop_and_collect macro
* fix clippy
* simplify
* Update datafusion/common/src/tree_node.rs
Co-authored-by: Andrew Lamb <[email protected]>
* add documentation
* fix macro
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/tree_node.rs | 82 +++++++++---
datafusion/expr/src/tree_node/expr.rs | 226 ++++++++++++++++------------------
2 files changed, 171 insertions(+), 137 deletions(-)
diff --git a/datafusion/common/src/tree_node.rs
b/datafusion/common/src/tree_node.rs
index 2d653a27c4..554722f37b 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -532,8 +532,20 @@ impl<T> Transformed<T> {
}
}
-/// Transformation helper to process tree nodes that are siblings.
+/// Transformation helper to process a sequence of iterable tree nodes that
are siblings.
pub trait TransformedIterator: Iterator {
+ /// Apples `f` to each item in this iterator
+ ///
+ /// Visits all items in the iterator unless
+ /// `f` returns an error or `f` returns TreeNodeRecursion::stop.
+ ///
+ /// # Returns
+ /// Error if `f` returns an error
+ ///
+ /// Ok(Transformed) such that:
+ /// 1. `transformed` is true if any return from `f` had transformed true
+ /// 2. `data` from the last invocation of `f`
+ /// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator
is empty
fn map_until_stop_and_collect<
F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
>(
@@ -551,22 +563,64 @@ impl<I: Iterator> TransformedIterator for I {
) -> Result<Transformed<Vec<Self::Item>>> {
let mut tnr = TreeNodeRecursion::Continue;
let mut transformed = false;
- let data = self
- .map(|item| match tnr {
- TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
- f(item).map(|result| {
- tnr = result.tnr;
- transformed |= result.transformed;
- result.data
- })
- }
- TreeNodeRecursion::Stop => Ok(item),
- })
- .collect::<Result<Vec<_>>>()?;
- Ok(Transformed::new(data, transformed, tnr))
+ self.map(|item| match tnr {
+ TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
+ f(item).map(|result| {
+ tnr = result.tnr;
+ transformed |= result.transformed;
+ result.data
+ })
+ }
+ TreeNodeRecursion::Stop => Ok(item),
+ })
+ .collect::<Result<Vec<_>>>()
+ .map(|data| Transformed::new(data, transformed, tnr))
}
}
+/// Transformation helper to process a heterogeneous sequence of tree node
containing
+/// expressions.
+/// This macro is very similar to
[TransformedIterator::map_until_stop_and_collect] to
+/// process nodes that are siblings, but it accepts an initial transformation
(`F0`) and
+/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
+/// transformation (`F`).
+///
+/// The macro builds up a tuple that contains `Transformed.data` result of
`F0` as the
+/// first element and further elements from the sequence of pairs. An element
from a pair
+/// is either the value of `EXPR` or the `Transformed.data` result of `F`,
depending on
+/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
+///
+/// # Returns
+/// Error if any of the transformations returns an error
+///
+/// Ok(Transformed<(data0, ..., dataN)>) such that:
+/// 1. `transformed` is true if any of the transformations had transformed true
+/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from
`F0` and
+/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data`
of `F`
+/// 3. `tnr` from `F0` or the last invocation of `F`
+#[macro_export]
+macro_rules! map_until_stop_and_collect {
+ ($F0:expr, $($EXPR:expr, $F:expr),*) => {{
+ $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
+ let all_datas = (
+ data0,
+ $(
+ if tnr == TreeNodeRecursion::Continue || tnr ==
TreeNodeRecursion::Jump {
+ $F.map(|result| {
+ tnr = result.tnr;
+ transformed |= result.transformed;
+ result.data
+ })?
+ } else {
+ $EXPR
+ },
+ )*
+ );
+ Ok(Transformed::new(all_datas, transformed, tnr))
+ })
+ }}
+}
+
/// Transformation helper to access [`Transformed`] fields in a [`Result`]
easily.
pub trait TransformedResult<T> {
fn data(self) -> Result<T>;
diff --git a/datafusion/expr/src/tree_node/expr.rs
b/datafusion/expr/src/tree_node/expr.rs
index 0909d8f662..df1585e5a5 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -27,7 +27,9 @@ use crate::{Expr, GetFieldAccess};
use datafusion_common::tree_node::{
Transformed, TransformedIterator, TreeNode, TreeNodeRecursion,
};
-use datafusion_common::{handle_visit_recursion, internal_err, Result};
+use datafusion_common::{
+ handle_visit_recursion, internal_err, map_until_stop_and_collect, Result,
+};
impl TreeNode for Expr {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
@@ -167,15 +169,14 @@ impl TreeNode for Expr {
Expr::InSubquery(InSubquery::new(be, subquery, negated))
}),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- transform_box(left, &mut f)?
- .update_data(|new_left| (new_left, right))
- .try_transform_node(|(new_left, right)| {
- Ok(transform_box(right, &mut f)?
- .update_data(|new_right| (new_left, new_right)))
- })?
- .update_data(|(new_left, new_right)| {
- Expr::BinaryExpr(BinaryExpr::new(new_left, op,
new_right))
- })
+ map_until_stop_and_collect!(
+ transform_box(left, &mut f),
+ right,
+ transform_box(right, &mut f)
+ )?
+ .update_data(|(new_left, new_right)| {
+ Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
+ })
}
Expr::Like(Like {
negated,
@@ -183,42 +184,40 @@ impl TreeNode for Expr {
pattern,
escape_char,
case_insensitive,
- }) => transform_box(expr, &mut f)?
- .update_data(|new_expr| (new_expr, pattern))
- .try_transform_node(|(new_expr, pattern)| {
- Ok(transform_box(pattern, &mut f)?
- .update_data(|new_pattern| (new_expr, new_pattern)))
- })?
- .update_data(|(new_expr, new_pattern)| {
- Expr::Like(Like::new(
- negated,
- new_expr,
- new_pattern,
- escape_char,
- case_insensitive,
- ))
- }),
+ }) => map_until_stop_and_collect!(
+ transform_box(expr, &mut f),
+ pattern,
+ transform_box(pattern, &mut f)
+ )?
+ .update_data(|(new_expr, new_pattern)| {
+ Expr::Like(Like::new(
+ negated,
+ new_expr,
+ new_pattern,
+ escape_char,
+ case_insensitive,
+ ))
+ }),
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
- }) => transform_box(expr, &mut f)?
- .update_data(|new_expr| (new_expr, pattern))
- .try_transform_node(|(new_expr, pattern)| {
- Ok(transform_box(pattern, &mut f)?
- .update_data(|new_pattern| (new_expr, new_pattern)))
- })?
- .update_data(|(new_expr, new_pattern)| {
- Expr::SimilarTo(Like::new(
- negated,
- new_expr,
- new_pattern,
- escape_char,
- case_insensitive,
- ))
- }),
+ }) => map_until_stop_and_collect!(
+ transform_box(expr, &mut f),
+ pattern,
+ transform_box(pattern, &mut f)
+ )?
+ .update_data(|(new_expr, new_pattern)| {
+ Expr::SimilarTo(Like::new(
+ negated,
+ new_expr,
+ new_pattern,
+ escape_char,
+ case_insensitive,
+ ))
+ }),
Expr::Not(expr) => transform_box(expr, &mut
f)?.update_data(Expr::Not),
Expr::IsNotNull(expr) => {
transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
@@ -248,48 +247,38 @@ impl TreeNode for Expr {
negated,
low,
high,
- }) => transform_box(expr, &mut f)?
- .update_data(|new_expr| (new_expr, low, high))
- .try_transform_node(|(new_expr, low, high)| {
- Ok(transform_box(low, &mut f)?
- .update_data(|new_low| (new_expr, new_low, high)))
- })?
- .try_transform_node(|(new_expr, new_low, high)| {
- Ok(transform_box(high, &mut f)?
- .update_data(|new_high| (new_expr, new_low, new_high)))
- })?
- .update_data(|(new_expr, new_low, new_high)| {
- Expr::Between(Between::new(new_expr, negated, new_low,
new_high))
- }),
+ }) => map_until_stop_and_collect!(
+ transform_box(expr, &mut f),
+ low,
+ transform_box(low, &mut f),
+ high,
+ transform_box(high, &mut f)
+ )?
+ .update_data(|(new_expr, new_low, new_high)| {
+ Expr::Between(Between::new(new_expr, negated, new_low,
new_high))
+ }),
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
- }) => transform_option_box(expr, &mut f)?
- .update_data(|new_expr| (new_expr, when_then_expr, else_expr))
- .try_transform_node(|(new_expr, when_then_expr, else_expr)| {
- Ok(when_then_expr
- .into_iter()
- .map_until_stop_and_collect(|(when, then)| {
- transform_box(when, &mut f)?
- .update_data(|new_when| (new_when, then))
- .try_transform_node(|(new_when, then)| {
- Ok(transform_box(then, &mut f)?
- .update_data(|new_then| (new_when,
new_then)))
- })
- })?
- .update_data(|new_when_then_expr| {
- (new_expr, new_when_then_expr, else_expr)
- }))
- })?
- .try_transform_node(|(new_expr, new_when_then_expr,
else_expr)| {
- Ok(transform_option_box(else_expr, &mut f)?.update_data(
- |new_else_expr| (new_expr, new_when_then_expr,
new_else_expr),
- ))
- })?
- .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
- Expr::Case(Case::new(new_expr, new_when_then_expr,
new_else_expr))
- }),
+ }) => map_until_stop_and_collect!(
+ transform_option_box(expr, &mut f),
+ when_then_expr,
+ when_then_expr
+ .into_iter()
+ .map_until_stop_and_collect(|(when, then)| {
+ map_until_stop_and_collect!(
+ transform_box(when, &mut f),
+ then,
+ transform_box(then, &mut f)
+ )
+ }),
+ else_expr,
+ transform_option_box(else_expr, &mut f)
+ )?
+ .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
+ Expr::Case(Case::new(new_expr, new_when_then_expr,
new_else_expr))
+ }),
Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut
f)?
.update_data(|be| Expr::Cast(Cast::new(be, data_type))),
Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr,
&mut f)?
@@ -320,30 +309,23 @@ impl TreeNode for Expr {
order_by,
window_frame,
null_treatment,
- }) => transform_vec(args, &mut f)?
- .update_data(|new_args| (new_args, partition_by, order_by))
- .try_transform_node(|(new_args, partition_by, order_by)| {
- Ok(transform_vec(partition_by, &mut f)?.update_data(
- |new_partition_by| (new_args, new_partition_by,
order_by),
- ))
- })?
- .try_transform_node(|(new_args, new_partition_by, order_by)| {
- Ok(
- transform_vec(order_by, &mut
f)?.update_data(|new_order_by| {
- (new_args, new_partition_by, new_order_by)
- }),
- )
- })?
- .update_data(|(new_args, new_partition_by, new_order_by)| {
- Expr::WindowFunction(WindowFunction::new(
- fun,
- new_args,
- new_partition_by,
- new_order_by,
- window_frame,
- null_treatment,
- ))
- }),
+ }) => map_until_stop_and_collect!(
+ transform_vec(args, &mut f),
+ partition_by,
+ transform_vec(partition_by, &mut f),
+ order_by,
+ transform_vec(order_by, &mut f)
+ )?
+ .update_data(|(new_args, new_partition_by, new_order_by)| {
+ Expr::WindowFunction(WindowFunction::new(
+ fun,
+ new_args,
+ new_partition_by,
+ new_order_by,
+ window_frame,
+ null_treatment,
+ ))
+ }),
Expr::AggregateFunction(AggregateFunction {
args,
func_def,
@@ -351,17 +333,15 @@ impl TreeNode for Expr {
filter,
order_by,
null_treatment,
- }) => transform_vec(args, &mut f)?
- .update_data(|new_args| (new_args, filter, order_by))
- .try_transform_node(|(new_args, filter, order_by)| {
- Ok(transform_option_box(filter, &mut f)?
- .update_data(|new_filter| (new_args, new_filter,
order_by)))
- })?
- .try_transform_node(|(new_args, new_filter, order_by)| {
- Ok(transform_option_vec(order_by, &mut f)?
- .update_data(|new_order_by| (new_args, new_filter,
new_order_by)))
- })?
- .map_data(|(new_args, new_filter, new_order_by)| match
func_def {
+ }) => map_until_stop_and_collect!(
+ transform_vec(args, &mut f),
+ filter,
+ transform_option_box(filter, &mut f),
+ order_by,
+ transform_option_vec(order_by, &mut f)
+ )?
+ .map_data(
+ |(new_args, new_filter, new_order_by)| match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun,
@@ -385,7 +365,8 @@ impl TreeNode for Expr {
AggregateFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be
resolved.")
}
- })?,
+ },
+ )?,
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
.update_data(|ve|
Expr::GroupingSet(GroupingSet::Rollup(ve))),
@@ -402,15 +383,14 @@ impl TreeNode for Expr {
expr,
list,
negated,
- }) => transform_box(expr, &mut f)?
- .update_data(|new_expr| (new_expr, list))
- .try_transform_node(|(new_expr, list)| {
- Ok(transform_vec(list, &mut f)?
- .update_data(|new_list| (new_expr, new_list)))
- })?
- .update_data(|(new_expr, new_list)| {
- Expr::InList(InList::new(new_expr, new_list, negated))
- }),
+ }) => map_until_stop_and_collect!(
+ transform_box(expr, &mut f),
+ list,
+ transform_vec(list, &mut f)
+ )?
+ .update_data(|(new_expr, new_list)| {
+ Expr::InList(InList::new(new_expr, new_list, negated))
+ }),
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
transform_box(expr, &mut f)?.update_data(|be| {
Expr::GetIndexedField(GetIndexedField::new(be, field))