alamb commented on code in PR #9780: URL: https://github.com/apache/arrow-datafusion/pull/9780#discussion_r1552005102
########## datafusion/expr/src/logical_plan/mutate.rs: ########## @@ -0,0 +1,346 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::plan::*; +use crate::expr::{Exists, InSubquery}; +use crate::{Expr, UserDefinedLogicalNode}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Result}; +use datafusion_common::{Column, DFSchema, DFSchemaRef}; +use std::sync::{Arc, OnceLock}; + +impl LogicalPlan { + /// applies `f` to each expression of this node, potentially rewriting it in + /// place + /// + /// If `f` returns an error, the error is returned and the expressions are + /// left in a partially modified state + pub fn rewrite_exprs<F>(&mut self, mut f: F) -> Result<Transformed<()>> + where + F: FnMut(&mut Expr) -> Result<Transformed<()>>, + { + match self { + LogicalPlan::Projection(Projection { expr, .. }) => { + rewrite_expr_iter_mut(expr.iter_mut(), f) + } + LogicalPlan::Values(Values { values, .. }) => { + rewrite_expr_iter_mut(values.iter_mut().flatten(), f) + } + LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::Hash(expr, _) => rewrite_expr_iter_mut(expr.iter_mut(), f), + Partitioning::DistributeBy(expr) => { + rewrite_expr_iter_mut(expr.iter_mut(), f) + } + Partitioning::RoundRobinBatch(_) => Ok(Transformed::no(())), + }, + LogicalPlan::Window(Window { window_expr, .. }) => { + rewrite_expr_iter_mut(window_expr.iter_mut(), f) + } + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => { + let exprs = group_expr.iter_mut().chain(aggr_expr.iter_mut()); + rewrite_expr_iter_mut(exprs, f) + } + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { on, filter, .. }) => { + let exprs = on + .iter_mut() + .flat_map(|(e1, e2)| std::iter::once(e1).chain(std::iter::once(e2))); + + let result = rewrite_expr_iter_mut(exprs, &mut f)?; + + if let Some(filter) = filter.as_mut() { + result.and_then(|| f(filter)) + } else { + Ok(result) + } + } + LogicalPlan::Sort(Sort { expr, .. }) => { + rewrite_expr_iter_mut(expr.iter_mut(), f) + } + LogicalPlan::Extension(extension) => { + rewrite_extension_exprs(&mut extension.node, f) + } + LogicalPlan::TableScan(TableScan { filters, .. }) => { + rewrite_expr_iter_mut(filters.iter_mut(), f) + } + LogicalPlan::Unnest(Unnest { column, .. }) => rewrite_column(column, f), + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => { + let exprs = on_expr + .iter_mut() + .chain(select_expr.iter_mut()) + .chain(sort_expr.iter_mut().flat_map(|x| x.iter_mut())); + + rewrite_expr_iter_mut(exprs, f) + } + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Ok(Transformed::no(())), + } + } + + /// applies `f` to each input of this node, rewriting them in place. + /// + /// # Notes + /// Inputs include both direct children as well as any embedded subquery + /// `LogicalPlan`s, for example such as are in [`Expr::Exists`]. + /// + /// If `f` returns an `Err`, that Err is returned, and the inputs are left + /// in a partially modified state + pub fn rewrite_inputs<F>(&mut self, mut f: F) -> Result<Transformed<()>> + where + F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>, + { + let children_result = match self { + LogicalPlan::Projection(Projection { input, .. }) => { + rewrite_arc(input, &mut f) + } + LogicalPlan::Filter(Filter { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Repartition(Repartition { input, .. }) => { + rewrite_arc(input, &mut f) + } + LogicalPlan::Window(Window { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Aggregate(Aggregate { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Join(Join { left, right, .. }) => { + rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f)) + } + LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { + rewrite_arc(left, &mut f)?.and_then(|| rewrite_arc(right, &mut f)) + } + LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Subquery(Subquery { subquery, .. }) => { + rewrite_arc(subquery, &mut f) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { + rewrite_arc(input, &mut f) + } + LogicalPlan::Extension(extension) => { + rewrite_extension_inputs(&mut extension.node, &mut f) + } + LogicalPlan::Union(Union { inputs, .. }) => inputs + .iter_mut() + .try_fold(Transformed::no(()), |acc, input| { + acc.and_then(|| rewrite_arc(input, &mut f)) + }), + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => rewrite_arc(input, &mut f), + LogicalPlan::Explain(explain) => rewrite_arc(&mut explain.plan, &mut f), + LogicalPlan::Analyze(analyze) => rewrite_arc(&mut analyze.input, &mut f), + LogicalPlan::Dml(write) => rewrite_arc(&mut write.input, &mut f), + LogicalPlan::Copy(copy) => rewrite_arc(&mut copy.input, &mut f), + LogicalPlan::Ddl(ddl) => { + if let Some(input) = ddl.input_mut() { + rewrite_arc(input, &mut f) + } else { + Ok(Transformed::no(())) + } + } + LogicalPlan::Unnest(Unnest { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::Prepare(Prepare { input, .. }) => rewrite_arc(input, &mut f), + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => rewrite_arc(static_term, &mut f)? + .and_then(|| rewrite_arc(recursive_term, &mut f)), + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())), + }?; + + // after visiting the actual children we we need to visit any subqueries + // that are inside the expressions + children_result.and_then(|| self.rewrite_subqueries(&mut f)) + } + + /// applies `f` to LogicalPlans in any subquery expressions + /// + /// If Err is returned, the plan may be left in a partially modified state + fn rewrite_subqueries<F>(&mut self, mut f: F) -> Result<Transformed<()>> Review Comment: When I combined both https://github.com/apache/arrow-datafusion/pull/9948 and https://github.com/apache/arrow-datafusion/pull/9946 tpch planning time improves by 12% (basically the same as this PR) Thus I conclude we do not need the tree mutator API quite yet. Let's continue the discussions on those other PRs <details><summary>Details</summary> <p> ``` group both_tests main optimizer_tree_node2 ----- ---------- ---- -------------------- logical_aggregate_with_join 1.01 1203.3±30.57µs ? ?/sec 1.00 1197.3±17.89µs ? ?/sec 1.00 1194.4±22.14µs ? ?/sec logical_plan_tpcds_all 1.00 157.6±2.32ms ? ?/sec 1.00 157.3±3.68ms ? ?/sec 1.00 157.6±3.67ms ? ?/sec logical_plan_tpch_all 1.01 16.9±0.34ms ? ?/sec 1.00 16.7±0.22ms ? ?/sec 1.04 17.5±0.37ms ? ?/sec logical_select_all_from_1000 1.01 19.6±0.17ms ? ?/sec 1.00 19.4±0.17ms ? ?/sec 1.00 19.4±0.17ms ? ?/sec logical_select_one_from_700 1.01 790.8±12.36µs ? ?/sec 1.00 784.0±12.23µs ? ?/sec 1.00 786.1±12.87µs ? ?/sec logical_trivial_join_high_numbered_columns 1.00 740.8±12.72µs ? ?/sec 1.00 737.2±28.09µs ? ?/sec 1.00 739.5±12.25µs ? ?/sec logical_trivial_join_low_numbered_columns 1.00 719.9±14.19µs ? ?/sec 1.02 731.2±11.31µs ? ?/sec 1.00 721.4±19.28µs ? ?/sec physical_plan_tpcds_all 1.00 1666.1±16.04ms ? ?/sec 1.12 1872.6±16.17ms ? ?/sec 1.20 2.0±0.02s ? ?/sec physical_plan_tpch_all 1.00 110.2±1.86ms ? ?/sec 1.10 121.3±1.90ms ? ?/sec 1.21 133.3±2.21ms ? ?/sec physical_plan_tpch_q1 1.00 6.2±0.11ms ? ?/sec 1.19 7.4±0.12ms ? ?/sec 1.23 7.6±0.07ms ? ?/sec physical_plan_tpch_q10 1.00 5.1±0.07ms ? ?/sec 1.11 5.6±0.11ms ? ?/sec 1.20 6.1±0.10ms ? ?/sec physical_plan_tpch_q11 1.00 4.5±0.09ms ? ?/sec 1.10 5.0±0.07ms ? ?/sec 1.20 5.4±0.10ms ? ?/sec physical_plan_tpch_q12 1.00 3.6±0.06ms ? ?/sec 1.11 4.0±0.07ms ? ?/sec 1.21 4.4±0.11ms ? ?/sec physical_plan_tpch_q13 1.00 2.4±0.04ms ? ?/sec 1.10 2.7±0.06ms ? ?/sec 1.18 2.9±0.08ms ? ?/sec physical_plan_tpch_q14 1.00 3.1±0.05ms ? ?/sec 1.09 3.4±0.06ms ? ?/sec 1.19 3.7±0.09ms ? ?/sec physical_plan_tpch_q16 1.00 4.5±0.06ms ? ?/sec 1.10 4.9±0.08ms ? ?/sec 1.24 5.6±0.10ms ? ?/sec physical_plan_tpch_q17 1.00 4.3±0.07ms ? ?/sec 1.09 4.7±0.05ms ? ?/sec 1.21 5.2±0.10ms ? ?/sec physical_plan_tpch_q18 1.00 4.6±0.09ms ? ?/sec 1.10 5.1±0.09ms ? ?/sec 1.22 5.6±0.10ms ? ?/sec physical_plan_tpch_q19 1.00 9.0±0.11ms ? ?/sec 1.06 9.5±0.11ms ? ?/sec 1.23 11.1±0.14ms ? ?/sec physical_plan_tpch_q2 1.00 9.6±0.12ms ? ?/sec 1.11 10.6±0.18ms ? ?/sec 1.23 11.7±0.14ms ? ?/sec physical_plan_tpch_q20 1.00 5.5±0.08ms ? ?/sec 1.11 6.1±0.08ms ? ?/sec 1.24 6.9±0.14ms ? ?/sec physical_plan_tpch_q21 1.00 7.6±0.11ms ? ?/sec 1.12 8.5±0.14ms ? ?/sec 1.22 9.3±0.13ms ? ?/sec physical_plan_tpch_q22 1.00 4.1±0.10ms ? ?/sec 1.10 4.5±0.10ms ? ?/sec 1.23 5.0±0.12ms ? ?/sec physical_plan_tpch_q3 1.00 3.6±0.07ms ? ?/sec 1.08 3.9±0.06ms ? ?/sec 1.19 4.3±0.11ms ? ?/sec physical_plan_tpch_q4 1.00 2.7±0.06ms ? ?/sec 1.11 3.0±0.05ms ? ?/sec 1.21 3.2±0.11ms ? ?/sec physical_plan_tpch_q5 1.00 5.3±0.08ms ? ?/sec 1.08 5.7±0.12ms ? ?/sec 1.17 6.2±0.08ms ? ?/sec physical_plan_tpch_q6 1.00 1905.1±30.83µs ? ?/sec 1.06 2.0±0.03ms ? ?/sec 1.15 2.2±0.05ms ? ?/sec physical_plan_tpch_q7 1.00 7.0±0.11ms ? ?/sec 1.08 7.6±0.12ms ? ?/sec 1.19 8.4±0.14ms ? ?/sec physical_plan_tpch_q8 1.00 8.9±0.14ms ? ?/sec 1.10 9.7±0.18ms ? ?/sec 1.19 10.6±0.16ms ? ?/sec physical_plan_tpch_q9 1.00 6.7±0.11ms ? ?/sec 1.10 7.3±0.10ms ? ?/sec 1.21 8.0±0.10ms ? ?/sec physical_select_all_from_1000 1.00 114.1±0.77ms ? ?/sec 1.13 129.2±1.14ms ? ?/sec 1.12 128.3±0.79ms ? ?/sec physical_select_one_from_700 1.00 3.9±0.08ms ? ?/sec 1.03 4.1±0.05ms ? ?/sec 1.04 4.1±0.04ms ? ?/sec ``` </p> </details> -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
