This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 42b3a6c92 DataFrame owned SessionState (#4633)
42b3a6c92 is described below

commit 42b3a6c9224e06adf2170aafe2cfeadca6b3f2bb
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Sat Dec 17 10:25:47 2022 +0000

    DataFrame owned SessionState (#4633)
    
    * DataFrame owned SessionState (#4617)
    
    * Fix deadlock
    
    * Fix execution time
---
 datafusion-examples/examples/custom_datasource.rs |  2 +-
 datafusion/core/src/dataframe.rs                  | 74 +++++++++--------------
 datafusion/core/src/execution/context.rs          | 21 ++++---
 3 files changed, 40 insertions(+), 57 deletions(-)

diff --git a/datafusion-examples/examples/custom_datasource.rs 
b/datafusion-examples/examples/custom_datasource.rs
index db4fed494..68e8f5a54 100644
--- a/datafusion-examples/examples/custom_datasource.rs
+++ b/datafusion-examples/examples/custom_datasource.rs
@@ -69,7 +69,7 @@ async fn search_accounts(
     )?
     .build()?;
 
-    let mut dataframe = DataFrame::new(ctx.state, logical_plan)
+    let mut dataframe = DataFrame::new(ctx.state(), logical_plan)
         .select_columns(&["id", "bank_account"])?;
 
     if let Some(f) = filter {
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 153c59c56..77f7e7615 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -21,7 +21,6 @@ use std::any::Any;
 use std::sync::Arc;
 
 use async_trait::async_trait;
-use parking_lot::RwLock;
 use parquet::file::properties::WriterProperties;
 
 use datafusion_common::{Column, DFSchema};
@@ -74,13 +73,13 @@ use crate::prelude::SessionContext;
 /// ```
 #[derive(Debug, Clone)]
 pub struct DataFrame {
-    session_state: Arc<RwLock<SessionState>>,
+    session_state: SessionState,
     plan: LogicalPlan,
 }
 
 impl DataFrame {
     /// Create a new Table based on an existing logical plan
-    pub fn new(session_state: Arc<RwLock<SessionState>>, plan: LogicalPlan) -> 
Self {
+    pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self {
         Self {
             session_state,
             plan,
@@ -88,26 +87,14 @@ impl DataFrame {
     }
 
     /// Create a physical plan
-    pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> {
-        // this function is copied from SessionContext function of the
-        // same name
-        let state_cloned = {
-            let mut state = self.session_state.write();
-            state.execution_props.start_execution();
-
-            // We need to clone `state` to release the lock that is not 
`Send`. We could
-            // make the lock `Send` by using `tokio::sync::Mutex`, but that 
would require to
-            // propagate async even to the `LogicalPlan` building methods.
-            // Cloning `state` here is fine as we then pass it as immutable 
`&state`, which
-            // means that we avoid write consistency issues as the cloned 
version will not
-            // be written to. As for eventual modifications that would be 
applied to the
-            // original state after it has been cloned, they will not be 
picked up by the
-            // clone but that is okay, as it is equivalent to postponing the 
state update
-            // by keeping the lock until the end of the function scope.
-            state.clone()
-        };
+    pub async fn create_physical_plan(mut self) -> Result<Arc<dyn 
ExecutionPlan>> {
+        self.create_physical_plan_impl().await
+    }
 
-        state_cloned.create_physical_plan(&self.plan).await
+    /// Temporary pending #4626
+    async fn create_physical_plan_impl(&mut self) -> Result<Arc<dyn 
ExecutionPlan>> {
+        self.session_state.execution_props.start_execution();
+        self.session_state.create_physical_plan(&self.plan).await
     }
 
     /// Filter the DataFrame by column. Returns a new DataFrame only 
containing the
@@ -437,8 +424,7 @@ impl DataFrame {
     }
 
     fn task_ctx(&self) -> TaskContext {
-        let lock = self.session_state.read();
-        TaskContext::from(&*lock)
+        TaskContext::from(&self.session_state)
     }
 
     /// Executes this DataFrame and returns a stream over a single partition
@@ -527,8 +513,7 @@ impl DataFrame {
     /// Return the optimized logical plan represented by this DataFrame.
     pub fn to_logical_plan(self) -> Result<LogicalPlan> {
         // Optimize the plan first for better UX
-        let state = self.session_state.read().clone();
-        state.optimize(&self.plan)
+        self.session_state.optimize(&self.plan)
     }
 
     /// Return a DataFrame with the explanation of its plan so far.
@@ -567,9 +552,8 @@ impl DataFrame {
     /// # Ok(())
     /// # }
     /// ```
-    pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
-        let registry = self.session_state.read().clone();
-        Arc::new(registry)
+    pub fn registry(&self) -> &dyn FunctionRegistry {
+        &self.session_state
     }
 
     /// Calculate the intersection of two [`DataFrame`]s.  The two 
[`DataFrame`]s must have exactly the same schema
@@ -620,28 +604,25 @@ impl DataFrame {
     }
 
     /// Write a `DataFrame` to a CSV file.
-    pub async fn write_csv(self, path: &str) -> Result<()> {
-        let state = self.session_state.read().clone();
-        let plan = self.create_physical_plan().await?;
-        plan_to_csv(&state, plan, path).await
+    pub async fn write_csv(mut self, path: &str) -> Result<()> {
+        let plan = self.create_physical_plan_impl().await?;
+        plan_to_csv(&self.session_state, plan, path).await
     }
 
     /// Write a `DataFrame` to a Parquet file.
     pub async fn write_parquet(
-        self,
+        mut self,
         path: &str,
         writer_properties: Option<WriterProperties>,
     ) -> Result<()> {
-        let state = self.session_state.read().clone();
-        let plan = self.create_physical_plan().await?;
-        plan_to_parquet(&state, plan, path, writer_properties).await
+        let plan = self.create_physical_plan_impl().await?;
+        plan_to_parquet(&self.session_state, plan, path, 
writer_properties).await
     }
 
     /// Executes a query and writes the results to a partitioned JSON file.
-    pub async fn write_json(self, path: impl AsRef<str>) -> Result<()> {
-        let state = self.session_state.read().clone();
-        let plan = self.create_physical_plan().await?;
-        plan_to_json(&state, plan, path).await
+    pub async fn write_json(mut self, path: impl AsRef<str>) -> Result<()> {
+        let plan = self.create_physical_plan_impl().await?;
+        plan_to_json(&self.session_state, plan, path).await
     }
 
     /// Add an additional column to the DataFrame.
@@ -747,7 +728,7 @@ impl DataFrame {
     /// # }
     /// ```
     pub async fn cache(self) -> Result<DataFrame> {
-        let context = 
SessionContext::with_state(self.session_state.read().clone());
+        let context = SessionContext::with_state(self.session_state.clone());
         let mem_table = MemTable::try_new(
             SchemaRef::from(self.schema().clone()),
             self.collect_partitioned().await?,
@@ -1029,9 +1010,8 @@ mod tests {
         // build query with a UDF using DataFrame API
         let df = ctx.table("aggregate_test_100")?;
 
-        let f = df.registry();
-
-        let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
+        let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]);
+        let df = df.select(vec![expr])?;
 
         // build query using SQL
         let sql_plan =
@@ -1088,7 +1068,7 @@ mod tests {
     async fn register_table() -> Result<()> {
         let df = test_table().await?.select_columns(&["c1", "c12"])?;
         let ctx = SessionContext::new();
-        let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
+        let df_impl = DataFrame::new(ctx.state(), df.plan.clone());
 
         // register a dataframe as a table
         ctx.register_table("test_table", Arc::new(df_impl.clone()))?;
@@ -1180,7 +1160,7 @@ mod tests {
     async fn with_column() -> Result<()> {
         let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
         let ctx = SessionContext::new();
-        let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
+        let df_impl = DataFrame::new(ctx.state(), df.plan.clone());
 
         let df = df_impl
             .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index 16f2a29f3..ce85f1821 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -273,7 +273,7 @@ impl SessionContext {
                     (false, true, Ok(_)) => {
                         self.deregister_table(&name)?;
                         let schema = Arc::new(input.schema().as_ref().into());
-                        let physical = DataFrame::new(self.state.clone(), 
input);
+                        let physical = DataFrame::new(self.state(), input);
 
                         let batches: Vec<_> = 
physical.collect_partitioned().await?;
                         let table = Arc::new(MemTable::try_new(schema, 
batches)?);
@@ -286,7 +286,7 @@ impl SessionContext {
                     )),
                     (_, _, Err(_)) => {
                         let schema = Arc::new(input.schema().as_ref().into());
-                        let physical = DataFrame::new(self.state.clone(), 
input);
+                        let physical = DataFrame::new(self.state(), input);
 
                         let batches: Vec<_> = 
physical.collect_partitioned().await?;
                         let table = Arc::new(MemTable::try_new(schema, 
batches)?);
@@ -363,7 +363,8 @@ impl SessionContext {
             LogicalPlan::SetVariable(SetVariable {
                 variable, value, ..
             }) => {
-                let config_options = &self.state.write().config.config_options;
+                let state = self.state.write();
+                let config_options = &state.config.config_options;
 
                 let old_value =
                     config_options.read().get(&variable).ok_or_else(|| {
@@ -410,6 +411,8 @@ impl SessionContext {
                         ))
                     }
                 }
+                drop(state);
+
                 self.return_empty_dataframe()
             }
 
@@ -475,14 +478,14 @@ impl SessionContext {
                 }
             }
 
-            plan => Ok(DataFrame::new(self.state.clone(), plan)),
+            plan => Ok(DataFrame::new(self.state(), plan)),
         }
     }
 
     // return an empty dataframe
     fn return_empty_dataframe(&self) -> Result<DataFrame> {
         let plan = LogicalPlanBuilder::empty(false).build()?;
-        Ok(DataFrame::new(self.state.clone(), plan))
+        Ok(DataFrame::new(self.state(), plan))
     }
 
     async fn create_external_table(
@@ -661,7 +664,7 @@ impl SessionContext {
     /// Creates an empty DataFrame.
     pub fn read_empty(&self) -> Result<DataFrame> {
         Ok(DataFrame::new(
-            self.state.clone(),
+            self.state(),
             LogicalPlanBuilder::empty(true).build()?,
         ))
     }
@@ -716,7 +719,7 @@ impl SessionContext {
     /// Creates a [`DataFrame`] for reading a custom [`TableProvider`].
     pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> 
Result<DataFrame> {
         Ok(DataFrame::new(
-            self.state.clone(),
+            self.state(),
             LogicalPlanBuilder::scan(UNNAMED_TABLE, 
provider_as_source(provider), None)?
                 .build()?,
         ))
@@ -726,7 +729,7 @@ impl SessionContext {
     pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
         let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
         Ok(DataFrame::new(
-            self.state.clone(),
+            self.state(),
             LogicalPlanBuilder::scan(
                 UNNAMED_TABLE,
                 provider_as_source(Arc::new(provider)),
@@ -946,7 +949,7 @@ impl SessionContext {
             None,
         )?
         .build()?;
-        Ok(DataFrame::new(self.state.clone(), plan))
+        Ok(DataFrame::new(self.state(), plan))
     }
 
     /// Return a [`TabelProvider`] for the specified table.

Reply via email to