This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new de7f15bf1 Minor: Use ExprVisitor to find columns referenced by expr 
(#2471)
de7f15bf1 is described below

commit de7f15bf14360508873e5f3e9c851b2efca2b78f
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri May 6 14:06:45 2022 -0400

    Minor: Use ExprVisitor to find columns referenced by expr (#2471)
---
 datafusion/core/src/sql/utils.rs | 25 +++++++++++++-
 datafusion/expr/src/expr.rs      | 71 ----------------------------------------
 2 files changed, 24 insertions(+), 72 deletions(-)

diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs
index 4acaa21ef..0293e2410 100644
--- a/datafusion/core/src/sql/utils.rs
+++ b/datafusion/core/src/sql/utils.rs
@@ -27,7 +27,6 @@ use crate::{
     error::{DataFusionError, Result},
     logical_plan::{Column, ExpressionVisitor, Recursion},
 };
-use datafusion_expr::expr::find_columns_referenced_by_expr;
 use std::collections::HashMap;
 
 /// Collect all deeply nested `Expr::AggregateFunction` and
@@ -86,6 +85,30 @@ where
         })
 }
 
+/// Recursively find all columns referenced by an expression
+#[derive(Debug, Default)]
+struct ColumnCollector {
+    exprs: Vec<Column>,
+}
+
+impl ExpressionVisitor for ColumnCollector {
+    fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
+        if let Expr::Column(c) = expr {
+            self.exprs.push(c.clone())
+        }
+        Ok(Recursion::Continue(self))
+    }
+}
+
+fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
+    // As the `ExpressionVisitor` impl above always returns Ok, this
+    // "can't" error
+    let ColumnCollector { exprs } = e
+        .accept(ColumnCollector::default())
+        .expect("Unexpected error");
+    exprs
+}
+
 // Visitor that find expressions that match a particular predicate
 struct Finder<'a, F>
 where
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 7e1adac43..4d88ed815 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -251,77 +251,6 @@ pub enum Expr {
     QualifiedWildcard { qualifier: String },
 }
 
-/// Recursively find all columns referenced by an expression
-pub fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
-    match e {
-        Expr::Alias(expr, _)
-        | Expr::Negative(expr)
-        | Expr::Cast { expr, .. }
-        | Expr::TryCast { expr, .. }
-        | Expr::Sort { expr, .. }
-        | Expr::InList { expr, .. }
-        | Expr::InSubquery { expr, .. }
-        | Expr::GetIndexedField { expr, .. }
-        | Expr::Not(expr)
-        | Expr::IsNotNull(expr)
-        | Expr::IsNull(expr) => find_columns_referenced_by_expr(expr),
-        Expr::Column(c) => vec![c.clone()],
-        Expr::BinaryExpr { left, right, .. } => {
-            let mut cols = vec![];
-            cols.extend(find_columns_referenced_by_expr(left.as_ref()));
-            cols.extend(find_columns_referenced_by_expr(right.as_ref()));
-            cols
-        }
-        Expr::Case {
-            expr,
-            when_then_expr,
-            else_expr,
-        } => {
-            let mut cols = vec![];
-            if let Some(expr) = expr {
-                cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
-            }
-            for (w, t) in when_then_expr {
-                cols.extend(find_columns_referenced_by_expr(w.as_ref()));
-                cols.extend(find_columns_referenced_by_expr(t.as_ref()));
-            }
-            if let Some(else_expr) = else_expr {
-                
cols.extend(find_columns_referenced_by_expr(else_expr.as_ref()));
-            }
-            cols
-        }
-        Expr::ScalarFunction { args, .. } => args
-            .iter()
-            .flat_map(find_columns_referenced_by_expr)
-            .collect(),
-        Expr::AggregateFunction { args, .. } => args
-            .iter()
-            .flat_map(find_columns_referenced_by_expr)
-            .collect(),
-        Expr::ScalarVariable(_, _)
-        | Expr::Exists { .. }
-        | Expr::Wildcard
-        | Expr::QualifiedWildcard { .. }
-        | Expr::ScalarSubquery(_)
-        | Expr::Literal(_) => vec![],
-        Expr::Between {
-            expr, low, high, ..
-        } => {
-            let mut cols = vec![];
-            cols.extend(find_columns_referenced_by_expr(expr.as_ref()));
-            cols.extend(find_columns_referenced_by_expr(low.as_ref()));
-            cols.extend(find_columns_referenced_by_expr(high.as_ref()));
-            cols
-        }
-        Expr::ScalarUDF { args, .. }
-        | Expr::WindowFunction { args, .. }
-        | Expr::AggregateUDF { args, .. } => args
-            .iter()
-            .flat_map(find_columns_referenced_by_expr)
-            .collect(),
-    }
-}
-
 /// Fixed seed for the hashing so that Ords are consistent across runs
 const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);
 

Reply via email to