alamb commented on code in PR #9999:
URL: https://github.com/apache/arrow-datafusion/pull/9999#discussion_r1557339008


##########
datafusion/expr/src/logical_plan/tree_node.rs:
##########
@@ -32,23 +39,364 @@ impl TreeNode for LogicalPlan {
         self.inputs().into_iter().apply_until_stop(f)
     }
 
-    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
-        self,
-        f: F,
-    ) -> Result<Transformed<Self>> {
-        let new_children = self
-            .inputs()
-            .into_iter()
-            .cloned()
-            .map_until_stop_and_collect(f)?;
-        // Propagate up `new_children.transformed` and `new_children.tnr`
-        // along with the node containing transformed children.
-        if new_children.transformed {
-            new_children.map_data(|new_children| {
-                self.with_new_exprs(self.expressions(), new_children)
-            })
-        } else {
-            Ok(new_children.update_data(|_| self))
-        }
+    /// Applies `f` to each child (input) of this plan node, rewriting them 
*in place.*
+    ///
+    /// # Notes
+    ///
+    /// Inputs include ONLY direct children, not embedded `LogicalPlan`s for
+    /// subqueries, for example such as are in [`Expr::Exists`].
+    ///
+    /// [`Expr::Exists`]: crate::Expr::Exists
+    fn map_children<F>(self, mut f: F) -> Result<Transformed<Self>>
+    where
+        F: FnMut(Self) -> Result<Transformed<Self>>,
+    {
+        Ok(match self {
+            LogicalPlan::Projection(Projection {
+                expr,
+                input,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Projection(Projection {
+                    expr,
+                    input,
+                    schema,
+                })
+            }),
+            LogicalPlan::Filter(Filter { predicate, input }) => 
rewrite_arc(input, f)?
+                .update_data(|input| LogicalPlan::Filter(Filter { predicate, 
input })),
+            LogicalPlan::Repartition(Repartition {
+                input,
+                partitioning_scheme,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Repartition(Repartition {
+                    input,
+                    partitioning_scheme,
+                })
+            }),
+            LogicalPlan::Window(Window {
+                input,
+                window_expr,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Window(Window {
+                    input,
+                    window_expr,
+                    schema,
+                })
+            }),
+            LogicalPlan::Aggregate(Aggregate {
+                input,
+                group_expr,
+                aggr_expr,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Aggregate(Aggregate {
+                    input,
+                    group_expr,
+                    aggr_expr,
+                    schema,
+                })
+            }),
+            LogicalPlan::Sort(Sort { expr, input, fetch }) => 
rewrite_arc(input, f)?
+                .update_data(|input| LogicalPlan::Sort(Sort { expr, input, 
fetch })),
+            LogicalPlan::Join(Join {
+                left,
+                right,
+                on,
+                filter,
+                join_type,
+                join_constraint,
+                schema,
+                null_equals_null,
+            }) => map_until_stop_and_collect!(
+                rewrite_arc(left, &mut f),
+                right,
+                rewrite_arc(right, &mut f)
+            )?
+            .update_data(|(left, right)| {
+                LogicalPlan::Join(Join {
+                    left,
+                    right,
+                    on,
+                    filter,
+                    join_type,
+                    join_constraint,
+                    schema,
+                    null_equals_null,
+                })
+            }),
+            LogicalPlan::CrossJoin(CrossJoin {
+                left,
+                right,
+                schema,
+            }) => map_until_stop_and_collect!(
+                rewrite_arc(left, &mut f),
+                right,
+                rewrite_arc(right, &mut f)
+            )?
+            .update_data(|(left, right)| {
+                LogicalPlan::CrossJoin(CrossJoin {
+                    left,
+                    right,
+                    schema,
+                })
+            }),
+            LogicalPlan::Limit(Limit { skip, fetch, input }) => 
rewrite_arc(input, f)?
+                .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, 
input })),
+            LogicalPlan::Subquery(Subquery {
+                subquery,
+                outer_ref_columns,
+            }) => rewrite_arc(subquery, f)?.update_data(|subquery| {
+                LogicalPlan::Subquery(Subquery {
+                    subquery,
+                    outer_ref_columns,
+                })
+            }),
+            LogicalPlan::SubqueryAlias(SubqueryAlias {
+                input,
+                alias,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::SubqueryAlias(SubqueryAlias {
+                    input,
+                    alias,
+                    schema,
+                })
+            }),
+            LogicalPlan::Extension(extension) => 
rewrite_extension_inputs(extension, f)?
+                .update_data(LogicalPlan::Extension),
+            LogicalPlan::Union(Union { inputs, schema }) => 
rewrite_arcs(inputs, f)?
+                .update_data(|inputs| LogicalPlan::Union(Union { inputs, 
schema })),
+            LogicalPlan::Distinct(distinct) => match distinct {
+                Distinct::All(input) => rewrite_arc(input, 
f)?.update_data(Distinct::All),
+                Distinct::On(DistinctOn {
+                    on_expr,
+                    select_expr,
+                    sort_expr,
+                    input,
+                    schema,
+                }) => rewrite_arc(input, f)?.update_data(|input| {
+                    Distinct::On(DistinctOn {
+                        on_expr,
+                        select_expr,
+                        sort_expr,
+                        input,
+                        schema,
+                    })
+                }),
+            }
+            .update_data(LogicalPlan::Distinct),
+            LogicalPlan::Explain(Explain {
+                verbose,
+                plan,
+                stringified_plans,
+                schema,
+                logical_optimization_succeeded,
+            }) => rewrite_arc(plan, f)?.update_data(|plan| {
+                LogicalPlan::Explain(Explain {
+                    verbose,
+                    plan,
+                    stringified_plans,
+                    schema,
+                    logical_optimization_succeeded,
+                })
+            }),
+            LogicalPlan::Analyze(Analyze {
+                verbose,
+                input,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Analyze(Analyze {
+                    verbose,
+                    input,
+                    schema,
+                })
+            }),
+            LogicalPlan::Dml(DmlStatement {
+                table_name,
+                table_schema,
+                op,
+                input,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Dml(DmlStatement {
+                    table_name,
+                    table_schema,
+                    op,
+                    input,
+                })
+            }),
+            LogicalPlan::Copy(CopyTo {
+                input,
+                output_url,
+                partition_by,
+                format_options,
+                options,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Copy(CopyTo {
+                    input,
+                    output_url,
+                    partition_by,
+                    format_options,
+                    options,
+                })
+            }),
+            LogicalPlan::Ddl(ddl) => {
+                match ddl {
+                    DdlStatement::CreateMemoryTable(CreateMemoryTable {
+                        name,
+                        constraints,
+                        input,
+                        if_not_exists,
+                        or_replace,
+                        column_defaults,
+                    }) => rewrite_arc(input, f)?.update_data(|input| {
+                        DdlStatement::CreateMemoryTable(CreateMemoryTable {
+                            name,
+                            constraints,
+                            input,
+                            if_not_exists,
+                            or_replace,
+                            column_defaults,
+                        })
+                    }),
+                    DdlStatement::CreateView(CreateView {
+                        name,
+                        input,
+                        or_replace,
+                        definition,
+                    }) => rewrite_arc(input, f)?.update_data(|input| {
+                        DdlStatement::CreateView(CreateView {
+                            name,
+                            input,
+                            or_replace,
+                            definition,
+                        })
+                    }),
+                    // no inputs in these statements
+                    DdlStatement::CreateExternalTable(_)
+                    | DdlStatement::CreateCatalogSchema(_)
+                    | DdlStatement::CreateCatalog(_)
+                    | DdlStatement::DropTable(_)
+                    | DdlStatement::DropView(_)
+                    | DdlStatement::DropCatalogSchema(_)
+                    | DdlStatement::CreateFunction(_)
+                    | DdlStatement::DropFunction(_) => Transformed::no(ddl),
+                }
+                .update_data(LogicalPlan::Ddl)
+            }
+            LogicalPlan::Unnest(Unnest {
+                input,
+                column,
+                schema,
+                options,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Unnest(Unnest {
+                    input,
+                    column,
+                    schema,
+                    options,
+                })
+            }),
+            LogicalPlan::Prepare(Prepare {
+                name,
+                data_types,
+                input,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Prepare(Prepare {
+                    name,
+                    data_types,
+                    input,
+                })
+            }),
+            LogicalPlan::RecursiveQuery(RecursiveQuery {
+                name,
+                static_term,
+                recursive_term,
+                is_distinct,
+            }) => map_until_stop_and_collect!(
+                rewrite_arc(static_term, &mut f),
+                recursive_term,
+                rewrite_arc(recursive_term, &mut f)
+            )?
+            .update_data(|(static_term, recursive_term)| {
+                LogicalPlan::RecursiveQuery(RecursiveQuery {
+                    name,
+                    static_term,
+                    recursive_term,
+                    is_distinct,
+                })
+            }),
+            // plans without inputs
+            LogicalPlan::TableScan { .. }
+            | LogicalPlan::Statement { .. }
+            | LogicalPlan::EmptyRelation { .. }
+            | LogicalPlan::Values { .. }
+            | LogicalPlan::DescribeTable(_) => Transformed::no(self),
+        })
     }
 }
