alamb commented on pull request #9531:
URL: https://github.com/apache/arrow/pull/9531#issuecomment-784596915
@wqc200 here is a way to use `std::sync::Mutex` to accomplish changing
database name:
```
use std::sync::{Arc, Mutex};
use arrow::array::ArrayRef;
use arrow::array::StringArray;
use arrow::datatypes::DataType;
use arrow::util::display::array_value_to_string;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;
use datafusion::logical_plan::create_udf;
use datafusion::physical_plan::functions::make_scalar_function;
pub struct TestProvider {
// Wrap the db_name in a mutex so it can be changed after registration
db_name: Arc<Mutex<String>>,
ctx: ExecutionContext,
}
impl TestProvider {
pub fn try_new(db_name: &str) -> Result<Self> {
let ctx = ExecutionContext::new();
Ok(Self {
db_name: Arc::new(Mutex::new(db_name.to_string())),
ctx,
})
}
pub fn change_db_name(&mut self, db_name: &str) {
*self.db_name.lock().expect("mutex poisoned") = db_name.to_string();
}
pub fn register_udf(&mut self) {
// implementation of `database()` function that returns the
// current value of `self.db_name`
let captured_name = self.db_name.clone();
let database_function = move |_args: &[ArrayRef]| {
// Lock the mutex, and read current db_name
let captured_name = captured_name.lock().expect("mutex
posioned");
let db_name = captured_name.as_str();
let res = StringArray::from(vec![Some(db_name)]);
Ok(Arc::new(res) as ArrayRef)
};
self.ctx.register_udf(create_udf(
"database", // function name
vec![], // input argument types
Arc::new(DataType::Utf8), // output type
make_scalar_function(database_function)), // function
implementation
);
}
pub async fn get_db_name(&mut self) -> Result<String> {
let sql_results = self.ctx.sql("select
database()")?.collect().await?;
let sql_result_col = array_value_to_string(sql_results[0].column(0),
0).unwrap();
Ok(sql_result_col)
}
}
#[tokio::test]
async fn test_func_provider_results() -> Result<()> {
let mut p = TestProvider::try_new("test1")?;
p.register_udf();
p.change_db_name("test");
let db_name = p.get_db_name().await?;
let expected_db_name = "test";
assert_eq!(db_name, expected_db_name);
Ok(())
}
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]