alamb commented on pull request #9531:
URL: https://github.com/apache/arrow/pull/9531#issuecomment-784596915


   @wqc200  here is a way to use `std::sync::Mutex` to accomplish changing 
database name:
   
   ```
   use std::sync::{Arc, Mutex};
   
   use arrow::array::ArrayRef;
   use arrow::array::StringArray;
   use arrow::datatypes::DataType;
   use arrow::util::display::array_value_to_string;
   
   use datafusion::error::Result;
   use datafusion::execution::context::ExecutionContext;
   use datafusion::logical_plan::create_udf;
   use datafusion::physical_plan::functions::make_scalar_function;
   
   pub struct TestProvider {
       // Wrap the db_name in a mutex so it can be changed after registration
       db_name: Arc<Mutex<String>>,
       ctx: ExecutionContext,
   }
   
   impl TestProvider {
       pub fn try_new(db_name: &str) -> Result<Self> {
           let ctx = ExecutionContext::new();
   
           Ok(Self {
               db_name: Arc::new(Mutex::new(db_name.to_string())),
               ctx,
           })
       }
   
       pub fn change_db_name(&mut self, db_name: &str) {
           *self.db_name.lock().expect("mutex poisoned") = db_name.to_string();
       }
   
       pub fn register_udf(&mut self) {
           // implementation of `database()` function that returns the
           // current value of `self.db_name`
           let captured_name = self.db_name.clone();
           let database_function = move |_args: &[ArrayRef]| {
               // Lock the mutex, and read current db_name
               let captured_name = captured_name.lock().expect("mutex 
posioned");
               let db_name = captured_name.as_str();
               let res = StringArray::from(vec![Some(db_name)]);
               Ok(Arc::new(res) as ArrayRef)
           };
   
           self.ctx.register_udf(create_udf(
               "database", // function name
               vec![], // input argument types
               Arc::new(DataType::Utf8), // output type
               make_scalar_function(database_function)), // function 
implementation
           );
       }
   
       pub async fn get_db_name(&mut self) -> Result<String> {
           let sql_results = self.ctx.sql("select 
database()")?.collect().await?;
   
           let sql_result_col = array_value_to_string(sql_results[0].column(0), 
0).unwrap();
   
           Ok(sql_result_col)
       }
   }
   
   #[tokio::test]
   async fn test_func_provider_results() -> Result<()> {
       let mut p = TestProvider::try_new("test1")?;
       p.register_udf();
       p.change_db_name("test");
   
       let db_name = p.get_db_name().await?;
       let expected_db_name = "test";
   
       assert_eq!(db_name, expected_db_name);
   
       Ok(())
   }
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to