+
+/// Converts a `Arc<LogicalPlan>` without copying, if possible. Copies the plan
+/// if there is a shared reference
+fn unwrap_arc(plan: Arc<LogicalPlan>) -> LogicalPlan {
+    Arc::try_unwrap(plan)
+        // if None is returned, there is another reference to this
+        // LogicalPlan, so we can not own it, and must clone instead
+        .unwrap_or_else(|node| node.as_ref().clone())
+}
+
+/// Applies `f` to rewrite a `Arc<LogicalPlan>` without copying, if possible
+fn rewrite_arc<F>(
+    plan: Arc<LogicalPlan>,
+    mut f: F,
+) -> Result<Transformed<Arc<LogicalPlan>>>
+where
+    F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
+{
+    f(unwrap_arc(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan)))
+}
+
+/// rewrite a `Vec` of `Arc<LogicalPlan>` without copying, if possible
+fn rewrite_arcs<F>(
+    input_plans: Vec<Arc<LogicalPlan>>,
+    mut f: F,
+) -> Result<Transformed<Vec<Arc<LogicalPlan>>>>
+where
+    F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
+{
+    Ok(input_plans
+        .into_iter()
+        .map(unwrap_arc)

Review Comment:
   This is an excellent suggestion -- the code is both less verbose, and and it 
saves having to make a second `Vec`
   
   I made this change in b29ebd26e



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to