This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f55003214 Simplify Expr::map_children (#9876)
2f55003214 is described below

commit 2f550032140d42d1ee6d8ed86f7790766fa7302e
Author: Peter Toth <[email protected]>
AuthorDate: Wed Apr 3 22:20:01 2024 +0200

    Simplify Expr::map_children (#9876)
    
    * add map_until_stop_and_collect macro
    
    * fix clippy
    
    * simplify
    
    * Update datafusion/common/src/tree_node.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * add documentation
    
    * fix macro
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/tree_node.rs    |  82 +++++++++---
 datafusion/expr/src/tree_node/expr.rs | 226 ++++++++++++++++------------------
 2 files changed, 171 insertions(+), 137 deletions(-)

diff --git a/datafusion/common/src/tree_node.rs 
b/datafusion/common/src/tree_node.rs
index 2d653a27c4..554722f37b 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -532,8 +532,20 @@ impl<T> Transformed<T> {
     }
 }
 
-/// Transformation helper to process tree nodes that are siblings.
+/// Transformation helper to process a sequence of iterable tree nodes that 
are siblings.
 pub trait TransformedIterator: Iterator {
+    /// Apples `f` to each item in this iterator
+    ///
+    /// Visits all items in the iterator unless
+    /// `f` returns an error or `f` returns TreeNodeRecursion::stop.
+    ///
+    /// # Returns
+    /// Error if `f` returns an error
+    ///
+    /// Ok(Transformed) such that:
+    /// 1. `transformed` is true if any return from `f` had transformed true
+    /// 2. `data` from the last invocation of `f`
+    /// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator 
is empty
     fn map_until_stop_and_collect<
         F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
     >(
@@ -551,22 +563,64 @@ impl<I: Iterator> TransformedIterator for I {
     ) -> Result<Transformed<Vec<Self::Item>>> {
         let mut tnr = TreeNodeRecursion::Continue;
         let mut transformed = false;
-        let data = self
-            .map(|item| match tnr {
-                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
-                    f(item).map(|result| {
-                        tnr = result.tnr;
-                        transformed |= result.transformed;
-                        result.data
-                    })
-                }
-                TreeNodeRecursion::Stop => Ok(item),
-            })
-            .collect::<Result<Vec<_>>>()?;
-        Ok(Transformed::new(data, transformed, tnr))
+        self.map(|item| match tnr {
+            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
+                f(item).map(|result| {
+                    tnr = result.tnr;
+                    transformed |= result.transformed;
+                    result.data
+                })
+            }
+            TreeNodeRecursion::Stop => Ok(item),
+        })
+        .collect::<Result<Vec<_>>>()
+        .map(|data| Transformed::new(data, transformed, tnr))
     }
 }
 
+/// Transformation helper to process a heterogeneous sequence of tree node 
containing
+/// expressions.
+/// This macro is very similar to 
[TransformedIterator::map_until_stop_and_collect] to
+/// process nodes that are siblings, but it accepts an initial transformation 
(`F0`) and
+/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
+/// transformation (`F`).
+///
+/// The macro builds up a tuple that contains `Transformed.data` result of 
`F0` as the
+/// first element and further elements from the sequence of pairs. An element 
from a pair
+/// is either the value of `EXPR` or the `Transformed.data` result of `F`, 
depending on
+/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
+///
+/// # Returns
+/// Error if any of the transformations returns an error
+///
+/// Ok(Transformed<(data0, ..., dataN)>) such that:
+/// 1. `transformed` is true if any of the transformations had transformed true
+/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from 
`F0` and
+///     `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` 
of `F`
+/// 3. `tnr` from `F0` or the last invocation of `F`
+#[macro_export]
+macro_rules! map_until_stop_and_collect {
+    ($F0:expr, $($EXPR:expr, $F:expr),*) => {{
+        $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
+            let all_datas = (
+                data0,
+                $(
+                    if tnr == TreeNodeRecursion::Continue || tnr == 
TreeNodeRecursion::Jump {
+                        $F.map(|result| {
+                            tnr = result.tnr;
+                            transformed |= result.transformed;
+                            result.data
+                        })?
+                    } else {
+                        $EXPR
+                    },
+                )*
+            );
+            Ok(Transformed::new(all_datas, transformed, tnr))
+        })
+    }}
+}
+
 /// Transformation helper to access [`Transformed`] fields in a [`Result`] 
