This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e937abce Remove expr_sub_expressions and rewrite_expression functions 
(#2772)
5e937abce is described below

commit 5e937abceff5a2b4b7e7981fdabe3220ebd81c7f
Author: Mike Roberts <[email protected]>
AuthorDate: Thu Jun 23 17:22:25 2022 +0100

    Remove expr_sub_expressions and rewrite_expression functions (#2772)
    
    * Remove expr_sub_expressions and rewrite_expression functions
    
    * Simplify return logic in rewrite_column_expr function
    
    Co-authored-by: Andy Grove <[email protected]>
    
    * Fix lint error in pruning.rs
    
    Co-authored-by: Andy Grove <[email protected]>
---
 datafusion/core/src/physical_optimizer/pruning.rs  |  32 ++-
 datafusion/optimizer/src/filter_push_down.rs       |  32 ++-
 .../optimizer/src/subquery_filter_to_join.rs       |  47 +++-
 datafusion/optimizer/src/utils.rs                  | 290 +--------------------
 4 files changed, 77 insertions(+), 324 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/pruning.rs 
b/datafusion/core/src/physical_optimizer/pruning.rs
index 0c3e20808..4b0d04b54 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -36,7 +36,6 @@ use crate::prelude::lit;
 use crate::{
     error::{DataFusionError, Result},
     logical_plan::{Column, DFSchema, Expr, Operator},
-    optimizer::utils,
     physical_plan::{ColumnarValue, PhysicalExpr},
 };
 use arrow::{
@@ -45,6 +44,7 @@ use arrow::{
     record_batch::RecordBatch,
 };
 use datafusion_expr::binary_expr;
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
 use datafusion_expr::utils::expr_to_columns;
 use datafusion_physical_expr::create_physical_expr;
 
@@ -283,7 +283,7 @@ impl RequiredStatColumns {
             // only add statistics column if not previously added
             self.columns.push((column.clone(), stat_type, stat_field));
         }
-        rewrite_column_expr(column_expr, column, &stat_column)
+        rewrite_column_expr(column_expr.clone(), column, &stat_column)
     }
 
     /// rewrite col --> col_min
@@ -553,22 +553,28 @@ fn is_compare_op(op: Operator) -> bool {
 
 /// replaces a column with an old name with a new name in an expression
 fn rewrite_column_expr(
-    expr: &Expr,
+    e: Expr,
     column_old: &Column,
     column_new: &Column,
 ) -> Result<Expr> {
-    let expressions = utils::expr_sub_expressions(expr)?;
-    let expressions = expressions
-        .iter()
-        .map(|e| rewrite_column_expr(e, column_old, column_new))
-        .collect::<Result<Vec<_>>>()?;
-
-    if let Expr::Column(c) = expr {
-        if c == column_old {
-            return Ok(Expr::Column(column_new.clone()));
+    struct ColumnReplacer<'a> {
+        old: &'a Column,
+        new: &'a Column,
+    }
+
+    impl<'a> ExprRewriter for ColumnReplacer<'a> {
+        fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+            match expr {
+                Expr::Column(c) if c == *self.old => 
Ok(Expr::Column(self.new.clone())),
+                _ => Ok(expr),
+            }
         }
     }
-    utils::rewrite_expression(expr, &expressions)
+
+    e.rewrite(&mut ColumnReplacer {
+        old: column_old,
+        new: column_new,
+    })
 }
 
 fn reverse_operator(op: Operator) -> Operator {
diff --git a/datafusion/optimizer/src/filter_push_down.rs 
b/datafusion/optimizer/src/filter_push_down.rs
index eded1af3e..bd44ebea1 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/filter_push_down.rs
@@ -18,7 +18,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule};
 use datafusion_common::{Column, DFSchema, Result};
 use datafusion_expr::{
     col,
-    expr_rewriter::replace_col,
+    expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
     logical_plan::{
         Aggregate, CrossJoin, Filter, Join, JoinType, Limit, LogicalPlan, 
Projection,
         TableScan, Union,
@@ -393,7 +393,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> 
Result<LogicalPlan> {
             // re-write all filters based on this projection
             // E.g. in `Filter: #b\n  Projection: #a > 1 as b`, we can swap 
them, but the filter must be "#a > 1"
             for (predicate, columns) in state.filters.iter_mut() {
-                *predicate = rewrite(predicate, &projection)?;
+                *predicate = replace_cols_by_name(predicate.clone(), 
&projection)?;
 
                 columns.clear();
                 expr_to_columns(predicate, columns)?;
@@ -443,7 +443,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> 
Result<LogicalPlan> {
             // rewriting predicate expressions using unqualified names as 
replacements
             if !projection.is_empty() {
                 for (predicate, columns) in state.filters.iter_mut() {
-                    *predicate = rewrite(predicate, &projection)?;
+                    *predicate = replace_cols_by_name(predicate.clone(), 
&projection)?;
 
                     columns.clear();
                     expr_to_columns(predicate, columns)?;
@@ -629,21 +629,25 @@ impl FilterPushDown {
 }
 
 /// replaces columns by its name on the projection.
-fn rewrite(expr: &Expr, projection: &HashMap<String, Expr>) -> Result<Expr> {
-    let expressions = utils::expr_sub_expressions(expr)?;
-
-    let expressions = expressions
-        .iter()
-        .map(|e| rewrite(e, projection))
-        .collect::<Result<Vec<_>>>()?;
+fn replace_cols_by_name(e: Expr, replace_map: &HashMap<String, Expr>) -> 
Result<Expr> {
+    struct ColumnReplacer<'a> {
+        replace_map: &'a HashMap<String, Expr>,
+    }
 
-    if let Expr::Column(c) = expr {
-        if let Some(expr) = projection.get(&c.flat_name()) {
-            return Ok(expr.clone());
+    impl<'a> ExprRewriter for ColumnReplacer<'a> {
+        fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+            if let Expr::Column(c) = &expr {
+                match self.replace_map.get(&c.flat_name()) {
+                    Some(new_c) => Ok(new_c.clone()),
+                    None => Ok(expr),
+                }
+            } else {
+                Ok(expr)
+            }
         }
     }
 
-    utils::rewrite_expression(expr, &expressions)
+    e.rewrite(&mut ColumnReplacer { replace_map })
 }
 
 #[cfg(test)]
diff --git a/datafusion/optimizer/src/subquery_filter_to_join.rs 
b/datafusion/optimizer/src/subquery_filter_to_join.rs
index f2621e190..1a61f92df 100644
--- a/datafusion/optimizer/src/subquery_filter_to_join.rs
+++ b/datafusion/optimizer/src/subquery_filter_to_join.rs
@@ -29,6 +29,7 @@
 use crate::{utils, OptimizerConfig, OptimizerRule};
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::{
+    expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion},
     logical_plan::{
         builder::build_join_schema, Filter, Join, JoinConstraint, JoinType, 
LogicalPlan,
     },
@@ -177,15 +178,21 @@ impl OptimizerRule for SubqueryFilterToJoin {
 }
 
 fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec<Expr>) -> 
Result<()> {
-    utils::expr_sub_expressions(expression)?
-        .into_iter()
-        .try_for_each(|se| match se {
-            Expr::InSubquery { .. } => {
-                extracted.push(se);
-                Ok(())
+    struct InSubqueryVisitor<'a> {
+        accum: &'a mut Vec<Expr>,
+    }
+
+    impl ExpressionVisitor for InSubqueryVisitor<'_> {
+        fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
+            if let Expr::InSubquery { .. } = expr {
+                self.accum.push(expr.to_owned());
             }
-            _ => extract_subquery_filters(&se, extracted),
-        })
+            Ok(Recursion::Continue(self))
+        }
+    }
+
+    expression.accept(InSubqueryVisitor { accum: extracted })?;
+    Ok(())
 }
 
 #[cfg(test)]
@@ -330,6 +337,30 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn in_subquery_with_and_or_filters() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(and(
+                or(
+                    binary_expr(col("a"), Operator::Eq, lit(1_u32)),
+                    in_subquery(col("b"), test_subquery_with_name("sq1")?),
+                ),
+                in_subquery(col("c"), test_subquery_with_name("sq2")?),
+            ))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: #test.b [b:UInt32]\
+        \n  Filter: #test.a = UInt32(1) OR #test.b IN (Subquery: Projection: 
#sq1.c\
+            \n  TableScan: sq1 projection=None) AND #test.c IN (Subquery: 
Projection: #sq2.c\
+            \n  TableScan: sq2 projection=None) [a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
     /// Test for nested IN subqueries
     #[test]
     fn in_subquery_nested() -> Result<()> {
diff --git a/datafusion/optimizer/src/utils.rs 
b/datafusion/optimizer/src/utils.rs
index dfde370a0..e0c988e07 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -18,22 +18,15 @@
 //! Collection of utility functions that are leveraged by the query optimizer 
rules
 
 use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_common::Result;
 use datafusion_expr::{
     and,
-    expr::GroupingSet,
-    lit,
     logical_plan::{Filter, LogicalPlan},
     utils::from_plan,
     Expr, Operator,
 };
 use std::sync::Arc;
 
-const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
-const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__";
-const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__";
-const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__";
-
 /// Convenience rule for writing optimizers: recursively invoke
 /// optimize on plan's children and then return a node of the same
 /// type. Useful for optimizer rules which want to leave the type
@@ -54,287 +47,6 @@ pub fn optimize_children(
     from_plan(plan, &new_exprs, &new_inputs)
 }
 
-/// Returns all direct children `Expression`s of `expr`.
-/// E.g. if the expression is "(a + 1) + 1", it returns ["a + 1", "1"] (as 
Expr objects)
-pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
-    match expr {
-        Expr::BinaryExpr { left, right, .. } => {
-            Ok(vec![left.as_ref().to_owned(), right.as_ref().to_owned()])
-        }
-        Expr::IsNull(expr)
-        | Expr::IsNotNull(expr)
-        | Expr::Cast { expr, .. }
-        | Expr::TryCast { expr, .. }
-        | Expr::Alias(expr, ..)
-        | Expr::Not(expr)
-        | Expr::Negative(expr)
-        | Expr::Sort { expr, .. }
-        | Expr::GetIndexedField { expr, .. } => 
Ok(vec![expr.as_ref().to_owned()]),
-        Expr::ScalarFunction { args, .. }
-        | Expr::ScalarUDF { args, .. }
-        | Expr::AggregateFunction { args, .. }
-        | Expr::AggregateUDF { args, .. } => Ok(args.clone()),
-        Expr::GroupingSet(grouping_set) => match grouping_set {
-            GroupingSet::Rollup(exprs) => Ok(exprs.clone()),
-            GroupingSet::Cube(exprs) => Ok(exprs.clone()),
-            GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
-                "GroupingSets are not supported yet".to_string(),
-            )),
-        },
-        Expr::WindowFunction {
-            args,
-            partition_by,
-            order_by,
-            ..
-        } => {
-            let mut expr_list: Vec<Expr> = vec![];
-            expr_list.extend(args.clone());
-            expr_list.push(lit(WINDOW_PARTITION_MARKER));
-            expr_list.extend(partition_by.clone());
-            expr_list.push(lit(WINDOW_SORT_MARKER));
-            expr_list.extend(order_by.clone());
-            Ok(expr_list)
-        }
-        Expr::Case {
-            expr,
-            when_then_expr,
-            else_expr,
-            ..
-        } => {
-            let mut expr_list: Vec<Expr> = vec![];
-            if let Some(e) = expr {
-                expr_list.push(lit(CASE_EXPR_MARKER));
-                expr_list.push(e.as_ref().to_owned());
-            }
-            for (w, t) in when_then_expr {
-                expr_list.push(w.as_ref().to_owned());
-                expr_list.push(t.as_ref().to_owned());
-            }
-            if let Some(e) = else_expr {
-                expr_list.push(lit(CASE_ELSE_MARKER));
-                expr_list.push(e.as_ref().to_owned());
-            }
-            Ok(expr_list)
-        }
-        Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_, _) => 
Ok(vec![]),
-        Expr::Between {
-            expr, low, high, ..
-        } => Ok(vec![
-            expr.as_ref().to_owned(),
-            low.as_ref().to_owned(),
-            high.as_ref().to_owned(),
-        ]),
-        Expr::InList { expr, list, .. } => {
-            let mut expr_list: Vec<Expr> = vec![expr.as_ref().to_owned()];
-            for list_expr in list {
-                expr_list.push(list_expr.to_owned());
-            }
-            Ok(expr_list)
-        }
-        Expr::Exists { .. } => Ok(vec![]),
-        Expr::InSubquery { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
-        Expr::ScalarSubquery(_) => Ok(vec![]),
-        Expr::Wildcard { .. } => Err(DataFusionError::Internal(
-            "Wildcard expressions are not valid in a logical query 
plan".to_owned(),
-        )),
-        Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal(
-            "QualifiedWildcard expressions are not valid in a logical query 
plan"
-                .to_owned(),
-        )),
-    }
-}
-
-/// returns a new expression where the expressions in `expr` are replaced by 
the ones in
-/// `expressions`.
-/// This is used in conjunction with ``expr_expressions`` to re-write 
expressions.
-pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
-    match expr {
-        Expr::BinaryExpr { op, .. } => Ok(Expr::BinaryExpr {
-            left: Box::new(expressions[0].clone()),
-            op: *op,
-            right: Box::new(expressions[1].clone()),
-        }),
-        Expr::IsNull(_) => Ok(Expr::IsNull(Box::new(expressions[0].clone()))),
-        Expr::IsNotNull(_) => 
Ok(Expr::IsNotNull(Box::new(expressions[0].clone()))),
-        Expr::ScalarFunction { fun, .. } => Ok(Expr::ScalarFunction {
-            fun: fun.clone(),
-            args: expressions.to_vec(),
-        }),
-        Expr::ScalarUDF { fun, .. } => Ok(Expr::ScalarUDF {
-            fun: fun.clone(),
-            args: expressions.to_vec(),
-        }),
-        Expr::WindowFunction {
-            fun, window_frame, ..
-        } => {
-            let partition_index = expressions
-                .iter()
-                .position(|expr| {
-                    matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str)))
-            if str == WINDOW_PARTITION_MARKER)
-                })
-                .ok_or_else(|| {
-                    DataFusionError::Internal(
-                        "Ill-formed window function expressions: unexpected 
marker"
-                            .to_owned(),
-                    )
-                })?;
-
-            let sort_index = expressions
-                .iter()
-                .position(|expr| {
-                    matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str)))
-            if str == WINDOW_SORT_MARKER)
-                })
-                .ok_or_else(|| {
-                    DataFusionError::Internal(
-                        "Ill-formed window function expressions".to_owned(),
-                    )
-                })?;
-
-            if partition_index >= sort_index {
-                Err(DataFusionError::Internal(
-                    "Ill-formed window function expressions: partition index 
too large"
-                        .to_owned(),
-                ))
-            } else {
-                Ok(Expr::WindowFunction {
-                    fun: fun.clone(),
-                    args: expressions[..partition_index].to_vec(),
-                    partition_by: expressions[partition_index + 
1..sort_index].to_vec(),
-                    order_by: expressions[sort_index + 1..].to_vec(),
-                    window_frame: *window_frame,
-                })
-            }
-        }
-        Expr::AggregateFunction { fun, distinct, .. } => 
Ok(Expr::AggregateFunction {
-            fun: fun.clone(),
-            args: expressions.to_vec(),
-            distinct: *distinct,
-        }),
-        Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF {
-            fun: fun.clone(),
-            args: expressions.to_vec(),
-        }),
-        Expr::GroupingSet(grouping_set) => match grouping_set {
-            GroupingSet::Rollup(_exprs) => {
-                
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
-            }
-            GroupingSet::Cube(_exprs) => {
-                
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
-            }
-            GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
-                "GroupingSets are not supported yet".to_string(),
-            )),
-        },
-        Expr::Case { .. } => {
-            let mut base_expr: Option<Box<Expr>> = None;
-            let mut when_then: Vec<(Box<Expr>, Box<Expr>)> = vec![];
-            let mut else_expr: Option<Box<Expr>> = None;
-            let mut i = 0;
-
-            while i < expressions.len() {
-                match &expressions[i] {
-                    Expr::Literal(ScalarValue::Utf8(Some(str)))
-                        if str == CASE_EXPR_MARKER =>
-                    {
-                        base_expr = Some(Box::new(expressions[i + 1].clone()));
-                        i += 2;
-                    }
-                    Expr::Literal(ScalarValue::Utf8(Some(str)))
-                        if str == CASE_ELSE_MARKER =>
-                    {
-                        else_expr = Some(Box::new(expressions[i + 1].clone()));
-                        i += 2;
-                    }
-                    _ => {
-                        when_then.push((
-                            Box::new(expressions[i].clone()),
-                            Box::new(expressions[i + 1].clone()),
-                        ));
-                        i += 2;
-                    }
-                }
-            }
-
-            Ok(Expr::Case {
-                expr: base_expr,
-                when_then_expr: when_then,
-                else_expr,
-            })
-        }
-        Expr::Cast { data_type, .. } => Ok(Expr::Cast {
-            expr: Box::new(expressions[0].clone()),
-            data_type: data_type.clone(),
-        }),
-        Expr::TryCast { data_type, .. } => Ok(Expr::TryCast {
-            expr: Box::new(expressions[0].clone()),
-            data_type: data_type.clone(),
-        }),
-        Expr::Alias(_, alias) => {
-            Ok(Expr::Alias(Box::new(expressions[0].clone()), alias.clone()))
-        }
-        Expr::Not(_) => Ok(Expr::Not(Box::new(expressions[0].clone()))),
-        Expr::Negative(_) => 
Ok(Expr::Negative(Box::new(expressions[0].clone()))),
-        Expr::InList { list, negated, .. } => Ok(Expr::InList {
-            expr: Box::new(expressions[0].clone()),
-            list: list.clone(),
-            negated: *negated,
-        }),
-        Expr::InSubquery {
-            subquery, negated, ..
-        } => Ok(Expr::InSubquery {
-            expr: Box::new(expressions[0].clone()),
-            subquery: subquery.clone(),
-            negated: *negated,
-        }),
-        Expr::Column(_)
-        | Expr::Literal(_)
-        | Expr::Exists { .. }
-        | Expr::ScalarSubquery(_)
-        | Expr::ScalarVariable(_, _) => Ok(expr.clone()),
-        Expr::Sort {
-            asc, nulls_first, ..
-        } => Ok(Expr::Sort {
-            expr: Box::new(expressions[0].clone()),
-            asc: *asc,
-            nulls_first: *nulls_first,
-        }),
-        Expr::Between { negated, .. } => {
-            let expr = Expr::BinaryExpr {
-                left: Box::new(Expr::BinaryExpr {
-                    left: Box::new(expressions[0].clone()),
-                    op: Operator::GtEq,
-                    right: Box::new(expressions[1].clone()),
-                }),
-                op: Operator::And,
-                right: Box::new(Expr::BinaryExpr {
-                    left: Box::new(expressions[0].clone()),
-                    op: Operator::LtEq,
-                    right: Box::new(expressions[2].clone()),
-                }),
-            };
-
-            if *negated {
-                Ok(Expr::Not(Box::new(expr)))
-            } else {
-                Ok(expr)
-            }
-        }
-        Expr::Wildcard => Err(DataFusionError::Internal(
-            "Wildcard expressions are not valid in a logical query 
plan".to_owned(),
-        )),
-        Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal(
-            "QualifiedWildcard expressions are not valid in a logical query 
plan"
-                .to_owned(),
-        )),
-        Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField {
-            expr: Box::new(expressions[0].clone()),
-            key: key.clone(),
-        }),
-    }
-}
-
 /// converts "A AND B AND C" => [A, B, C]
 pub fn split_conjunction<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a 
Expr>) {
     match predicate {

Reply via email to