This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 42b3a6c92 DataFrame owned SessionState (#4633)
42b3a6c92 is described below
commit 42b3a6c9224e06adf2170aafe2cfeadca6b3f2bb
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Sat Dec 17 10:25:47 2022 +0000
DataFrame owned SessionState (#4633)
* DataFrame owned SessionState (#4617)
* Fix deadlock
* Fix execution time
---
datafusion-examples/examples/custom_datasource.rs | 2 +-
datafusion/core/src/dataframe.rs | 74 +++++++++--------------
datafusion/core/src/execution/context.rs | 21 ++++---
3 files changed, 40 insertions(+), 57 deletions(-)
diff --git a/datafusion-examples/examples/custom_datasource.rs
b/datafusion-examples/examples/custom_datasource.rs
index db4fed494..68e8f5a54 100644
--- a/datafusion-examples/examples/custom_datasource.rs
+++ b/datafusion-examples/examples/custom_datasource.rs
@@ -69,7 +69,7 @@ async fn search_accounts(
)?
.build()?;
- let mut dataframe = DataFrame::new(ctx.state, logical_plan)
+ let mut dataframe = DataFrame::new(ctx.state(), logical_plan)
.select_columns(&["id", "bank_account"])?;
if let Some(f) = filter {
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 153c59c56..77f7e7615 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -21,7 +21,6 @@ use std::any::Any;
use std::sync::Arc;
use async_trait::async_trait;
-use parking_lot::RwLock;
use parquet::file::properties::WriterProperties;
use datafusion_common::{Column, DFSchema};
@@ -74,13 +73,13 @@ use crate::prelude::SessionContext;
/// ```
#[derive(Debug, Clone)]
pub struct DataFrame {
- session_state: Arc<RwLock<SessionState>>,
+ session_state: SessionState,
plan: LogicalPlan,
}
impl DataFrame {
/// Create a new Table based on an existing logical plan
- pub fn new(session_state: Arc<RwLock<SessionState>>, plan: LogicalPlan) ->
Self {
+ pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self {
Self {
session_state,
plan,
@@ -88,26 +87,14 @@ impl DataFrame {
}
/// Create a physical plan
- pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> {
- // this function is copied from SessionContext function of the
- // same name
- let state_cloned = {
- let mut state = self.session_state.write();
- state.execution_props.start_execution();
-
- // We need to clone `state` to release the lock that is not
`Send`. We could
- // make the lock `Send` by using `tokio::sync::Mutex`, but that
would require to
- // propagate async even to the `LogicalPlan` building methods.
- // Cloning `state` here is fine as we then pass it as immutable
`&state`, which
- // means that we avoid write consistency issues as the cloned
version will not
- // be written to. As for eventual modifications that would be
applied to the
- // original state after it has been cloned, they will not be
picked up by the
- // clone but that is okay, as it is equivalent to postponing the
state update
- // by keeping the lock until the end of the function scope.
- state.clone()
- };
+ pub async fn create_physical_plan(mut self) -> Result<Arc<dyn
ExecutionPlan>> {
+ self.create_physical_plan_impl().await
+ }
- state_cloned.create_physical_plan(&self.plan).await
+ /// Temporary pending #4626
+ async fn create_physical_plan_impl(&mut self) -> Result<Arc<dyn
ExecutionPlan>> {
+ self.session_state.execution_props.start_execution();
+ self.session_state.create_physical_plan(&self.plan).await
}
/// Filter the DataFrame by column. Returns a new DataFrame only
containing the
@@ -437,8 +424,7 @@ impl DataFrame {
}
fn task_ctx(&self) -> TaskContext {
- let lock = self.session_state.read();
- TaskContext::from(&*lock)
+ TaskContext::from(&self.session_state)
}
/// Executes this DataFrame and returns a stream over a single partition
@@ -527,8 +513,7 @@ impl DataFrame {
/// Return the optimized logical plan represented by this DataFrame.
pub fn to_logical_plan(self) -> Result<LogicalPlan> {
// Optimize the plan first for better UX
- let state = self.session_state.read().clone();
- state.optimize(&self.plan)
+ self.session_state.optimize(&self.plan)
}
/// Return a DataFrame with the explanation of its plan so far.
@@ -567,9 +552,8 @@ impl DataFrame {
/// # Ok(())
/// # }
/// ```
- pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
- let registry = self.session_state.read().clone();
- Arc::new(registry)
+ pub fn registry(&self) -> &dyn FunctionRegistry {
+ &self.session_state
}
/// Calculate the intersection of two [`DataFrame`]s. The two
[`DataFrame`]s must have exactly the same schema
@@ -620,28 +604,25 @@ impl DataFrame {
}
/// Write a `DataFrame` to a CSV file.
- pub async fn write_csv(self, path: &str) -> Result<()> {
- let state = self.session_state.read().clone();
- let plan = self.create_physical_plan().await?;
- plan_to_csv(&state, plan, path).await
+ pub async fn write_csv(mut self, path: &str) -> Result<()> {
+ let plan = self.create_physical_plan_impl().await?;
+ plan_to_csv(&self.session_state, plan, path).await
}
/// Write a `DataFrame` to a Parquet file.
pub async fn write_parquet(
- self,
+ mut self,
path: &str,
writer_properties: Option<WriterProperties>,
) -> Result<()> {
- let state = self.session_state.read().clone();
- let plan = self.create_physical_plan().await?;
- plan_to_parquet(&state, plan, path, writer_properties).await
+ let plan = self.create_physical_plan_impl().await?;
+ plan_to_parquet(&self.session_state, plan, path,
writer_properties).await
}
/// Executes a query and writes the results to a partitioned JSON file.
- pub async fn write_json(self, path: impl AsRef<str>) -> Result<()> {
- let state = self.session_state.read().clone();
- let plan = self.create_physical_plan().await?;
- plan_to_json(&state, plan, path).await
+ pub async fn write_json(mut self, path: impl AsRef<str>) -> Result<()> {
+ let plan = self.create_physical_plan_impl().await?;
+ plan_to_json(&self.session_state, plan, path).await
}
/// Add an additional column to the DataFrame.
@@ -747,7 +728,7 @@ impl DataFrame {
/// # }
/// ```
pub async fn cache(self) -> Result<DataFrame> {
- let context =
SessionContext::with_state(self.session_state.read().clone());
+ let context = SessionContext::with_state(self.session_state.clone());
let mem_table = MemTable::try_new(
SchemaRef::from(self.schema().clone()),
self.collect_partitioned().await?,
@@ -1029,9 +1010,8 @@ mod tests {
// build query with a UDF using DataFrame API
let df = ctx.table("aggregate_test_100")?;
- let f = df.registry();
-
- let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
+ let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]);
+ let df = df.select(vec![expr])?;
// build query using SQL
let sql_plan =
@@ -1088,7 +1068,7 @@ mod tests {
async fn register_table() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c12"])?;
let ctx = SessionContext::new();
- let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
+ let df_impl = DataFrame::new(ctx.state(), df.plan.clone());
// register a dataframe as a table
ctx.register_table("test_table", Arc::new(df_impl.clone()))?;
@@ -1180,7 +1160,7 @@ mod tests {
async fn with_column() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
- let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
+ let df_impl = DataFrame::new(ctx.state(), df.plan.clone());
let df = df_impl
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
diff --git a/datafusion/core/src/execution/context.rs
b/datafusion/core/src/execution/context.rs
index 16f2a29f3..ce85f1821 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -273,7 +273,7 @@ impl SessionContext {
(false, true, Ok(_)) => {
self.deregister_table(&name)?;
let schema = Arc::new(input.schema().as_ref().into());
- let physical = DataFrame::new(self.state.clone(),
input);
+ let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> =
physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema,
batches)?);
@@ -286,7 +286,7 @@ impl SessionContext {
)),
(_, _, Err(_)) => {
let schema = Arc::new(input.schema().as_ref().into());
- let physical = DataFrame::new(self.state.clone(),
input);
+ let physical = DataFrame::new(self.state(), input);
let batches: Vec<_> =
physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema,
batches)?);
@@ -363,7 +363,8 @@ impl SessionContext {
LogicalPlan::SetVariable(SetVariable {
variable, value, ..
}) => {
- let config_options = &self.state.write().config.config_options;
+ let state = self.state.write();
+ let config_options = &state.config.config_options;
let old_value =
config_options.read().get(&variable).ok_or_else(|| {
@@ -410,6 +411,8 @@ impl SessionContext {
))
}
}
+ drop(state);
+
self.return_empty_dataframe()
}
@@ -475,14 +478,14 @@ impl SessionContext {
}
}
- plan => Ok(DataFrame::new(self.state.clone(), plan)),
+ plan => Ok(DataFrame::new(self.state(), plan)),
}
}
// return an empty dataframe
fn return_empty_dataframe(&self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::empty(false).build()?;
- Ok(DataFrame::new(self.state.clone(), plan))
+ Ok(DataFrame::new(self.state(), plan))
}
async fn create_external_table(
@@ -661,7 +664,7 @@ impl SessionContext {
/// Creates an empty DataFrame.
pub fn read_empty(&self) -> Result<DataFrame> {
Ok(DataFrame::new(
- self.state.clone(),
+ self.state(),
LogicalPlanBuilder::empty(true).build()?,
))
}
@@ -716,7 +719,7 @@ impl SessionContext {
/// Creates a [`DataFrame`] for reading a custom [`TableProvider`].
pub fn read_table(&self, provider: Arc<dyn TableProvider>) ->
Result<DataFrame> {
Ok(DataFrame::new(
- self.state.clone(),
+ self.state(),
LogicalPlanBuilder::scan(UNNAMED_TABLE,
provider_as_source(provider), None)?
.build()?,
))
@@ -726,7 +729,7 @@ impl SessionContext {
pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
Ok(DataFrame::new(
- self.state.clone(),
+ self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
@@ -946,7 +949,7 @@ impl SessionContext {
None,
)?
.build()?;
- Ok(DataFrame::new(self.state.clone(), plan))
+ Ok(DataFrame::new(self.state(), plan))
}
/// Return a [`TabelProvider`] for the specified table.