easily.
 pub trait TransformedResult<T> {
     fn data(self) -> Result<T>;
diff --git a/datafusion/expr/src/tree_node/expr.rs 
b/datafusion/expr/src/tree_node/expr.rs
index 0909d8f662..df1585e5a5 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -27,7 +27,9 @@ use crate::{Expr, GetFieldAccess};
 use datafusion_common::tree_node::{
     Transformed, TransformedIterator, TreeNode, TreeNodeRecursion,
 };
-use datafusion_common::{handle_visit_recursion, internal_err, Result};
+use datafusion_common::{
+    handle_visit_recursion, internal_err, map_until_stop_and_collect, Result,
+};
 
 impl TreeNode for Expr {
     fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
@@ -167,15 +169,14 @@ impl TreeNode for Expr {
                 Expr::InSubquery(InSubquery::new(be, subquery, negated))
             }),
             Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-                transform_box(left, &mut f)?
-                    .update_data(|new_left| (new_left, right))
-                    .try_transform_node(|(new_left, right)| {
-                        Ok(transform_box(right, &mut f)?
-                            .update_data(|new_right| (new_left, new_right)))
-                    })?
-                    .update_data(|(new_left, new_right)| {
-                        Expr::BinaryExpr(BinaryExpr::new(new_left, op, 
new_right))
-                    })
+                map_until_stop_and_collect!(
+                    transform_box(left, &mut f),
+                    right,
+                    transform_box(right, &mut f)
+                )?
+                .update_data(|(new_left, new_right)| {
+                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
+                })
             }
             Expr::Like(Like {
                 negated,
@@ -183,42 +184,40 @@ impl TreeNode for Expr {
                 pattern,
                 escape_char,
                 case_insensitive,
-            }) => transform_box(expr, &mut f)?
-                .update_data(|new_expr| (new_expr, pattern))
-                .try_transform_node(|(new_expr, pattern)| {
-                    Ok(transform_box(pattern, &mut f)?
-                        .update_data(|new_pattern| (new_expr, new_pattern)))
-                })?
-                .update_data(|(new_expr, new_pattern)| {
-                    Expr::Like(Like::new(
-                        negated,
-                        new_expr,
-                        new_pattern,
-                        escape_char,
-                        case_insensitive,
-                    ))
-                }),
+            }) => map_until_stop_and_collect!(
+                transform_box(expr, &mut f),
+                pattern,
+                transform_box(pattern, &mut f)
+            )?
+            .update_data(|(new_expr, new_pattern)| {
+                Expr::Like(Like::new(
+                    negated,
+                    new_expr,
+                    new_pattern,
+                    escape_char,
+                    case_insensitive,
+                ))
+            }),
             Expr::SimilarTo(Like {
                 negated,
                 expr,
                 pattern,
                 escape_char,
                 case_insensitive,
-            }) => transform_box(expr, &mut f)?
-                .update_data(|new_expr| (new_expr, pattern))
-                .try_transform_node(|(new_expr, pattern)| {
-                    Ok(transform_box(pattern, &mut f)?
-                        .update_data(|new_pattern| (new_expr, new_pattern)))
-                })?
-                .update_data(|(new_expr, new_pattern)| {
-                    Expr::SimilarTo(Like::new(
-                        negated,
-                        new_expr,
-                        new_pattern,
-                        escape_char,
-                        case_insensitive,
-                    ))
-                }),
+            }) => map_until_stop_and_collect!(
+                transform_box(expr, &mut f),
+                pattern,
+                transform_box(pattern, &mut f)
+            )?
+            .update_data(|(new_expr, new_pattern)| {
+                Expr::SimilarTo(Like::new(
+                    negated,
+                    new_expr,
+                    new_pattern,
+                    escape_char,
+                    case_insensitive,
+                ))
+            }),
             Expr::Not(expr) => transform_box(expr, &mut 
f)?.update_data(Expr::Not),
             Expr::IsNotNull(expr) => {
                 transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
@@ -248,48 +247,38 @@ impl TreeNode for Expr {
                 negated,
                 low,
                 high,
-            }) => transform_box(expr, &mut f)?
-                .update_data(|new_expr| (new_expr, low, high))
-                .try_transform_node(|(new_expr, low, high)| {
-                    Ok(transform_box(low, &mut f)?
-                        .update_data(|new_low| (new_expr, new_low, high)))
-                })?
-                .try_transform_node(|(new_expr, new_low, high)| {
-                    Ok(transform_box(high, &mut f)?
-                        .update_data(|new_high| (new_expr, new_low, new_high)))
-                })?
-                .update_data(|(new_expr, new_low, new_high)| {
-                    Expr::Between(Between::new(new_expr, negated, new_low, 
new_high))
-                }),
+            }) => map_until_stop_and_collect!(
+                transform_box(expr, &mut f),
+                low,
+                transform_box(low, &mut f),
+                high,
+                transform_box(high, &mut f)
+            )?
+            .update_data(|(new_expr, new_low, new_high)| {
+                Expr::Between(Between::new(new_expr, negated, new_low, 
new_high))
+            }),
             Expr::Case(Case {
                 expr,
                 when_then_expr,
                 else_expr,
-            }) => transform_option_box(expr, &mut f)?
-                .update_data(|new_expr| (new_expr, when_then_expr, else_expr))
-                .try_transform_node(|(new_expr, when_then_expr, else_expr)| {
-                    Ok(when_then_expr
-                        .into_iter()
-                        .map_until_stop_and_collect(|(when, then)| {
-                            transform_box(when, &mut f)?
-                                .update_data(|new_when| (new_when, then))
-                                .try_transform_node(|(new_when, then)| {
-                                    Ok(transform_box(then, &mut f)?
-                                        .update_data(|new_then| (new_when, 
new_then)))
-                                })
-                        })?
-                        .update_data(|new_when_then_expr| {
-                            (new_expr, new_when_then_expr, else_expr)
-                        }))
-                })?
-                .try_transform_node(|(new_expr, new_when_then_expr, 
else_expr)| {
-                    Ok(transform_option_box(else_expr, &mut f)?.update_data(
-                        |new_else_expr| (new_expr, new_when_then_expr, 
new_else_expr),
-                    ))
-                })?
-                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
-                    Expr::Case(Case::new(new_expr, new_when_then_expr, 
new_else_expr))
-                }),
+            }) => map_until_stop_and_collect!(
+                transform_option_box(expr, &mut f),
+                when_then_expr,
+                when_then_expr
+                    .into_iter()
+                    .map_until_stop_and_collect(|(when, then)| {
+                        map_until_stop_and_collect!(
+                            transform_box(when, &mut f),
+                            then,
+                            transform_box(then, &mut f)
+                        )
+                    }),
+                else_expr,
+                transform_option_box(else_expr, &mut f)
+            )?
+            .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
+                Expr::Case(Case::new(new_expr, new_when_then_expr, 
new_else_expr))
+            }),
             Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut 
