This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0a4d9a6c78 Consistent LogicalPlan subquery handling in TreeNode::apply
and TreeNode::visit (#9913)
0a4d9a6c78 is described below
commit 0a4d9a6c788c1e4ad340943492abb823bd31c4f9
Author: Peter Toth <[email protected]>
AuthorDate: Mon Apr 8 11:28:59 2024 +0200
Consistent LogicalPlan subquery handling in TreeNode::apply and
TreeNode::visit (#9913)
* fix
* clippy
* remove accidental extra apply
* minor fixes
* fix `LogicalPlan::apply_expressions()` and `LogicalPlan::map_subqueries()`
* fix `LogicalPlan::visit_with_subqueries()`
* Add deprecated LogicalPlan::inspect_expressions
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/tree_node.rs | 3 +-
datafusion/core/src/execution/context/mod.rs | 4 +-
datafusion/expr/src/logical_plan/plan.rs | 558 +++++++++++++++++++++-----
datafusion/expr/src/tree_node/expr.rs | 2 +-
datafusion/expr/src/tree_node/plan.rs | 53 +--
datafusion/optimizer/src/analyzer/mod.rs | 15 +-
datafusion/optimizer/src/analyzer/subquery.rs | 2 +-
datafusion/optimizer/src/plan_signature.rs | 4 +-
8 files changed, 475 insertions(+), 166 deletions(-)
diff --git a/datafusion/common/src/tree_node.rs
b/datafusion/common/src/tree_node.rs
index 8e088e7a0b..42514537e2 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -25,10 +25,9 @@ use crate::Result;
/// These macros are used to determine continuation during transforming
traversals.
macro_rules! handle_transform_recursion {
($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
- #[allow(clippy::redundant_closure_call)]
$F_DOWN?
.transform_children(|n| n.map_children($F_CHILD))?
- .transform_parent(|n| $F_UP(n))
+ .transform_parent($F_UP)
}};
}
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index f15c1c218d..9e48c7b8a6 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -67,7 +67,7 @@ use datafusion_common::{
alias::AliasGenerator,
config::{ConfigExtension, TableOptions},
exec_err, not_impl_err, plan_datafusion_err, plan_err,
- tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
+ tree_node::{TreeNodeRecursion, TreeNodeVisitor},
SchemaReference, TableReference,
};
use datafusion_execution::registry::SerializerRegistry;
@@ -2298,7 +2298,7 @@ impl SQLOptions {
/// Return an error if the [`LogicalPlan`] has any nodes that are
/// incompatible with this [`SQLOptions`].
pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> {
- plan.visit(&mut BadPlanVisitor::new(self))?;
+ plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?;
Ok(())
}
}
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 3d40dcae0e..4f55bbfe3f 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -34,8 +34,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::{DmlStatement, Statement};
use crate::utils::{
enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs,
- grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre,
- split_conjunction,
+ grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction,
};
use crate::{
build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction,
@@ -45,16 +44,19 @@ use crate::{
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::tree_node::{
- Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
TreeNodeVisitor,
+ Transformed, TransformedResult, TreeNode, TreeNodeIterator,
TreeNodeRecursion,
+ TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{
- aggregate_functional_dependencies, internal_err, plan_err, Column,
Constraints,
- DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence,
- FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions,
+ aggregate_functional_dependencies, internal_err,
map_until_stop_and_collect,
+ plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError,
Dependency,
+ FunctionalDependence, FunctionalDependencies, ParamValues, Result,
TableReference,
+ UnnestOptions,
};
// backwards compatibility
use crate::display::PgJsonVisitor;
+use crate::tree_node::expr::transform_option_vec;
pub use datafusion_common::display::{PlanType, StringifiedPlan,
ToStringifiedPlan};
pub use datafusion_common::{JoinConstraint, JoinType};
@@ -248,9 +250,9 @@ impl LogicalPlan {
/// DataFusion's optimizer attempts to optimize them away.
pub fn expressions(self: &LogicalPlan) -> Vec<Expr> {
let mut exprs = vec![];
- self.inspect_expressions(|e| {
+ self.apply_expressions(|e| {
exprs.push(e.clone());
- Ok(()) as Result<()>
+ Ok(TreeNodeRecursion::Continue)
})
// closure always returns OK
.unwrap();
@@ -261,13 +263,13 @@ impl LogicalPlan {
/// logical plan nodes and all its descendant nodes.
pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec<Expr> {
let mut exprs = vec![];
- self.inspect_expressions(|e| {
+ self.apply_expressions(|e| {
find_out_reference_exprs(e).into_iter().for_each(|e| {
if !exprs.contains(&e) {
exprs.push(e)
}
});
- Ok(()) as Result<(), DataFusionError>
+ Ok(TreeNodeRecursion::Continue)
})
// closure always returns OK
.unwrap();
@@ -282,60 +284,81 @@ impl LogicalPlan {
exprs
}
- /// Calls `f` on all expressions (non-recursively) in the current
- /// logical plan node. This does not include expressions in any
- /// children.
+ #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")]
pub fn inspect_expressions<F, E>(self: &LogicalPlan, mut f: F) ->
Result<(), E>
where
F: FnMut(&Expr) -> Result<(), E>,
{
+ let mut err = Ok(());
+ self.apply_expressions(|e| {
+ if let Err(e) = f(e) {
+ // save the error for later (it may not be a DataFusionError
+ err = Err(e);
+ Ok(TreeNodeRecursion::Stop)
+ } else {
+ Ok(TreeNodeRecursion::Continue)
+ }
+ })
+ // The closure always returns OK, so this will always too
+ .expect("no way to return error during recursion");
+
+ err
+ }
+
+ /// Calls `f` on all expressions (non-recursively) in the current
+ /// logical plan node. This does not include expressions in any
+ /// children.
+ pub fn apply_expressions<F: FnMut(&Expr) -> Result<TreeNodeRecursion>>(
+ &self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
match self {
LogicalPlan::Projection(Projection { expr, .. }) => {
- expr.iter().try_for_each(f)
- }
- LogicalPlan::Values(Values { values, .. }) => {
- values.iter().flatten().try_for_each(f)
+ expr.iter().apply_until_stop(f)
}
+ LogicalPlan::Values(Values { values, .. }) => values
+ .iter()
+ .apply_until_stop(|value| value.iter().apply_until_stop(&mut
f)),
LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate),
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
- Partitioning::Hash(expr, _) => expr.iter().try_for_each(f),
- Partitioning::DistributeBy(expr) =>
expr.iter().try_for_each(f),
- Partitioning::RoundRobinBatch(_) => Ok(()),
+ Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr)
=> {
+ expr.iter().apply_until_stop(f)
+ }
+ Partitioning::RoundRobinBatch(_) =>
Ok(TreeNodeRecursion::Continue),
},
LogicalPlan::Window(Window { window_expr, .. }) => {
- window_expr.iter().try_for_each(f)
+ window_expr.iter().apply_until_stop(f)
}
LogicalPlan::Aggregate(Aggregate {
group_expr,
aggr_expr,
..
- }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f),
+ }) => group_expr
+ .iter()
+ .chain(aggr_expr.iter())
+ .apply_until_stop(f),
// There are two part of expression for join, equijoin(on) and
non-equijoin(filter).
// 1. the first part is `on.len()` equijoin expressions, and the
struct of each expr is `left-on = right-on`.
// 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join { on, filter, .. }) => {
on.iter()
+ // TODO: why we need to create an `Expr::eq`? Cloning
`Expr` is costly...
// it not ideal to create an expr here to analyze them,
but could cache it on the Join itself
.map(|(l, r)| Expr::eq(l.clone(), r.clone()))
- .try_for_each(|e| f(&e))?;
-
- if let Some(filter) = filter.as_ref() {
- f(filter)
- } else {
- Ok(())
- }
+ .apply_until_stop(|e| f(&e))?
+ .visit_sibling(|| filter.iter().apply_until_stop(f))
}
- LogicalPlan::Sort(Sort { expr, .. }) =>
expr.iter().try_for_each(f),
+ LogicalPlan::Sort(Sort { expr, .. }) =>
expr.iter().apply_until_stop(f),
LogicalPlan::Extension(extension) => {
// would be nice to avoid this copy -- maybe can
// update extension to just observer Exprs
- extension.node.expressions().iter().try_for_each(f)
+ extension.node.expressions().iter().apply_until_stop(f)
}
LogicalPlan::TableScan(TableScan { filters, .. }) => {
- filters.iter().try_for_each(f)
+ filters.iter().apply_until_stop(f)
}
LogicalPlan::Unnest(Unnest { column, .. }) => {
f(&Expr::Column(column.clone()))
@@ -348,8 +371,8 @@ impl LogicalPlan {
})) => on_expr
.iter()
.chain(select_expr.iter())
- .chain(sort_expr.clone().unwrap_or(vec![]).iter())
- .try_for_each(f),
+ .chain(sort_expr.iter().flatten())
+ .apply_until_stop(f),
// plans without expressions
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
@@ -366,10 +389,225 @@ impl LogicalPlan {
| LogicalPlan::Ddl(_)
| LogicalPlan::Copy(_)
| LogicalPlan::DescribeTable(_)
- | LogicalPlan::Prepare(_) => Ok(()),
+ | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue),
}
}
+ pub fn map_expressions<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ Ok(match self {
+ LogicalPlan::Projection(Projection {
+ expr,
+ input,
+ schema,
+ }) => expr
+ .into_iter()
+ .map_until_stop_and_collect(f)?
+ .update_data(|expr| {
+ LogicalPlan::Projection(Projection {
+ expr,
+ input,
+ schema,
+ })
+ }),
+ LogicalPlan::Values(Values { schema, values }) => values
+ .into_iter()
+ .map_until_stop_and_collect(|value| {
+ value.into_iter().map_until_stop_and_collect(&mut f)
+ })?
+ .update_data(|values| LogicalPlan::Values(Values { schema,
values })),
+ LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)?
+ .update_data(|predicate| {
+ LogicalPlan::Filter(Filter { predicate, input })
+ }),
+ LogicalPlan::Repartition(Repartition {
+ input,
+ partitioning_scheme,
+ }) => match partitioning_scheme {
+ Partitioning::Hash(expr, usize) => expr
+ .into_iter()
+ .map_until_stop_and_collect(f)?
+ .update_data(|expr| Partitioning::Hash(expr, usize)),
+ Partitioning::DistributeBy(expr) => expr
+ .into_iter()
+ .map_until_stop_and_collect(f)?
+ .update_data(Partitioning::DistributeBy),
+ Partitioning::RoundRobinBatch(_) =>
Transformed::no(partitioning_scheme),
+ }
+ .update_data(|partitioning_scheme| {
+ LogicalPlan::Repartition(Repartition {
+ input,
+ partitioning_scheme,
+ })
+ }),
+ LogicalPlan::Window(Window {
+ input,
+ window_expr,
+ schema,
+ }) => window_expr
+ .into_iter()
+ .map_until_stop_and_collect(f)?
+ .update_data(|window_expr| {
+ LogicalPlan::Window(Window {
+ input,
+ window_expr,
+ schema,
+ })
+ }),
+ LogicalPlan::Aggregate(Aggregate {
+ input,
+ group_expr,
+ aggr_expr,
+ schema,
+ }) => map_until_stop_and_collect!(
+ group_expr.into_iter().map_until_stop_and_collect(&mut f),
+ aggr_expr,
+ aggr_expr.into_iter().map_until_stop_and_collect(&mut f)
+ )?
+ .update_data(|(group_expr, aggr_expr)| {
+ LogicalPlan::Aggregate(Aggregate {
+ input,
+ group_expr,
+ aggr_expr,
+ schema,
+ })
+ }),
+
+ // There are two part of expression for join, equijoin(on) and
non-equijoin(filter).
+ // 1. the first part is `on.len()` equijoin expressions, and the
struct of each expr is `left-on = right-on`.
+ // 2. the second part is non-equijoin(filter).
+ LogicalPlan::Join(Join {
+ left,
+ right,
+ on,
+ filter,
+ join_type,
+ join_constraint,
+ schema,
+ null_equals_null,
+ }) => map_until_stop_and_collect!(
+ on.into_iter().map_until_stop_and_collect(
+ |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1))
+ ),
+ filter,
+ filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)),
|e| {
+ Ok(f(e)?.update_data(Some))
+ })
+ )?
+ .update_data(|(on, filter)| {
+ LogicalPlan::Join(Join {
+ left,
+ right,
+ on,
+ filter,
+ join_type,
+ join_constraint,
+ schema,
+ null_equals_null,
+ })
+ }),
+ LogicalPlan::Sort(Sort { expr, input, fetch }) => expr
+ .into_iter()
+ .map_until_stop_and_collect(f)?
+ .update_data(|expr| LogicalPlan::Sort(Sort { expr, input,
fetch })),
+ LogicalPlan::Extension(Extension { node }) => {
+ // would be nice to avoid this copy -- maybe can
+ // update extension to just observer Exprs
+ node.expressions()
+ .into_iter()
+ .map_until_stop_and_collect(f)?
+ .update_data(|exprs| {
+ LogicalPlan::Extension(Extension {
+ node: UserDefinedLogicalNode::from_template(
+ node.as_ref(),
+ exprs.as_slice(),
+ node.inputs()
+ .into_iter()
+ .cloned()
+ .collect::<Vec<_>>()
+ .as_slice(),
+ ),
+ })
+ })
+ }
+ LogicalPlan::TableScan(TableScan {
+ table_name,
+ source,
+ projection,
+ projected_schema,
+ filters,
+ fetch,
+ }) => filters
+ .into_iter()
+ .map_until_stop_and_collect(f)?
+ .update_data(|filters| {
+ LogicalPlan::TableScan(TableScan {
+ table_name,
+ source,
+ projection,
+ projected_schema,
+ filters,
+ fetch,
+ })
+ }),
+ LogicalPlan::Unnest(Unnest {
+ input,
+ column,
+ schema,
+ options,
+ }) => f(Expr::Column(column))?.map_data(|column| match column {
+ Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest {
+ input,
+ column,
+ schema,
+ options,
+ })),
+ _ => internal_err!("Transformation should return Column"),
+ })?,
+ LogicalPlan::Distinct(Distinct::On(DistinctOn {
+ on_expr,
+ select_expr,
+ sort_expr,
+ input,
+ schema,
+ })) => map_until_stop_and_collect!(
+ on_expr.into_iter().map_until_stop_and_collect(&mut f),
+ select_expr,
+ select_expr.into_iter().map_until_stop_and_collect(&mut f),
+ sort_expr,
+ transform_option_vec(sort_expr, &mut f)
+ )?
+ .update_data(|(on_expr, select_expr, sort_expr)| {
+ LogicalPlan::Distinct(Distinct::On(DistinctOn {
+ on_expr,
+ select_expr,
+ sort_expr,
+ input,
+ schema,
+ }))
+ }),
+ // plans without expressions
+ LogicalPlan::EmptyRelation(_)
+ | LogicalPlan::RecursiveQuery(_)
+ | LogicalPlan::Subquery(_)
+ | LogicalPlan::SubqueryAlias(_)
+ | LogicalPlan::Limit(_)
+ | LogicalPlan::Statement(_)
+ | LogicalPlan::CrossJoin(_)
+ | LogicalPlan::Analyze(_)
+ | LogicalPlan::Explain(_)
+ | LogicalPlan::Union(_)
+ | LogicalPlan::Distinct(Distinct::All(_))
+ | LogicalPlan::Dml(_)
+ | LogicalPlan::Ddl(_)
+ | LogicalPlan::Copy(_)
+ | LogicalPlan::DescribeTable(_)
+ | LogicalPlan::Prepare(_) => Transformed::no(self),
+ })
+ }
+
/// returns all inputs of this `LogicalPlan` node. Does not
/// include inputs to inputs, or subqueries.
pub fn inputs(&self) -> Vec<&LogicalPlan> {
@@ -417,7 +655,7 @@ impl LogicalPlan {
pub fn using_columns(&self) -> Result<Vec<HashSet<Column>>,
DataFusionError> {
let mut using_columns: Vec<HashSet<Column>> = vec![];
- self.apply(&mut |plan| {
+ self.apply_with_subqueries(&mut |plan| {
if let LogicalPlan::Join(Join {
join_constraint: JoinConstraint::Using,
on,
@@ -1079,57 +1317,178 @@ impl LogicalPlan {
}
}
+/// This macro is used to determine continuation during combined transforming
+/// traversals.
+macro_rules! handle_transform_recursion {
+ ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
+ $F_DOWN?
+ .transform_children(|n| n.map_subqueries($F_CHILD))?
+ .transform_sibling(|n| n.map_children($F_CHILD))?
+ .transform_parent($F_UP)
+ }};
+}
+
+macro_rules! handle_transform_recursion_down {
+ ($F_DOWN:expr, $F_CHILD:expr) => {{
+ $F_DOWN?
+ .transform_children(|n| n.map_subqueries($F_CHILD))?
+ .transform_sibling(|n| n.map_children($F_CHILD))
+ }};
+}
+
+macro_rules! handle_transform_recursion_up {
+ ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{
+ $SELF
+ .map_subqueries($F_CHILD)?
+ .transform_sibling(|n| n.map_children($F_CHILD))?
+ .transform_parent(|n| $F_UP(n))
+ }};
+}
+
impl LogicalPlan {
- /// applies `op` to any subqueries in the plan
- pub(crate) fn apply_subqueries<F>(&self, op: &mut F) -> Result<()>
- where
- F: FnMut(&Self) -> Result<TreeNodeRecursion>,
- {
- self.inspect_expressions(|expr| {
- // recursively look for subqueries
- inspect_expr_pre(expr, |expr| {
- match expr {
- Expr::Exists(Exists { subquery, .. })
- | Expr::InSubquery(InSubquery { subquery, .. })
- | Expr::ScalarSubquery(subquery) => {
- // use a synthetic plan so the collector sees a
- // LogicalPlan::Subquery (even though it is
- // actually a Subquery alias)
- let synthetic_plan =
LogicalPlan::Subquery(subquery.clone());
- synthetic_plan.apply(op)?;
- }
- _ => {}
+ pub fn visit_with_subqueries<V: TreeNodeVisitor<Node = Self>>(
+ &self,
+ visitor: &mut V,
+ ) -> Result<TreeNodeRecursion> {
+ visitor
+ .f_down(self)?
+ .visit_children(|| {
+ self.apply_subqueries(|c| c.visit_with_subqueries(visitor))
+ })?
+ .visit_sibling(|| self.apply_children(|c|
c.visit_with_subqueries(visitor)))?
+ .visit_parent(|| visitor.f_up(self))
+ }
+
+ pub fn rewrite_with_subqueries<R: TreeNodeRewriter<Node = Self>>(
+ self,
+ rewriter: &mut R,
+ ) -> Result<Transformed<Self>> {
+ handle_transform_recursion!(
+ rewriter.f_down(self),
+ |c| c.rewrite_with_subqueries(rewriter),
+ |n| rewriter.f_up(n)
+ )
+ }
+
+ pub fn apply_with_subqueries<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
+ &self,
+ f: &mut F,
+ ) -> Result<TreeNodeRecursion> {
+ f(self)?
+ .visit_children(|| self.apply_subqueries(|c|
c.apply_with_subqueries(f)))?
+ .visit_sibling(|| self.apply_children(|c|
c.apply_with_subqueries(f)))
+ }
+
+ pub fn transform_with_subqueries<F: Fn(Self) -> Result<Transformed<Self>>>(
+ self,
+ f: &F,
+ ) -> Result<Transformed<Self>> {
+ self.transform_up_with_subqueries(f)
+ }
+
+ pub fn transform_down_with_subqueries<F: Fn(Self) ->
Result<Transformed<Self>>>(
+ self,
+ f: &F,
+ ) -> Result<Transformed<Self>> {
+ handle_transform_recursion_down!(f(self), |c|
c.transform_down_with_subqueries(f))
+ }
+
+ pub fn transform_down_mut_with_subqueries<
+ F: FnMut(Self) -> Result<Transformed<Self>>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<Transformed<Self>> {
+ handle_transform_recursion_down!(f(self), |c| c
+ .transform_down_mut_with_subqueries(f))
+ }
+
+ pub fn transform_up_with_subqueries<F: Fn(Self) ->
Result<Transformed<Self>>>(
+ self,
+ f: &F,
+ ) -> Result<Transformed<Self>> {
+ handle_transform_recursion_up!(self, |c|
c.transform_up_with_subqueries(f), f)
+ }
+
+ pub fn transform_up_mut_with_subqueries<
+ F: FnMut(Self) -> Result<Transformed<Self>>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<Transformed<Self>> {
+ handle_transform_recursion_up!(self, |c|
c.transform_up_mut_with_subqueries(f), f)
+ }
+
+ pub fn transform_down_up_with_subqueries<
+ FD: FnMut(Self) -> Result<Transformed<Self>>,
+ FU: FnMut(Self) -> Result<Transformed<Self>>,
+ >(
+ self,
+ f_down: &mut FD,
+ f_up: &mut FU,
+ ) -> Result<Transformed<Self>> {
+ handle_transform_recursion!(
+ f_down(self),
+ |c| c.transform_down_up_with_subqueries(f_down, f_up),
+ f_up
+ )
+ }
+
+ fn apply_subqueries<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
+ &self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.apply_expressions(|expr| {
+ expr.apply(&mut |expr| match expr {
+ Expr::Exists(Exists { subquery, .. })
+ | Expr::InSubquery(InSubquery { subquery, .. })
+ | Expr::ScalarSubquery(subquery) => {
+ // use a synthetic plan so the collector sees a
+ // LogicalPlan::Subquery (even though it is
+ // actually a Subquery alias)
+ f(&LogicalPlan::Subquery(subquery.clone()))
}
- Ok::<(), DataFusionError>(())
+ _ => Ok(TreeNodeRecursion::Continue),
})
- })?;
- Ok(())
+ })
}
- /// applies visitor to any subqueries in the plan
- pub(crate) fn visit_subqueries<V>(&self, v: &mut V) -> Result<()>
- where
- V: TreeNodeVisitor<Node = LogicalPlan>,
- {
- self.inspect_expressions(|expr| {
- // recursively look for subqueries
- inspect_expr_pre(expr, |expr| {
- match expr {
- Expr::Exists(Exists { subquery, .. })
- | Expr::InSubquery(InSubquery { subquery, .. })
- | Expr::ScalarSubquery(subquery) => {
- // use a synthetic plan so the visitor sees a
- // LogicalPlan::Subquery (even though it is
- // actually a Subquery alias)
- let synthetic_plan =
LogicalPlan::Subquery(subquery.clone());
- synthetic_plan.visit(v)?;
- }
- _ => {}
+ fn map_subqueries<F: FnMut(Self) -> Result<Transformed<Self>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ self.map_expressions(|expr| {
+ expr.transform_down_mut(&mut |expr| match expr {
+ Expr::Exists(Exists { subquery, negated }) => {
+ f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s {
+ LogicalPlan::Subquery(subquery) => {
+ Ok(Expr::Exists(Exists { subquery, negated }))
+ }
+ _ => internal_err!("Transformation should return
Subquery"),
+ })
}
- Ok::<(), DataFusionError>(())
+ Expr::InSubquery(InSubquery {
+ expr,
+ subquery,
+ negated,
+ }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s
{
+ LogicalPlan::Subquery(subquery) =>
Ok(Expr::InSubquery(InSubquery {
+ expr,
+ subquery,
+ negated,
+ })),
+ _ => internal_err!("Transformation should return
Subquery"),
+ }),
+ Expr::ScalarSubquery(subquery) =>
f(LogicalPlan::Subquery(subquery))?
+ .map_data(|s| match s {
+ LogicalPlan::Subquery(subquery) => {
+ Ok(Expr::ScalarSubquery(subquery))
+ }
+ _ => internal_err!("Transformation should return
Subquery"),
+ }),
+ _ => Ok(Transformed::no(expr)),
})
- })?;
- Ok(())
+ })
}
/// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
@@ -1165,8 +1524,8 @@ impl LogicalPlan {
) -> Result<HashMap<String, Option<DataType>>, DataFusionError> {
let mut param_types: HashMap<String, Option<DataType>> =
HashMap::new();
- self.apply(&mut |plan| {
- plan.inspect_expressions(|expr| {
+ self.apply_with_subqueries(&mut |plan| {
+ plan.apply_expressions(|expr| {
expr.apply(&mut |expr| {
if let Expr::Placeholder(Placeholder { id, data_type }) =
expr {
let prev = param_types.get(id);
@@ -1183,13 +1542,10 @@ impl LogicalPlan {
}
}
Ok(TreeNodeRecursion::Continue)
- })?;
- Ok::<(), DataFusionError>(())
- })?;
- Ok(TreeNodeRecursion::Continue)
- })?;
-
- Ok(param_types)
+ })
+ })
+ })
+ .map(|_| param_types)
}
/// Return an Expr with all placeholders replaced with their
@@ -1257,7 +1613,7 @@ impl LogicalPlan {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let with_schema = false;
let mut visitor = IndentVisitor::new(f, with_schema);
- match self.0.visit(&mut visitor) {
+ match self.0.visit_with_subqueries(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
@@ -1300,7 +1656,7 @@ impl LogicalPlan {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let with_schema = true;
let mut visitor = IndentVisitor::new(f, with_schema);
- match self.0.visit(&mut visitor) {
+ match self.0.visit_with_subqueries(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
@@ -1320,7 +1676,7 @@ impl LogicalPlan {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut visitor = PgJsonVisitor::new(f);
visitor.with_schema(true);
- match self.0.visit(&mut visitor) {
+ match self.0.visit_with_subqueries(&mut visitor) {
Ok(_) => Ok(()),
Err(_) => Err(fmt::Error),
}
@@ -1369,12 +1725,16 @@ impl LogicalPlan {
visitor.start_graph()?;
visitor.pre_visit_plan("LogicalPlan")?;
- self.0.visit(&mut visitor).map_err(|_| fmt::Error)?;
+ self.0
+ .visit_with_subqueries(&mut visitor)
+ .map_err(|_| fmt::Error)?;
visitor.post_visit_plan()?;
visitor.set_with_schema(true);
visitor.pre_visit_plan("Detailed LogicalPlan")?;
- self.0.visit(&mut visitor).map_err(|_| fmt::Error)?;
+ self.0
+ .visit_with_subqueries(&mut visitor)
+ .map_err(|_| fmt::Error)?;
visitor.post_visit_plan()?;
visitor.end_graph()?;
@@ -2908,7 +3268,7 @@ digraph {
fn visit_order() {
let mut visitor = OkVisitor::default();
let plan = test_plan();
- let res = plan.visit(&mut visitor);
+ let res = plan.visit_with_subqueries(&mut visitor);
assert!(res.is_ok());
assert_eq!(
@@ -2984,7 +3344,7 @@ digraph {
..Default::default()
};
let plan = test_plan();
- let res = plan.visit(&mut visitor);
+ let res = plan.visit_with_subqueries(&mut visitor);
assert!(res.is_ok());
assert_eq!(
@@ -3000,7 +3360,7 @@ digraph {
..Default::default()
};
let plan = test_plan();
- let res = plan.visit(&mut visitor);
+ let res = plan.visit_with_subqueries(&mut visitor);
assert!(res.is_ok());
assert_eq!(
@@ -3051,7 +3411,7 @@ digraph {
..Default::default()
};
let plan = test_plan();
- let res = plan.visit(&mut visitor).unwrap_err();
+ let res = plan.visit_with_subqueries(&mut visitor).unwrap_err();
assert_eq!(
"This feature is not implemented: Error in pre_visit",
res.strip_backtrace()
@@ -3069,7 +3429,7 @@ digraph {
..Default::default()
};
let plan = test_plan();
- let res = plan.visit(&mut visitor).unwrap_err();
+ let res = plan.visit_with_subqueries(&mut visitor).unwrap_err();
assert_eq!(
"This feature is not implemented: Error in post_visit",
res.strip_backtrace()
diff --git a/datafusion/expr/src/tree_node/expr.rs
b/datafusion/expr/src/tree_node/expr.rs
index 97331720ce..85097f6249 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -412,7 +412,7 @@ where
}
/// &mut transform a Option<`Vec` of `Expr`s>
-fn transform_option_vec<F>(
+pub fn transform_option_vec<F>(
ove: Option<Vec<Expr>>,
f: &mut F,
) -> Result<Transformed<Option<Vec<Expr>>>>
diff --git a/datafusion/expr/src/tree_node/plan.rs
b/datafusion/expr/src/tree_node/plan.rs
index 7a6b1005fe..482fc96b51 100644
--- a/datafusion/expr/src/tree_node/plan.rs
+++ b/datafusion/expr/src/tree_node/plan.rs
@@ -20,58 +20,11 @@
use crate::LogicalPlan;
use datafusion_common::tree_node::{
- Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
TreeNodeVisitor,
+ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
};
use datafusion_common::Result;
impl TreeNode for LogicalPlan {
- fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
- &self,
- f: &mut F,
- ) -> Result<TreeNodeRecursion> {
- // Compared to the default implementation, we need to invoke
- // [`Self::apply_subqueries`] before visiting its children
- f(self)?.visit_children(|| {
- self.apply_subqueries(f)?;
- self.apply_children(|n| n.apply(f))
- })
- }
-
- /// To use, define a struct that implements the trait [`TreeNodeVisitor`]
and then invoke
- /// [`LogicalPlan::visit`].
- ///
- /// For example, for a logical plan like:
- ///
- /// ```text
- /// Projection: id
- /// Filter: state Eq Utf8(\"CO\")\
- /// CsvScan: employee.csv projection=Some([0, 3])";
- /// ```
- ///
- /// The sequence of visit operations would be:
- /// ```text
- /// visitor.pre_visit(Projection)
- /// visitor.pre_visit(Filter)
- /// visitor.pre_visit(CsvScan)
- /// visitor.post_visit(CsvScan)
- /// visitor.post_visit(Filter)
- /// visitor.post_visit(Projection)
- /// ```
- fn visit<V: TreeNodeVisitor<Node = Self>>(
- &self,
- visitor: &mut V,
- ) -> Result<TreeNodeRecursion> {
- // Compared to the default implementation, we need to invoke
- // [`Self::visit_subqueries`] before visiting its children
- visitor
- .f_down(self)?
- .visit_children(|| {
- self.visit_subqueries(visitor)?;
- self.apply_children(|n| n.visit(visitor))
- })?
- .visit_parent(|| visitor.f_up(self))
- }
-
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: F,
@@ -85,8 +38,8 @@ impl TreeNode for LogicalPlan {
) -> Result<Transformed<Self>> {
let new_children = self
.inputs()
- .iter()
- .map(|&c| c.clone())
+ .into_iter()
+ .cloned()
.map_until_stop_and_collect(f)?;
// Propagate up `new_children.transformed` and `new_children.tnr`
// along with the node containing transformed children.
diff --git a/datafusion/optimizer/src/analyzer/mod.rs
b/datafusion/optimizer/src/analyzer/mod.rs
index b446fe2f32..d0b83d2429 100644
--- a/datafusion/optimizer/src/analyzer/mod.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -155,8 +155,8 @@ impl Analyzer {
/// Do necessary check and fail the invalid plan
fn check_plan(plan: &LogicalPlan) -> Result<()> {
- plan.apply(&mut |plan: &LogicalPlan| {
- plan.inspect_expressions(|expr| {
+ plan.apply_with_subqueries(&mut |plan: &LogicalPlan| {
+ plan.apply_expressions(|expr| {
// recursively look for subqueries
expr.apply(&mut |expr| {
match expr {
@@ -168,11 +168,8 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> {
_ => {}
};
Ok(TreeNodeRecursion::Continue)
- })?;
- Ok::<(), DataFusionError>(())
- })?;
- Ok(TreeNodeRecursion::Continue)
- })?;
-
- Ok(())
+ })
+ })
+ })
+ .map(|_| ())
}
diff --git a/datafusion/optimizer/src/analyzer/subquery.rs
b/datafusion/optimizer/src/analyzer/subquery.rs
index 038361c3ee..79375e52da 100644
--- a/datafusion/optimizer/src/analyzer/subquery.rs
+++ b/datafusion/optimizer/src/analyzer/subquery.rs
@@ -283,7 +283,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) ->
&LogicalPlan {
fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
let mut exprs = vec![];
- inner_plan.apply(&mut |plan| {
+ inner_plan.apply_with_subqueries(&mut |plan| {
if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
let (correlated, _): (Vec<_>, Vec<_>) =
split_conjunction(predicate)
.into_iter()
diff --git a/datafusion/optimizer/src/plan_signature.rs
b/datafusion/optimizer/src/plan_signature.rs
index 4143d52a05..a8e323ff42 100644
--- a/datafusion/optimizer/src/plan_signature.rs
+++ b/datafusion/optimizer/src/plan_signature.rs
@@ -21,7 +21,7 @@ use std::{
num::NonZeroUsize,
};
-use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+use datafusion_common::tree_node::TreeNodeRecursion;
use datafusion_expr::LogicalPlan;
/// Non-unique identifier of a [`LogicalPlan`].
@@ -73,7 +73,7 @@ impl LogicalPlanSignature {
/// Get total number of [`LogicalPlan`]s in the plan.
fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize {
let mut node_number = 0;
- plan.apply(&mut |_plan| {
+ plan.apply_with_subqueries(&mut |_plan| {
node_number += 1;
Ok(TreeNodeRecursion::Continue)
})