This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new cb21404bd3 Avoid `LogicalPlan::clone()` in `LogicalPlan::map_children` 
when possible (#9999)
cb21404bd3 is described below

commit cb21404bd3736ff9a6d8d443a67c64ece4c551a9
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Apr 9 07:22:05 2024 -0400

    Avoid `LogicalPlan::clone()` in `LogicalPlan::map_children` when possible 
(#9999)
    
    * Implement TreeNode::map_children in place
    
    * fix doc
    
    * Avoid explict unwrap
---
 datafusion/expr/src/logical_plan/tree_node.rs | 383 ++++++++++++++++++++++++--
 1 file changed, 363 insertions(+), 20 deletions(-)

diff --git a/datafusion/expr/src/logical_plan/tree_node.rs 
b/datafusion/expr/src/logical_plan/tree_node.rs
index 482fc96b51..ce26cac797 100644
--- a/datafusion/expr/src/logical_plan/tree_node.rs
+++ b/datafusion/expr/src/logical_plan/tree_node.rs
@@ -17,12 +17,19 @@
 
 //! Tree node implementation for logical plan
 
-use crate::LogicalPlan;
+use crate::{
+    Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, 
DdlStatement, Distinct,
+    DistinctOn, DmlStatement, Explain, Extension, Filter, Join, Limit, 
LogicalPlan,
+    Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, 
SubqueryAlias,
+    Union, Unnest, Window,
+};
+use std::sync::Arc;
 
+use crate::dml::CopyTo;
 use datafusion_common::tree_node::{
     Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
 };
-use datafusion_common::Result;
+use datafusion_common::{map_until_stop_and_collect, Result};
 
 impl TreeNode for LogicalPlan {
     fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
@@ -32,23 +39,359 @@ impl TreeNode for LogicalPlan {
         self.inputs().into_iter().apply_until_stop(f)
     }
 
-    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
-        self,
-        f: F,
-    ) -> Result<Transformed<Self>> {
-        let new_children = self
-            .inputs()
-            .into_iter()
-            .cloned()
-            .map_until_stop_and_collect(f)?;
-        // Propagate up `new_children.transformed` and `new_children.tnr`
-        // along with the node containing transformed children.
-        if new_children.transformed {
-            new_children.map_data(|new_children| {
-                self.with_new_exprs(self.expressions(), new_children)
-            })
-        } else {
-            Ok(new_children.update_data(|_| self))
-        }
+    /// Applies `f` to each child (input) of this plan node, rewriting them 
*in place.*
+    ///
+    /// # Notes
+    ///
+    /// Inputs include ONLY direct children, not embedded `LogicalPlan`s for
+    /// subqueries, for example such as are in [`Expr::Exists`].
+    ///
+    /// [`Expr::Exists`]: crate::Expr::Exists
+    fn map_children<F>(self, mut f: F) -> Result<Transformed<Self>>
+    where
+        F: FnMut(Self) -> Result<Transformed<Self>>,
+    {
+        Ok(match self {
+            LogicalPlan::Projection(Projection {
+                expr,
+                input,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Projection(Projection {
+                    expr,
+                    input,
+                    schema,
+                })
+            }),
+            LogicalPlan::Filter(Filter { predicate, input }) => 
rewrite_arc(input, f)?
+                .update_data(|input| LogicalPlan::Filter(Filter { predicate, 
input })),
+            LogicalPlan::Repartition(Repartition {
+                input,
+                partitioning_scheme,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Repartition(Repartition {
+                    input,
+                    partitioning_scheme,
+                })
+            }),
+            LogicalPlan::Window(Window {
+                input,
+                window_expr,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Window(Window {
+                    input,
+                    window_expr,
+                    schema,
+                })
+            }),
+            LogicalPlan::Aggregate(Aggregate {
+                input,
+                group_expr,
+                aggr_expr,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Aggregate(Aggregate {
+                    input,
+                    group_expr,
+                    aggr_expr,
+                    schema,
+                })
+            }),
+            LogicalPlan::Sort(Sort { expr, input, fetch }) => 
rewrite_arc(input, f)?
+                .update_data(|input| LogicalPlan::Sort(Sort { expr, input, 
fetch })),
+            LogicalPlan::Join(Join {
+                left,
+                right,
+                on,
+                filter,
+                join_type,
+                join_constraint,
+                schema,
+                null_equals_null,
+            }) => map_until_stop_and_collect!(
+                rewrite_arc(left, &mut f),
+                right,
+                rewrite_arc(right, &mut f)
+            )?
+            .update_data(|(left, right)| {
+                LogicalPlan::Join(Join {
+                    left,
+                    right,
+                    on,
+                    filter,
+                    join_type,
+                    join_constraint,
+                    schema,
+                    null_equals_null,
+                })
+            }),
+            LogicalPlan::CrossJoin(CrossJoin {
+                left,
+                right,
+                schema,
+            }) => map_until_stop_and_collect!(
+                rewrite_arc(left, &mut f),
+                right,
+                rewrite_arc(right, &mut f)
+            )?
+            .update_data(|(left, right)| {
+                LogicalPlan::CrossJoin(CrossJoin {
+                    left,
+                    right,
+                    schema,
+                })
+            }),
+            LogicalPlan::Limit(Limit { skip, fetch, input }) => 
rewrite_arc(input, f)?
+                .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, 
input })),
+            LogicalPlan::Subquery(Subquery {
+                subquery,
+                outer_ref_columns,
+            }) => rewrite_arc(subquery, f)?.update_data(|subquery| {
+                LogicalPlan::Subquery(Subquery {
+                    subquery,
+                    outer_ref_columns,
+                })
+            }),
+            LogicalPlan::SubqueryAlias(SubqueryAlias {
+                input,
+                alias,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::SubqueryAlias(SubqueryAlias {
+                    input,
+                    alias,
+                    schema,
+                })
+            }),
+            LogicalPlan::Extension(extension) => 
rewrite_extension_inputs(extension, f)?
+                .update_data(LogicalPlan::Extension),
+            LogicalPlan::Union(Union { inputs, schema }) => 
rewrite_arcs(inputs, f)?
+                .update_data(|inputs| LogicalPlan::Union(Union { inputs, 
schema })),
+            LogicalPlan::Distinct(distinct) => match distinct {
+                Distinct::All(input) => rewrite_arc(input, 
f)?.update_data(Distinct::All),
+                Distinct::On(DistinctOn {
+                    on_expr,
+                    select_expr,
+                    sort_expr,
+                    input,
+                    schema,
+                }) => rewrite_arc(input, f)?.update_data(|input| {
+                    Distinct::On(DistinctOn {
+                        on_expr,
+                        select_expr,
+                        sort_expr,
+                        input,
+                        schema,
+                    })
+                }),
+            }
+            .update_data(LogicalPlan::Distinct),
+            LogicalPlan::Explain(Explain {
+                verbose,
+                plan,
+                stringified_plans,
+                schema,
+                logical_optimization_succeeded,
+            }) => rewrite_arc(plan, f)?.update_data(|plan| {
+                LogicalPlan::Explain(Explain {
+                    verbose,
+                    plan,
+                    stringified_plans,
+                    schema,
+                    logical_optimization_succeeded,
+                })
+            }),
+            LogicalPlan::Analyze(Analyze {
+                verbose,
+                input,
+                schema,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Analyze(Analyze {
+                    verbose,
+                    input,
+                    schema,
+                })
+            }),
+            LogicalPlan::Dml(DmlStatement {
+                table_name,
+                table_schema,
+                op,
+                input,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Dml(DmlStatement {
+                    table_name,
+                    table_schema,
+                    op,
+                    input,
+                })
+            }),
+            LogicalPlan::Copy(CopyTo {
+                input,
+                output_url,
+                partition_by,
+                format_options,
+                options,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Copy(CopyTo {
+                    input,
+                    output_url,
+                    partition_by,
+                    format_options,
+                    options,
+                })
+            }),
+            LogicalPlan::Ddl(ddl) => {
+                match ddl {
+                    DdlStatement::CreateMemoryTable(CreateMemoryTable {
+                        name,
+                        constraints,
+                        input,
+                        if_not_exists,
+                        or_replace,
+                        column_defaults,
+                    }) => rewrite_arc(input, f)?.update_data(|input| {
+                        DdlStatement::CreateMemoryTable(CreateMemoryTable {
+                            name,
+                            constraints,
+                            input,
+                            if_not_exists,
+                            or_replace,
+                            column_defaults,
+                        })
+                    }),
+                    DdlStatement::CreateView(CreateView {
+                        name,
+                        input,
+                        or_replace,
+                        definition,
+                    }) => rewrite_arc(input, f)?.update_data(|input| {
+                        DdlStatement::CreateView(CreateView {
+                            name,
+                            input,
+                            or_replace,
+                            definition,
+                        })
+                    }),
+                    // no inputs in these statements
+                    DdlStatement::CreateExternalTable(_)
+                    | DdlStatement::CreateCatalogSchema(_)
+                    | DdlStatement::CreateCatalog(_)
+                    | DdlStatement::DropTable(_)
+                    | DdlStatement::DropView(_)
+                    | DdlStatement::DropCatalogSchema(_)
+                    | DdlStatement::CreateFunction(_)
+                    | DdlStatement::DropFunction(_) => Transformed::no(ddl),
+                }
+                .update_data(LogicalPlan::Ddl)
+            }
+            LogicalPlan::Unnest(Unnest {
+                input,
+                column,
+                schema,
+                options,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Unnest(Unnest {
+                    input,
+                    column,
+                    schema,
+                    options,
+                })
+            }),
+            LogicalPlan::Prepare(Prepare {
+                name,
+                data_types,
+                input,
+            }) => rewrite_arc(input, f)?.update_data(|input| {
+                LogicalPlan::Prepare(Prepare {
+                    name,
+                    data_types,
+                    input,
+                })
+            }),
+            LogicalPlan::RecursiveQuery(RecursiveQuery {
+                name,
+                static_term,
+                recursive_term,
+                is_distinct,
+            }) => map_until_stop_and_collect!(
+                rewrite_arc(static_term, &mut f),
+                recursive_term,
+                rewrite_arc(recursive_term, &mut f)
+            )?
+            .update_data(|(static_term, recursive_term)| {
+                LogicalPlan::RecursiveQuery(RecursiveQuery {
+                    name,
+                    static_term,
+                    recursive_term,
+                    is_distinct,
+                })
+            }),
+            // plans without inputs
+            LogicalPlan::TableScan { .. }
+            | LogicalPlan::Statement { .. }
+            | LogicalPlan::EmptyRelation { .. }
+            | LogicalPlan::Values { .. }
+            | LogicalPlan::DescribeTable(_) => Transformed::no(self),
+        })
     }
 }
