This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 648294f91 Move expression utils from sql module to expr crate (#2553)
648294f91 is described below

commit 648294f91f0caf41bc19cab1dd62c3d7bfa8fa2b
Author: Andy Grove <[email protected]>
AuthorDate: Tue May 17 19:16:48 2022 -0600

    Move expression utils from sql module to expr crate (#2553)
---
 datafusion/core/src/dataframe.rs                   |   3 +-
 datafusion/core/src/logical_plan/builder.rs        |  10 +-
 .../core/src/optimizer/projection_push_down.rs     |   3 +-
 datafusion/core/src/physical_plan/planner.rs       |   3 +-
 datafusion/core/src/sql/planner.rs                 |   8 +-
 datafusion/core/src/sql/utils.rs                   | 330 ---------------------
 datafusion/expr/src/utils.rs                       | 325 ++++++++++++++++++++
 7 files changed, 339 insertions(+), 343 deletions(-)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 369c2ae93..d0670bb28 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -34,12 +34,11 @@ use crate::arrow::datatypes::SchemaRef;
 use crate::arrow::util::pretty;
 use crate::datasource::TableProvider;
 use crate::execution::context::{SessionState, TaskContext};
-use crate::logical_expr::TableType;
+use crate::logical_expr::{utils::find_window_exprs, TableType};
 use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, 
plan_to_parquet};
 use crate::physical_plan::{collect, collect_partitioned};
 use crate::physical_plan::{execute_stream, execute_stream_partitioned, 
ExecutionPlan};
 use crate::scalar::ScalarValue;
-use crate::sql::utils::find_window_exprs;
 use parking_lot::RwLock;
 use std::any::Any;
 
diff --git a/datafusion/core/src/logical_plan/builder.rs 
b/datafusion/core/src/logical_plan/builder.rs
index 9e8f17815..739d19494 100644
--- a/datafusion/core/src/logical_plan/builder.rs
+++ b/datafusion/core/src/logical_plan/builder.rs
@@ -29,6 +29,7 @@ use crate::scalar::ScalarValue;
 use arrow::datatypes::{DataType, Schema};
 use datafusion_expr::utils::{
     expand_qualified_wildcard, expand_wildcard, expr_to_columns,
+    group_window_expr_by_sort_keys,
 };
 use std::convert::TryFrom;
 use std::iter;
@@ -38,13 +39,12 @@ use std::{
 };
 
 use super::{Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
-use crate::logical_plan::expr::exprlist_to_fields;
 use crate::logical_plan::{
-    columnize_expr, normalize_col, normalize_cols, provider_as_source,
-    rewrite_sort_cols_by_aggs, Column, CrossJoin, DFField, DFSchema, 
DFSchemaRef, Limit,
-    Offset, Partitioning, Repartition, Values,
+    columnize_expr, exprlist_to_fields, normalize_col, normalize_cols,
+    provider_as_source, rewrite_sort_cols_by_aggs, Column, CrossJoin, DFField, 
DFSchema,
+    DFSchemaRef, Limit, Offset, Partitioning, Repartition, Values,
 };
-use crate::sql::utils::group_window_expr_by_sort_keys;
+
 use datafusion_common::ToDFSchema;
 
 /// Default table name for unnamed table
diff --git a/datafusion/core/src/optimizer/projection_push_down.rs 
b/datafusion/core/src/optimizer/projection_push_down.rs
index a1bd6efe1..20b8f683d 100644
--- a/datafusion/core/src/optimizer/projection_push_down.rs
+++ b/datafusion/core/src/optimizer/projection_push_down.rs
@@ -29,10 +29,9 @@ use crate::logical_plan::{
 };
 use crate::optimizer::optimizer::OptimizerRule;
 use crate::optimizer::utils;
-use crate::sql::utils::find_sort_exprs;
 use arrow::datatypes::{Field, Schema};
 use arrow::error::Result as ArrowResult;
-use datafusion_expr::utils::{expr_to_columns, exprlist_to_columns};
+use datafusion_expr::utils::{expr_to_columns, exprlist_to_columns, 
find_sort_exprs};
 use datafusion_expr::Expr;
 use std::{
     collections::{BTreeSet, HashSet},
diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index af844cd1e..911f9f67f 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -23,6 +23,7 @@ use super::{
     hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, 
windows,
 };
 use crate::execution::context::{ExecutionProps, SessionState};
+use crate::logical_expr::utils::generate_sort_key;
 use crate::logical_plan::plan::{
     source_as_provider, Aggregate, EmptyRelation, Filter, Join, Projection, 
Sort,
     SubqueryAlias, TableScan, Window,
@@ -52,7 +53,7 @@ use crate::physical_plan::windows::WindowAggExec;
 use crate::physical_plan::{join_utils, Partitioning};
 use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, 
WindowExpr};
 use crate::scalar::ScalarValue;
-use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys};
+use crate::sql::utils::window_expr_common_partition_keys;
 use crate::variable::VarType;
 use crate::{
     error::{DataFusionError, Result},
diff --git a/datafusion/core/src/sql/planner.rs 
b/datafusion/core/src/sql/planner.rs
index 2518de78a..85c2d8f0c 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -45,7 +45,9 @@ use crate::{
     sql::parser::{CreateExternalTable, Statement as DFStatement},
 };
 use arrow::datatypes::*;
-use datafusion_expr::utils::exprlist_to_columns;
+use datafusion_expr::utils::{
+    exprlist_to_columns, find_aggregate_exprs, find_window_exprs,
+};
 use datafusion_expr::{window_function::WindowFunction, BuiltinScalarFunction};
 use hashbrown::HashMap;
 
@@ -67,8 +69,8 @@ use super::{
     parser::DFParser,
     utils::{
         check_columns_satisfy_exprs, expr_as_column_expr, extract_aliases,
-        find_aggregate_exprs, find_column_exprs, find_window_exprs, 
rebase_expr,
-        resolve_aliases_to_exprs, resolve_positions_to_exprs,
+        find_column_exprs, rebase_expr, resolve_aliases_to_exprs,
+        resolve_positions_to_exprs,
     },
 };
 use crate::logical_plan::builder::project_with_alias;
diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs
index b2cf1f698..0b8e8d3a6 100644
--- a/datafusion/core/src/sql/utils.rs
+++ b/datafusion/core/src/sql/utils.rs
@@ -30,34 +30,6 @@ use crate::{
 use datafusion_expr::expr::GroupingSet;
 use std::collections::HashMap;
 
-/// Collect all deeply nested `Expr::AggregateFunction` and
-/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth
-/// first), with duplicates omitted.
-pub(crate) fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
-    find_exprs_in_exprs(exprs, &|nested_expr| {
-        matches!(
-            nested_expr,
-            Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
-        )
-    })
-}
-
-/// Collect all deeply nested `Expr::Sort`. They are returned in order of 
occurrence
-/// (depth first), with duplicates omitted.
-pub(crate) fn find_sort_exprs(exprs: &[Expr]) -> Vec<Expr> {
-    find_exprs_in_exprs(exprs, &|nested_expr| {
-        matches!(nested_expr, Expr::Sort { .. })
-    })
-}
-
-/// Collect all deeply nested `Expr::WindowFunction`. They are returned in 
order of occurrence
-/// (depth first), with duplicates omitted.
-pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
-    find_exprs_in_exprs(exprs, &|nested_expr| {
-        matches!(nested_expr, Expr::WindowFunction { .. })
-    })
-}
-
 /// Collect all deeply nested `Expr::Column`'s. They are returned in order of
 /// appearance (depth first), and may contain duplicates.
 pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
@@ -68,24 +40,6 @@ pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> 
{
         .collect()
 }
 
-/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
-/// pass the provided test. The returned `Expr`'s are deduplicated and returned
-/// in order of appearance (depth first).
-fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
-where
-    F: Fn(&Expr) -> bool,
-{
-    exprs
-        .iter()
-        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
-        .fold(vec![], |mut acc, expr| {
-            if !acc.contains(&expr) {
-                acc.push(expr)
-            }
-            acc
-        })
-}
-
 /// Recursively find all columns referenced by an expression
 #[derive(Debug, Default)]
 struct ColumnCollector {
@@ -110,59 +64,6 @@ pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> 
Vec<Column> {
     exprs
 }
 
-// Visitor that find expressions that match a particular predicate
-struct Finder<'a, F>
-where
-    F: Fn(&Expr) -> bool,
-{
-    test_fn: &'a F,
-    exprs: Vec<Expr>,
-}
-
-impl<'a, F> Finder<'a, F>
-where
-    F: Fn(&Expr) -> bool,
-{
-    /// Create a new finder with the `test_fn`
-    fn new(test_fn: &'a F) -> Self {
-        Self {
-            test_fn,
-            exprs: Vec::new(),
-        }
-    }
-}
-
-impl<'a, F> ExpressionVisitor for Finder<'a, F>
-where
-    F: Fn(&Expr) -> bool,
-{
-    fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
-        if (self.test_fn)(expr) {
-            if !(self.exprs.contains(expr)) {
-                self.exprs.push(expr.clone())
-            }
-            // stop recursing down this expr once we find a match
-            return Ok(Recursion::Stop(self));
-        }
-
-        Ok(Recursion::Continue(self))
-    }
-}
-
-/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
-/// provided test. The returned `Expr`'s are deduplicated and returned in order
-/// of appearance (depth first).
-fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
-where
-    F: Fn(&Expr) -> bool,
-{
-    let Finder { exprs, .. } = expr
-        .accept(Finder::new(test_fn))
-        // pre_visit always returns OK, so this will always too
-        .expect("no way to return error during recursion");
-    exprs
-}
-
 /// Convert any `Expr` to an `Expr::Column`.
 pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> 
Result<Expr> {
     match expr {
@@ -572,28 +473,6 @@ pub(crate) fn resolve_aliases_to_exprs(
     })
 }
 
-type WindowSortKey = Vec<Expr>;
-
-/// Generate a sort key for a given window expr's partition_by and order_bu 
expr
-pub(crate) fn generate_sort_key(
-    partition_by: &[Expr],
-    order_by: &[Expr],
-) -> WindowSortKey {
-    let mut sort_key = vec![];
-    partition_by.iter().for_each(|e| {
-        let e = e.clone().sort(true, true);
-        if !sort_key.contains(&e) {
-            sort_key.push(e);
-        }
-    });
-    order_by.iter().for_each(|e| {
-        if !sort_key.contains(e) {
-            sort_key.push(e.clone());
-        }
-    });
-    sort_key
-}
-
 /// given a slice of window expressions sharing the same sort key, find their 
common partition
 /// keys.
 pub(crate) fn window_expr_common_partition_keys(
@@ -618,31 +497,6 @@ pub(crate) fn window_expr_common_partition_keys(
     Ok(result)
 }
 
-/// group a slice of window expression expr by their order by expressions
-pub(crate) fn group_window_expr_by_sort_keys(
-    window_expr: &[Expr],
-) -> Result<Vec<(WindowSortKey, Vec<&Expr>)>> {
-    let mut result = vec![];
-    window_expr.iter().try_for_each(|expr| match expr {
-        Expr::WindowFunction { partition_by, order_by, .. } => {
-            let sort_key = generate_sort_key(partition_by, order_by);
-            if let Some((_, values)) = result.iter_mut().find(
-                |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, 
(key, _) if *key == sort_key),
-            ) {
-                values.push(expr);
-            } else {
-                result.push((sort_key, vec![expr]))
-            }
-            Ok(())
-        }
-        other => Err(DataFusionError::Internal(format!(
-            "Impossibly got non-window expr {:?}",
-            other,
-        ))),
-    })?;
-    Ok(result)
-}
-
 /// Returns a validated `DataType` for the specified precision and
 /// scale
 pub(crate) fn make_decimal_type(
@@ -677,187 +531,3 @@ pub(crate) fn normalize_ident(id: &Ident) -> String {
         None => id.value.to_ascii_lowercase(),
     }
 }
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use crate::logical_plan::col;
-    use crate::physical_plan::aggregates::AggregateFunction;
-    use datafusion_expr::window_function::WindowFunction;
-
-    #[test]
-    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
-        let result = group_window_expr_by_sort_keys(&[])?;
-        let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![];
-        assert_eq!(expected, result);
-        Ok(())
-    }
-
-    #[test]
-    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
-        let max1 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
-            args: vec![col("name")],
-            partition_by: vec![],
-            order_by: vec![],
-            window_frame: None,
-        };
-        let max2 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
-            args: vec![col("name")],
-            partition_by: vec![],
-            order_by: vec![],
-            window_frame: None,
-        };
-        let min3 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
-            args: vec![col("name")],
-            partition_by: vec![],
-            order_by: vec![],
-            window_frame: None,
-        };
-        let sum4 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
-            args: vec![col("age")],
-            partition_by: vec![],
-            order_by: vec![],
-            window_frame: None,
-        };
-        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
-        let result = group_window_expr_by_sort_keys(exprs)?;
-        let key = vec![];
-        let expected: Vec<(WindowSortKey, Vec<&Expr>)> =
-            vec![(key, vec![&max1, &max2, &min3, &sum4])];
-        assert_eq!(expected, result);
-        Ok(())
-    }
-
-    #[test]
-    fn test_group_window_expr_by_sort_keys() -> Result<()> {
-        let age_asc = Expr::Sort {
-            expr: Box::new(col("age")),
-            asc: true,
-            nulls_first: true,
-        };
-        let name_desc = Expr::Sort {
-            expr: Box::new(col("name")),
-            asc: false,
-            nulls_first: true,
-        };
-        let created_at_desc = Expr::Sort {
-            expr: Box::new(col("created_at")),
-            asc: false,
-            nulls_first: true,
-        };
-        let max1 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
-            args: vec![col("name")],
-            partition_by: vec![],
-            order_by: vec![age_asc.clone(), name_desc.clone()],
-            window_frame: None,
-        };
-        let max2 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
-            args: vec![col("name")],
-            partition_by: vec![],
-            order_by: vec![],
-            window_frame: None,
-        };
-        let min3 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
-            args: vec![col("name")],
-            partition_by: vec![],
-            order_by: vec![age_asc.clone(), name_desc.clone()],
-            window_frame: None,
-        };
-        let sum4 = Expr::WindowFunction {
-            fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
-            args: vec![col("age")],
-            partition_by: vec![],
-            order_by: vec![name_desc.clone(), age_asc.clone(), 
created_at_desc.clone()],
-            window_frame: None,
-        };
-        // FIXME use as_ref
-        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
-        let result = group_window_expr_by_sort_keys(exprs)?;
-
-        let key1 = vec![age_asc.clone(), name_desc.clone()];
-        let key2 = vec![];
-        let key3 = vec![name_desc, age_asc, created_at_desc];
-
-        let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![
-            (key1, vec![&max1, &min3]),
-            (key2, vec![&max2]),
-            (key3, vec![&sum4]),
-        ];
-        assert_eq!(expected, result);
-        Ok(())
-    }
-
-    #[test]
-    fn test_find_sort_exprs() -> Result<()> {
-        let exprs = &[
-            Expr::WindowFunction {
-                fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
-                args: vec![col("name")],
-                partition_by: vec![],
-                order_by: vec![
-                    Expr::Sort {
-                        expr: Box::new(col("age")),
-                        asc: true,
-                        nulls_first: true,
-                    },
-                    Expr::Sort {
-                        expr: Box::new(col("name")),
-                        asc: false,
-                        nulls_first: true,
-                    },
-                ],
-                window_frame: None,
-            },
-            Expr::WindowFunction {
-                fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
-                args: vec![col("age")],
-                partition_by: vec![],
-                order_by: vec![
-                    Expr::Sort {
-                        expr: Box::new(col("name")),
-                        asc: false,
-                        nulls_first: true,
-                    },
-                    Expr::Sort {
-                        expr: Box::new(col("age")),
-                        asc: true,
-                        nulls_first: true,
-                    },
-                    Expr::Sort {
-                        expr: Box::new(col("created_at")),
-                        asc: false,
-                        nulls_first: true,
-                    },
-                ],
-                window_frame: None,
-            },
-        ];
-        let expected = vec![
-            Expr::Sort {
-                expr: Box::new(col("age")),
-                asc: true,
-                nulls_first: true,
-            },
-            Expr::Sort {
-                expr: Box::new(col("name")),
-                asc: false,
-                nulls_first: true,
-            },
-            Expr::Sort {
-                expr: Box::new(col("created_at")),
-                asc: false,
-                nulls_first: true,
-            },
-        ];
-        let result = find_sort_exprs(exprs);
-        assert_eq!(expected, result);
-        Ok(())
-    }
-}
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 8f3a1a53c..709a3eee2 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -141,3 +141,328 @@ pub fn expand_qualified_wildcard(
         DFSchema::new_with_metadata(qualified_fields, 
schema.metadata().clone())?;
     expand_wildcard(&qualifier_schema, plan)
 }
