This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f1f0965331 feat: function name hints for UDFs (#9407)
f1f0965331 is described below

commit f1f09653319aea3186c2b1f9ca103ef7030c2da1
Author: SteveLauC <[email protected]>
AuthorDate: Sun Mar 10 19:31:56 2024 +0800

    feat: function name hints for UDFs (#9407)
    
    * feat: function name hints for UDFs
    
    * refactor: rebase fn to xxx_names()
    
    * style: fix clippy
    
    * style: fix clippy
    
    * Add test
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion-cli/Cargo.lock                          |  1 +
 datafusion-examples/examples/rewrite_expr.rs       | 12 +++++
 datafusion/core/src/execution/context/mod.rs       | 12 +++++
 datafusion/expr/src/function.rs                    | 37 ++------------
 .../optimizer/tests/optimizer_integration.rs       | 12 +++++
 datafusion/sql/Cargo.toml                          |  1 +
 datafusion/sql/examples/sql.rs                     | 12 +++++
 datafusion/sql/src/expr/function.rs                | 58 ++++++++++++++++++++--
 datafusion/sql/src/expr/mod.rs                     | 12 +++++
 datafusion/sql/src/planner.rs                      |  4 ++
 datafusion/sql/tests/sql_integration.rs            | 12 +++++
 datafusion/sqllogictest/test_files/functions.slt   |  2 +-
 12 files changed, 135 insertions(+), 40 deletions(-)

diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 5e3c8648fc..b4af789682 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1363,6 +1363,7 @@ dependencies = [
  "datafusion-expr",
  "log",
  "sqlparser",
+ "strum 0.26.1",
 ]
 
 [[package]]
diff --git a/datafusion-examples/examples/rewrite_expr.rs 
b/datafusion-examples/examples/rewrite_expr.rs
index cc1396f770..541448ebf1 100644
--- a/datafusion-examples/examples/rewrite_expr.rs
+++ b/datafusion-examples/examples/rewrite_expr.rs
@@ -226,6 +226,18 @@ impl ContextProvider for MyContextProvider {
     fn options(&self) -> &ConfigOptions {
         &self.options
     }
+
+    fn udfs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
+
+    fn udafs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
+
+    fn udwfs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
 }
 
 struct MyTableSource {
diff --git a/datafusion/core/src/execution/context/mod.rs 
b/datafusion/core/src/execution/context/mod.rs
index 7b37e4914c..49d1b12e66 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -2098,6 +2098,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
     fn options(&self) -> &ConfigOptions {
         self.state.config_options()
     }
+
+    fn udfs_names(&self) -> Vec<String> {
+        self.state.scalar_functions().keys().cloned().collect()
+    }
+
+    fn udafs_names(&self) -> Vec<String> {
+        self.state.aggregate_functions().keys().cloned().collect()
+    }
+
+    fn udwfs_names(&self) -> Vec<String> {
+        self.state.window_functions().keys().cloned().collect()
+    }
 }
 
 impl FunctionRegistry for SessionState {
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 3e30a5574b..a3760eeb35 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -17,13 +17,12 @@
 
 //! Function module contains typing and signature for built-in and user 
defined functions.
 
-use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
-use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
+use crate::{
+    Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, 
Signature,
+};
 use arrow::datatypes::DataType;
-use datafusion_common::utils::datafusion_strsim;
 use datafusion_common::Result;
 use std::sync::Arc;
-use strum::IntoEnumIterator;
 
 /// Scalar function
 ///
@@ -75,33 +74,3 @@ pub fn return_type(
 pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
     fun.signature()
 }
-
-/// Suggest a valid function based on an invalid input function name
-pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) 
-> String {
-    let valid_funcs = if is_window_func {
-        // All aggregate functions and builtin window functions
-        AggregateFunction::iter()
-            .map(|func| func.to_string())
-            .chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
-            .collect()
-    } else {
-        // All scalar functions and aggregate functions
-        BuiltinScalarFunction::iter()
-            .map(|func| func.to_string())
-            .chain(AggregateFunction::iter().map(|func| func.to_string()))
-            .collect()
-    };
-    find_closest_match(valid_funcs, input_function_name)
-}
-
-/// Find the closest matching string to the target string in the candidates 
list, using edit distance(case insensitve)
-/// Input `candidates` must not be empty otherwise it will panic
-fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
-    let target = target.to_lowercase();
-    candidates
-        .into_iter()
-        .min_by_key(|candidate| {
-            datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target)
-        })
-        .expect("No candidates provided.") // Panic if `candidates` argument 
is empty
-}
diff --git a/datafusion/optimizer/tests/optimizer_integration.rs 
b/datafusion/optimizer/tests/optimizer_integration.rs
index db7bfa8b3b..b02623854b 100644
--- a/datafusion/optimizer/tests/optimizer_integration.rs
+++ b/datafusion/optimizer/tests/optimizer_integration.rs
@@ -417,6 +417,18 @@ impl ContextProvider for MyContextProvider {
     fn options(&self) -> &ConfigOptions {
         &self.options
     }
+
+    fn udfs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
+
+    fn udafs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
+
+    fn udwfs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
 }
 
 struct MyTableSource {
diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml
index fb300e2c87..7739058a5c 100644
--- a/datafusion/sql/Cargo.toml
+++ b/datafusion/sql/Cargo.toml
@@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features = 
true }
 datafusion-expr = { workspace = true }
 log = { workspace = true }
 sqlparser = { workspace = true }
