This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ea30b93194 feat : Support for deregistering user defined functions 
(#9239)
ea30b93194 is described below

commit ea30b93194cfdfe4148a45fc0e33549884ba81b1
Author: Eddy Oyieko <[email protected]>
AuthorDate: Wed Feb 28 01:37:46 2024 +0300

    feat : Support for deregistering user defined functions (#9239)
    
    * Initial commit
    
    * Updated mod.rs - Docstrings, Initial test
    
    * Updated mod.rs - Fixed udf test
    
    * Added udaf test, Updated udf test
    
    * Added test for udwf
    
    * Linting with rustfmt
    
    * Update datafusion/core/src/execution/context/mod.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Moved tests to core/tests/user_defined
    
    * fix fmt
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/execution/context/mod.rs       | 33 ++++++++++++++++++++++
 .../tests/user_defined/user_defined_aggregates.rs  | 23 +++++++++++++++
 .../user_defined/user_defined_scalar_functions.rs  | 16 +++++++++++
 .../user_defined/user_defined_window_functions.rs  | 15 ++++++++++
 datafusion/execution/src/registry.rs               | 27 ++++++++++++++++++
 5 files changed, 114 insertions(+)

diff --git a/datafusion/core/src/execution/context/mod.rs 
b/datafusion/core/src/execution/context/mod.rs
index 453a00a1a5..3aa4edfe3a 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -849,6 +849,21 @@ impl SessionContext {
         self.state.write().register_udwf(Arc::new(f)).ok();
     }
 
+    /// Deregisters a UDF within this context.
+    pub fn deregister_udf(&self, name: &str) {
+        self.state.write().deregister_udf(name).ok();
+    }
+
+    /// Deregisters a UDAF within this context.
+    pub fn deregister_udaf(&self, name: &str) {
+        self.state.write().deregister_udaf(name).ok();
+    }
+
+    /// Deregisters a UDWF within this context.
+    pub fn deregister_udwf(&self, name: &str) {
+        self.state.write().deregister_udwf(name).ok();
+    }
+
     /// Creates a [`DataFrame`] for reading a data source.
     ///
     /// For more control such as reading multiple files, you can use
@@ -2026,6 +2041,24 @@ impl FunctionRegistry for SessionState {
     fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> 
Result<Option<Arc<WindowUDF>>> {
         Ok(self.window_functions.insert(udwf.name().into(), udwf))
     }
+
+    fn deregister_udf(&mut self, name: &str) -> Result<Option<Arc<ScalarUDF>>> 
{
+        let udf = self.scalar_functions.remove(name);
+        if let Some(udf) = &udf {
+            for alias in udf.aliases() {
+                self.scalar_functions.remove(alias);
+            }
+        }
+        Ok(udf)
+    }
+
+    fn deregister_udaf(&mut self, name: &str) -> 
Result<Option<Arc<AggregateUDF>>> {
+        Ok(self.aggregate_functions.remove(name))
+    }
+
+    fn deregister_udwf(&mut self, name: &str) -> 
Result<Option<Arc<WindowUDF>>> {
+        Ok(self.window_functions.remove(name))
+    }
 }
 
 impl OptimizerConfig for SessionState {
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs 
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 0b29ad10d6..8daeefd236 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -255,6 +255,29 @@ async fn simple_udaf() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn deregister_udaf() -> Result<()> {
+    let ctx = SessionContext::new();
+    let my_avg = create_udaf(
+        "my_avg",
+        vec![DataType::Float64],
+        Arc::new(DataType::Float64),
+        Volatility::Immutable,
+        Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
+        Arc::new(vec![DataType::UInt64, DataType::Float64]),
+    );
+
+    ctx.register_udaf(my_avg.clone());
+
+    assert!(ctx.state().aggregate_functions().contains_key("my_avg"));
+
+    ctx.deregister_udaf("my_avg");
+
+    assert!(!ctx.state().aggregate_functions().contains_key("my_avg"));
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
     let ctx = SessionContext::new();
diff --git 
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 9812789740..a255498eb5 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -498,6 +498,22 @@ async fn test_user_defined_functions_zero_argument() -> 
Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn deregister_udf() -> Result<()> {
+    let random_normal_udf = ScalarUDF::from(RandomUDF::new());
+    let ctx = SessionContext::new();
+
+    ctx.register_udf(random_normal_udf.clone());
+
+    assert!(ctx.udfs().contains("random_udf"));
+
+    ctx.deregister_udf("random_udf");
+
+    assert!(!ctx.udfs().contains("random_udf"));
+
+    Ok(())
+}
+
 #[derive(Debug)]
 struct TakeUDF {
     signature: Signature,
diff --git 
a/datafusion/core/tests/user_defined/user_defined_window_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
index 54eab4315a..cfd74f8861 100644
--- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs
@@ -103,6 +103,21 @@ async fn test_udwf() {
     assert_eq!(test_state.evaluate_all_called(), 2);
 }
 
+#[tokio::test]
+async fn test_deregister_udwf() -> Result<()> {
+    let test_state = Arc::new(TestState::new());
+    let mut ctx = SessionContext::new();
+    OddCounter::register(&mut ctx, Arc::clone(&test_state));
+
+    assert!(ctx.state().window_functions().contains_key("odd_counter"));
+
+    ctx.deregister_udwf("odd_counter");
+
+    assert!(!ctx.state().window_functions().contains_key("odd_counter"));
+
+    Ok(())
+}
+
 /// Basic user defined window function with bounded window
 #[tokio::test]
 async fn test_udwf_bounded_window_ignores_frame() {
diff --git a/datafusion/execution/src/registry.rs 
b/datafusion/execution/src/registry.rs
index 4569967acb..6e0a932f0b 100644
--- a/datafusion/execution/src/registry.rs
+++ b/datafusion/execution/src/registry.rs
@@ -66,6 +66,33 @@ pub trait FunctionRegistry {
     fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> 
Result<Option<Arc<WindowUDF>>> {
         not_impl_err!("Registering WindowUDF")
     }
+
+    /// Deregisters a [`ScalarUDF`], returning the implementation that was
+    /// deregistered.
+    ///
+    /// Returns an error (the default) if the function can not be deregistered,
+    /// for example if the registry is read only.
+    fn deregister_udf(&mut self, _name: &str) -> 
Result<Option<Arc<ScalarUDF>>> {
+        not_impl_err!("Deregistering ScalarUDF")
+    }
+
+    /// Deregisters a [`AggregateUDF`], returning the implementation that was
+    /// deregistered.
+    ///
+    /// Returns an error (the default) if the function can not be deregistered,
+    /// for example if the registry is read only.
+    fn deregister_udaf(&mut self, _name: &str) -> 
Result<Option<Arc<AggregateUDF>>> {
+        not_impl_err!("Deregistering AggregateUDF")
+    }
+
+    /// Deregisters a [`WindowUDF`], returning the implementation that was
+    /// deregistered.
+    ///
+    /// Returns an error (the default) if the function can not be deregistered,
+    /// for example if the registry is read only.
+    fn deregister_udwf(&mut self, _name: &str) -> 
Result<Option<Arc<WindowUDF>>> {
+        not_impl_err!("Deregistering WindowUDF")
+    }
 }
 
 /// Serializer and deserializer registry for extensions like 
[UserDefinedLogicalNode].

Reply via email to