+
+type WindowSortKey = Vec<Expr>;
+
+/// Generate a sort key for a given window expr's partition_by and order_bu 
expr
+pub fn generate_sort_key(partition_by: &[Expr], order_by: &[Expr]) -> 
WindowSortKey {
+    let mut sort_key = vec![];
+    partition_by.iter().for_each(|e| {
+        let e = e.clone().sort(true, true);
+        if !sort_key.contains(&e) {
+            sort_key.push(e);
+        }
+    });
+    order_by.iter().for_each(|e| {
+        if !sort_key.contains(e) {
+            sort_key.push(e.clone());
+        }
+    });
+    sort_key
+}
+
+/// group a slice of window expression expr by their order by expressions
+pub fn group_window_expr_by_sort_keys(
+    window_expr: &[Expr],
+) -> Result<Vec<(WindowSortKey, Vec<&Expr>)>> {
+    let mut result = vec![];
+    window_expr.iter().try_for_each(|expr| match expr {
+        Expr::WindowFunction { partition_by, order_by, .. } => {
+            let sort_key = generate_sort_key(partition_by, order_by);
+            if let Some((_, values)) = result.iter_mut().find(
+                |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, 
(key, _) if *key == sort_key),
+            ) {
+                values.push(expr);
+            } else {
+                result.push((sort_key, vec![expr]))
+            }
+            Ok(())
+        }
+        other => Err(DataFusionError::Internal(format!(
+            "Impossibly got non-window expr {:?}",
+            other,
+        ))),
+    })?;
+    Ok(result)
+}
+
+/// Collect all deeply nested `Expr::AggregateFunction` and
+/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth
+/// first), with duplicates omitted.
+pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
+    find_exprs_in_exprs(exprs, &|nested_expr| {
+        matches!(
+            nested_expr,
+            Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
+        )
+    })
+}
+
+/// Collect all deeply nested `Expr::Sort`. They are returned in order of 
occurrence
+/// (depth first), with duplicates omitted.
+pub fn find_sort_exprs(exprs: &[Expr]) -> Vec<Expr> {
+    find_exprs_in_exprs(exprs, &|nested_expr| {
+        matches!(nested_expr, Expr::Sort { .. })
+    })
+}
+
+/// Collect all deeply nested `Expr::WindowFunction`. They are returned in 
order of occurrence
+/// (depth first), with duplicates omitted.
+pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
+    find_exprs_in_exprs(exprs, &|nested_expr| {
+        matches!(nested_expr, Expr::WindowFunction { .. })
+    })
+}
+
+/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
+/// pass the provided test. The returned `Expr`'s are deduplicated and returned
+/// in order of appearance (depth first).
+fn find_exprs_in_exprs<F>(exprs: &[Expr], test_fn: &F) -> Vec<Expr>
+where
+    F: Fn(&Expr) -> bool,
+{
+    exprs
+        .iter()
+        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
+        .fold(vec![], |mut acc, expr| {
+            if !acc.contains(&expr) {
+                acc.push(expr)
+            }
+            acc
+        })
+}
+
+/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
+/// provided test. The returned `Expr`'s are deduplicated and returned in order
+/// of appearance (depth first).
+fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
+where
+    F: Fn(&Expr) -> bool,
+{
+    let Finder { exprs, .. } = expr
+        .accept(Finder::new(test_fn))
+        // pre_visit always returns OK, so this will always too
+        .expect("no way to return error during recursion");
+    exprs
+}
+
+// Visitor that find expressions that match a particular predicate
+struct Finder<'a, F>
+where
+    F: Fn(&Expr) -> bool,
+{
+    test_fn: &'a F,
+    exprs: Vec<Expr>,
+}
+
+impl<'a, F> Finder<'a, F>
+where
+    F: Fn(&Expr) -> bool,
+{
+    /// Create a new finder with the `test_fn`
+    fn new(test_fn: &'a F) -> Self {
+        Self {
+            test_fn,
+            exprs: Vec::new(),
+        }
+    }
+}
+
+impl<'a, F> ExpressionVisitor for Finder<'a, F>
+where
+    F: Fn(&Expr) -> bool,
+{
+    fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
+        if (self.test_fn)(expr) {
+            if !(self.exprs.contains(expr)) {
+                self.exprs.push(expr.clone())
+            }
+            // stop recursing down this expr once we find a match
+            return Ok(Recursion::Stop(self));
+        }
+
+        Ok(Recursion::Continue(self))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::{col, AggregateFunction, WindowFunction};
+
+    #[test]
+    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
+        let result = group_window_expr_by_sort_keys(&[])?;
+        let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![];
+        assert_eq!(expected, result);
+        Ok(())
+    }
+
+    #[test]
+    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
+        let max1 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+            args: vec![col("name")],
+            partition_by: vec![],
+            order_by: vec![],
+            window_frame: None,
+        };
+        let max2 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+            args: vec![col("name")],
+            partition_by: vec![],
+            order_by: vec![],
+            window_frame: None,
+        };
+        let min3 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
+            args: vec![col("name")],
+            partition_by: vec![],
+            order_by: vec![],
+            window_frame: None,
+        };
+        let sum4 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
+            args: vec![col("age")],
+            partition_by: vec![],
+            order_by: vec![],
+            window_frame: None,
+        };
+        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
+        let result = group_window_expr_by_sort_keys(exprs)?;
+        let key = vec![];
+        let expected: Vec<(WindowSortKey, Vec<&Expr>)> =
+            vec![(key, vec![&max1, &max2, &min3, &sum4])];
+        assert_eq!(expected, result);
+        Ok(())
+    }
+
+    #[test]
+    fn test_group_window_expr_by_sort_keys() -> Result<()> {
+        let age_asc = Expr::Sort {
+            expr: Box::new(col("age")),
+            asc: true,
+            nulls_first: true,
+        };
+        let name_desc = Expr::Sort {
+            expr: Box::new(col("name")),
+            asc: false,
+            nulls_first: true,
+        };
+        let created_at_desc = Expr::Sort {
+            expr: Box::new(col("created_at")),
+            asc: false,
+            nulls_first: true,
+        };
+        let max1 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+            args: vec![col("name")],
+            partition_by: vec![],
+            order_by: vec![age_asc.clone(), name_desc.clone()],
+            window_frame: None,
+        };
+        let max2 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+            args: vec![col("name")],
+            partition_by: vec![],
+            order_by: vec![],
+            window_frame: None,
+        };
+        let min3 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
+            args: vec![col("name")],
+            partition_by: vec![],
+            order_by: vec![age_asc.clone(), name_desc.clone()],
+            window_frame: None,
+        };
+        let sum4 = Expr::WindowFunction {
+            fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
+            args: vec![col("age")],
+            partition_by: vec![],
+            order_by: vec![name_desc.clone(), age_asc.clone(), 
created_at_desc.clone()],
+            window_frame: None,
+        };
+        // FIXME use as_ref
+        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
+        let result = group_window_expr_by_sort_keys(exprs)?;
+
+        let key1 = vec![age_asc.clone(), name_desc.clone()];
+        let key2 = vec![];
+        let key3 = vec![name_desc, age_asc, created_at_desc];
+
+        let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![
+            (key1, vec![&max1, &min3]),
+            (key2, vec![&max2]),
+            (key3, vec![&sum4]),
+        ];
+        assert_eq!(expected, result);
+        Ok(())
+    }
+
+    #[test]
+    fn test_find_sort_exprs() -> Result<()> {
+        let exprs = &[
+            Expr::WindowFunction {
+                fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+                args: vec![col("name")],
+                partition_by: vec![],
+                order_by: vec![
+                    Expr::Sort {
+                        expr: Box::new(col("age")),
+                        asc: true,
+                        nulls_first: true,
+                    },
+                    Expr::Sort {
+                        expr: Box::new(col("name")),
+                        asc: false,
+                        nulls_first: true,
+                    },
+                ],
+                window_frame: None,
+            },
+            Expr::WindowFunction {
+                fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
+                args: vec![col("age")],
+                partition_by: vec![],
+                order_by: vec![
+                    Expr::Sort {
+                        expr: Box::new(col("name")),
+                        asc: false,
+                        nulls_first: true,
+                    },
+                    Expr::Sort {
+                        expr: Box::new(col("age")),
+                        asc: true,
+                        nulls_first: true,
+                    },
+                    Expr::Sort {
+                        expr: Box::new(col("created_at")),
+                        asc: false,
+                        nulls_first: true,
+                    },
+                ],
+                window_frame: None,
+            },
+        ];
+        let expected = vec![
+            Expr::Sort {
+                expr: Box::new(col("age")),
+                asc: true,
+                nulls_first: true,
+            },
+            Expr::Sort {
+                expr: Box::new(col("name")),
+                asc: false,
+                nulls_first: true,
+            },
+            Expr::Sort {
+                expr: Box::new(col("created_at")),
+                asc: false,
+                nulls_first: true,
+            },
+        ];
+        let result = find_sort_exprs(exprs);
+        assert_eq!(expected, result);
+        Ok(())
+    }
+}

Reply via email to