+strum = { version = "0.26.1", features = ["derive"] }
 
 [dev-dependencies]
 ctor = { workspace = true }
diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs
index 8744a90548..5bab2f19cf 100644
--- a/datafusion/sql/examples/sql.rs
+++ b/datafusion/sql/examples/sql.rs
@@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider {
     fn options(&self) -> &ConfigOptions {
         &self.options
     }
+
+    fn udfs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
+
+    fn udafs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
+
+    fn udwfs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
 }
diff --git a/datafusion/sql/src/expr/function.rs 
b/datafusion/sql/src/expr/function.rs
index bcf641e4b5..ffc951a6fa 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -20,20 +20,67 @@ use arrow_schema::DataType;
 use datafusion_common::{
     not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result,
 };
-use datafusion_expr::expr::{ScalarFunction, Unnest};
-use datafusion_expr::function::suggest_valid_function;
 use datafusion_expr::window_frame::{check_window_frame, 
regularize_window_order_by};
 use datafusion_expr::{
-    expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, 
WindowFrame,
-    WindowFunctionDefinition,
+    expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, 
WindowFunctionDefinition,
+};
+use datafusion_expr::{
+    expr::{ScalarFunction, Unnest},
+    BuiltInWindowFunction, BuiltinScalarFunction,
 };
 use sqlparser::ast::{
     Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, 
WindowType,
 };
 use std::str::FromStr;
+use strum::IntoEnumIterator;
 
 use super::arrow_cast::ARROW_CAST_NAME;
 
+/// Suggest a valid function based on an invalid input function name
+pub fn suggest_valid_function(
+    input_function_name: &str,
+    is_window_func: bool,
+    ctx: &dyn ContextProvider,
+) -> String {
+    let valid_funcs = if is_window_func {
+        // All aggregate functions and builtin window functions
+        let mut funcs = Vec::new();
+
+        funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
+        funcs.extend(ctx.udafs_names());
+        funcs.extend(BuiltInWindowFunction::iter().map(|func| 
func.to_string()));
+        funcs.extend(ctx.udwfs_names());
+
+        funcs
+    } else {
+        // All scalar functions and aggregate functions
+        let mut funcs = Vec::new();
+
+        funcs.extend(BuiltinScalarFunction::iter().map(|func| 
func.to_string()));
+        funcs.extend(ctx.udfs_names());
+        funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
+        funcs.extend(ctx.udafs_names());
+
+        funcs
+    };
+    find_closest_match(valid_funcs, input_function_name)
+}
+
+/// Find the closest matching string to the target string in the candidates 
list, using edit distance(case insensitve)
+/// Input `candidates` must not be empty otherwise it will panic
+fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
+    let target = target.to_lowercase();
+    candidates
+        .into_iter()
+        .min_by_key(|candidate| {
+            datafusion_common::utils::datafusion_strsim::levenshtein(
+                &candidate.to_lowercase(),
+                &target,
+            )
+        })
+        .expect("No candidates provided.") // Panic if `candidates` argument 
is empty
+}
+
 impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     pub(super) fn sql_function_to_expr(
         &self,
@@ -211,7 +258,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         }
 
         // Could not find the relevant function, so return an error
-        let suggested_func_name = suggest_valid_function(&name, 
is_function_window);
+        let suggested_func_name =
+            suggest_valid_function(&name, is_function_window, 
self.context_provider);
         plan_err!("Invalid function '{name}'.\nDid you mean 
'{suggested_func_name}'?")
     }
 
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index d6aa006ec3..e838a4cafb 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -983,6 +983,18 @@ mod tests {
         fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
             None
         }
+
+        fn udfs_names(&self) -> Vec<String> {
+            Vec::new()
+        }
+
+        fn udafs_names(&self) -> Vec<String> {
+            Vec::new()
+        }
+
+        fn udwfs_names(&self) -> Vec<String> {
+            Vec::new()
+        }
     }
 
     fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 2db2c01c5e..f94c6ec4e8 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -85,6 +85,10 @@ pub trait ContextProvider {
 
     /// Get configuration options
     fn options(&self) -> &ConfigOptions;
+
+    fn udfs_names(&self) -> Vec<String>;
+    fn udafs_names(&self) -> Vec<String>;
+    fn udwfs_names(&self) -> Vec<String>;
 }
 
 /// SQL parser options
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index 655eb63cc3..6681c3d025 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -2901,6 +2901,18 @@ impl ContextProvider for MockContextProvider {
     ) -> Result<Arc<dyn TableSource>> {
         Ok(Arc::new(EmptyTable::new(schema)))
     }
+
+    fn udfs_names(&self) -> Vec<String> {
+        self.udfs.keys().cloned().collect()
+    }
+
+    fn udafs_names(&self) -> Vec<String> {
+        self.udafs.keys().cloned().collect()
+    }
+
+    fn udwfs_names(&self) -> Vec<String> {
+        Vec::new()
+    }
 }
 
 #[test]
diff --git a/datafusion/sqllogictest/test_files/functions.slt 
b/datafusion/sqllogictest/test_files/functions.slt
index 96aa3e2752..21433ba168 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -483,7 +483,7 @@ statement error Did you mean 'arrow_typeof'?
 SELECT arrowtypeof(v1) from test;
 
 # Scalar function
-statement error Invalid function 'to_timestamps_second'
+statement error Did you mean 'to_timestamp_seconds'?
 SELECT to_TIMESTAMPS_second(v2) from test;
 
 # Aggregate function

Reply via email to