f)?
                 .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
             Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, 
&mut f)?
@@ -320,30 +309,23 @@ impl TreeNode for Expr {
                 order_by,
                 window_frame,
                 null_treatment,
-            }) => transform_vec(args, &mut f)?
-                .update_data(|new_args| (new_args, partition_by, order_by))
-                .try_transform_node(|(new_args, partition_by, order_by)| {
-                    Ok(transform_vec(partition_by, &mut f)?.update_data(
-                        |new_partition_by| (new_args, new_partition_by, 
order_by),
-                    ))
-                })?
-                .try_transform_node(|(new_args, new_partition_by, order_by)| {
-                    Ok(
-                        transform_vec(order_by, &mut 
f)?.update_data(|new_order_by| {
-                            (new_args, new_partition_by, new_order_by)
-                        }),
-                    )
-                })?
-                .update_data(|(new_args, new_partition_by, new_order_by)| {
-                    Expr::WindowFunction(WindowFunction::new(
-                        fun,
-                        new_args,
-                        new_partition_by,
-                        new_order_by,
-                        window_frame,
-                        null_treatment,
-                    ))
-                }),
+            }) => map_until_stop_and_collect!(
+                transform_vec(args, &mut f),
+                partition_by,
+                transform_vec(partition_by, &mut f),
+                order_by,
+                transform_vec(order_by, &mut f)
+            )?
+            .update_data(|(new_args, new_partition_by, new_order_by)| {
+                Expr::WindowFunction(WindowFunction::new(
+                    fun,
+                    new_args,
+                    new_partition_by,
+                    new_order_by,
+                    window_frame,
+                    null_treatment,
+                ))
+            }),
             Expr::AggregateFunction(AggregateFunction {
                 args,
                 func_def,
@@ -351,17 +333,15 @@ impl TreeNode for Expr {
                 filter,
                 order_by,
                 null_treatment,
-            }) => transform_vec(args, &mut f)?
-                .update_data(|new_args| (new_args, filter, order_by))
-                .try_transform_node(|(new_args, filter, order_by)| {
-                    Ok(transform_option_box(filter, &mut f)?
-                        .update_data(|new_filter| (new_args, new_filter, 
order_by)))
-                })?
-                .try_transform_node(|(new_args, new_filter, order_by)| {
-                    Ok(transform_option_vec(order_by, &mut f)?
-                        .update_data(|new_order_by| (new_args, new_filter, 
new_order_by)))
-                })?
-                .map_data(|(new_args, new_filter, new_order_by)| match 
func_def {
+            }) => map_until_stop_and_collect!(
+                transform_vec(args, &mut f),
+                filter,
+                transform_option_box(filter, &mut f),
+                order_by,
+                transform_option_vec(order_by, &mut f)
+            )?
+            .map_data(
+                |(new_args, new_filter, new_order_by)| match func_def {
                     AggregateFunctionDefinition::BuiltIn(fun) => {
                         Ok(Expr::AggregateFunction(AggregateFunction::new(
                             fun,
@@ -385,7 +365,8 @@ impl TreeNode for Expr {
                     AggregateFunctionDefinition::Name(_) => {
                         internal_err!("Function `Expr` with name should be 
resolved.")
                     }
-                })?,
+                },
+            )?,
             Expr::GroupingSet(grouping_set) => match grouping_set {
                 GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
                     .update_data(|ve| 
Expr::GroupingSet(GroupingSet::Rollup(ve))),
@@ -402,15 +383,14 @@ impl TreeNode for Expr {
                 expr,
                 list,
                 negated,
-            }) => transform_box(expr, &mut f)?
-                .update_data(|new_expr| (new_expr, list))
-                .try_transform_node(|(new_expr, list)| {
-                    Ok(transform_vec(list, &mut f)?
-                        .update_data(|new_list| (new_expr, new_list)))
-                })?
-                .update_data(|(new_expr, new_list)| {
-                    Expr::InList(InList::new(new_expr, new_list, negated))
-                }),
+            }) => map_until_stop_and_collect!(
+                transform_box(expr, &mut f),
+                list,
+                transform_vec(list, &mut f)
+            )?
+            .update_data(|(new_expr, new_list)| {
+                Expr::InList(InList::new(new_expr, new_list, negated))
+            }),
             Expr::GetIndexedField(GetIndexedField { expr, field }) => {
                 transform_box(expr, &mut f)?.update_data(|be| {
                     Expr::GetIndexedField(GetIndexedField::new(be, field))

Reply via email to