+
+/// Converts a `Arc<LogicalPlan>` without copying, if possible. Copies the plan
+/// if there is a shared reference
+fn unwrap_arc(plan: Arc<LogicalPlan>) -> LogicalPlan {
+    Arc::try_unwrap(plan)
+        // if None is returned, there is another reference to this
+        // LogicalPlan, so we can not own it, and must clone instead
+        .unwrap_or_else(|node| node.as_ref().clone())
+}
+
+/// Applies `f` to rewrite a `Arc<LogicalPlan>` without copying, if possible
+fn rewrite_arc<F>(
+    plan: Arc<LogicalPlan>,
+    mut f: F,
+) -> Result<Transformed<Arc<LogicalPlan>>>
+where
+    F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
+{
+    f(unwrap_arc(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan)))
+}
+
+/// rewrite a `Vec` of `Arc<LogicalPlan>` without copying, if possible
+fn rewrite_arcs<F>(
+    input_plans: Vec<Arc<LogicalPlan>>,
+    mut f: F,
+) -> Result<Transformed<Vec<Arc<LogicalPlan>>>>
+where
+    F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
+{
+    input_plans
+        .into_iter()
+        .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f))
+}
+
+/// Rewrites all inputs for an Extension node "in place"
+/// (it currently has to copy values because there are no APIs for in place 
modification)
+///
+/// Should be removed when we have an API for in place modifications of the
+/// extension to avoid these copies
+fn rewrite_extension_inputs<F>(
+    extension: Extension,
+    f: F,
+) -> Result<Transformed<Extension>>
+where
+    F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
+{
+    let Extension { node } = extension;
+
+    node.inputs()
+        .into_iter()
+        .cloned()
+        .map_until_stop_and_collect(f)?
+        .map_data(|new_inputs| {
+            let exprs = node.expressions();
+            Ok(Extension {
+                node: node.from_template(&exprs, &new_inputs),
+            })
+        })
+}

Reply via email to