This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new deeaa5632 Minor: Move TableProviderFactories up out of RuntimeEnv and 
into SessionState (#5477)
deeaa5632 is described below

commit deeaa5632ed99a58b91767261570756db736d158
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Mar 7 23:09:33 2023 +0100

    Minor: Move TableProviderFactories up out of RuntimeEnv and into 
SessionState (#5477)
---
 datafusion/core/src/execution/context.rs     | 75 +++++++++++++++++++++++-----
 datafusion/core/src/execution/runtime_env.rs | 38 +-------------
 datafusion/core/tests/sql/create_drop.rs     | 21 +++++---
 datafusion/proto/src/logical_plan/mod.rs     | 13 +++--
 4 files changed, 84 insertions(+), 63 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index ce20425a5..0340b4761 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -18,7 +18,11 @@
 //! SessionContext contains methods for registering data sources and executing 
queries
 use crate::{
     catalog::catalog::{CatalogList, MemoryCatalogList},
-    datasource::listing::{ListingOptions, ListingTable},
+    datasource::{
+        datasource::TableProviderFactory,
+        listing::{ListingOptions, ListingTable},
+        listing_table_factory::ListingTableFactory,
+    },
     datasource::{MemTable, ViewTable},
     logical_expr::{PlanType, ToStringifiedPlan},
     optimizer::optimizer::Optimizer,
@@ -278,6 +282,15 @@ impl SessionContext {
         self.session_id.clone()
     }
 
+    /// Return the [`TableFactoryProvider`] that is registered for the
+    /// specified file type, if any.
+    pub fn table_factory(
+        &self,
+        file_type: &str,
+    ) -> Option<Arc<dyn TableProviderFactory>> {
+        self.state.read().table_factories().get(file_type).cloned()
+    }
+
     /// Return the `enable_ident_normalization` of this Session
     pub fn enable_ident_normalization(&self) -> bool {
         self.state
@@ -579,16 +592,16 @@ impl SessionContext {
     ) -> Result<Arc<dyn TableProvider>> {
         let state = self.state.read().clone();
         let file_type = cmd.file_type.to_uppercase();
-        let factory = &state
-            .runtime_env
-            .table_factories
-            .get(file_type.as_str())
-            .ok_or_else(|| {
-                DataFusionError::Execution(format!(
-                    "Unable to find factory for {}",
-                    cmd.file_type
-                ))
-            })?;
+        let factory =
+            &state
+                .table_factories
+                .get(file_type.as_str())
+                .ok_or_else(|| {
+                    DataFusionError::Execution(format!(
+                        "Unable to find factory for {}",
+                        cmd.file_type
+                    ))
+                })?;
         let table = (*factory).create(&state, cmd).await?;
         Ok(table)
     }
@@ -1507,6 +1520,14 @@ pub struct SessionState {
     config: SessionConfig,
     /// Execution properties
     execution_props: ExecutionProps,
+    /// TableProviderFactories for different file formats.
+    ///
+    /// Maps strings like "JSON" to an instance of  [`TableProviderFactory`]
+    ///
+    /// This is used to create [`TableProvider`] instances for the
+    /// `CREATE EXTERNAL TABLE ... STORED AS <FORMAT>` for custom file
+    /// formats other than those built into DataFusion
+    table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
     /// Runtime environment
     runtime_env: Arc<RuntimeEnv>,
 }
@@ -1540,6 +1561,15 @@ impl SessionState {
     ) -> Self {
         let session_id = Uuid::new_v4().to_string();
 
+        // Create table_factories for all default formats
+        let mut table_factories: HashMap<String, Arc<dyn 
TableProviderFactory>> =
+            HashMap::new();
+        table_factories.insert("PARQUET".into(), 
Arc::new(ListingTableFactory::new()));
+        table_factories.insert("CSV".into(), 
Arc::new(ListingTableFactory::new()));
+        table_factories.insert("JSON".into(), 
Arc::new(ListingTableFactory::new()));
+        table_factories.insert("NDJSON".into(), 
Arc::new(ListingTableFactory::new()));
+        table_factories.insert("AVRO".into(), 
Arc::new(ListingTableFactory::new()));
+
         if config.create_default_catalog_and_schema() {
             let default_catalog = MemoryCatalogProvider::new();
 
@@ -1550,7 +1580,12 @@ impl SessionState {
                 )
                 .expect("memory catalog provider can register schema");
 
-            Self::register_default_schema(&config, &runtime, &default_catalog);
+            Self::register_default_schema(
+                &config,
+                &table_factories,
+                &runtime,
+                &default_catalog,
+            );
 
             catalog_list.register_catalog(
                 config.config_options().catalog.default_catalog.clone(),
@@ -1619,11 +1654,13 @@ impl SessionState {
             config,
             execution_props: ExecutionProps::new(),
             runtime_env: runtime,
+            table_factories,
         }
     }
 
     fn register_default_schema(
         config: &SessionConfig,
+        table_factories: &HashMap<String, Arc<dyn TableProviderFactory>>,
         runtime: &Arc<RuntimeEnv>,
         default_catalog: &MemoryCatalogProvider,
     ) {
@@ -1650,7 +1687,7 @@ impl SessionState {
             Ok(store) => store,
             _ => return,
         };
-        let factory = match runtime.table_factories.get(format.as_str()) {
+        let factory = match table_factories.get(format.as_str()) {
             Some(factory) => factory,
             _ => return,
         };
@@ -1756,6 +1793,18 @@ impl SessionState {
         self
     }
 
+    /// Get the table factories
+    pub fn table_factories(&self) -> &HashMap<String, Arc<dyn 
TableProviderFactory>> {
+        &self.table_factories
+    }
+
+    /// Get the table factories
+    pub fn table_factories_mut(
+        &mut self,
+    ) -> &mut HashMap<String, Arc<dyn TableProviderFactory>> {
+        &mut self.table_factories
+    }
+
     /// Convert a SQL string into an AST Statement
     pub fn sql_to_statement(
         &self,
diff --git a/datafusion/core/src/execution/runtime_env.rs 
b/datafusion/core/src/execution/runtime_env.rs
index d559e7c7f..1a738cdd3 100644
--- a/datafusion/core/src/execution/runtime_env.rs
+++ b/datafusion/core/src/execution/runtime_env.rs
@@ -22,10 +22,7 @@ use crate::{
     error::Result,
     execution::disk_manager::{DiskManager, DiskManagerConfig},
 };
-use std::collections::HashMap;
 
-use crate::datasource::datasource::TableProviderFactory;
-use crate::datasource::listing_table_factory::ListingTableFactory;
 use crate::datasource::object_store::ObjectStoreRegistry;
 use crate::execution::memory_pool::{GreedyMemoryPool, MemoryPool, 
UnboundedMemoryPool};
 use datafusion_common::DataFusionError;
@@ -44,8 +41,6 @@ pub struct RuntimeEnv {
     pub disk_manager: Arc<DiskManager>,
     /// Object Store Registry
     pub object_store_registry: Arc<ObjectStoreRegistry>,
-    /// TableProviderFactories
-    pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
 }
 
 impl Debug for RuntimeEnv {
@@ -61,7 +56,6 @@ impl RuntimeEnv {
             memory_pool,
             disk_manager,
             object_store_registry,
-            table_factories,
         } = config;
 
         let memory_pool =
@@ -71,7 +65,6 @@ impl RuntimeEnv {
             memory_pool,
             disk_manager: DiskManager::try_new(disk_manager)?,
             object_store_registry,
-            table_factories,
         })
     }
 
@@ -94,14 +87,6 @@ impl RuntimeEnv {
             .register_store(scheme, host, object_store)
     }
 
-    /// Registers TableFactories
-    pub fn register_table_factories(
-        &mut self,
-        table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
-    ) {
-        self.table_factories.extend(table_factories)
-    }
-
     /// Retrieves a `ObjectStore` instance for a url by consulting the
     /// registery. See [`ObjectStoreRegistry::get_by_url`] for more
     /// details.
@@ -129,24 +114,12 @@ pub struct RuntimeConfig {
     pub memory_pool: Option<Arc<dyn MemoryPool>>,
     /// ObjectStoreRegistry to get object store based on url
     pub object_store_registry: Arc<ObjectStoreRegistry>,
-    /// Custom table factories for things like deltalake that are not part of 
core datafusion
-    pub table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
 }
 
 impl RuntimeConfig {
     /// New with default values
     pub fn new() -> Self {
-        let mut table_factories: HashMap<String, Arc<dyn 
TableProviderFactory>> =
-            HashMap::new();
-        table_factories.insert("PARQUET".into(), 
Arc::new(ListingTableFactory::new()));
-        table_factories.insert("CSV".into(), 
Arc::new(ListingTableFactory::new()));
-        table_factories.insert("JSON".into(), 
Arc::new(ListingTableFactory::new()));
-        table_factories.insert("NDJSON".into(), 
Arc::new(ListingTableFactory::new()));
-        table_factories.insert("AVRO".into(), 
Arc::new(ListingTableFactory::new()));
-        Self {
-            table_factories,
-            ..Default::default()
-        }
+        Default::default()
     }
 
     /// Customize disk manager
@@ -170,15 +143,6 @@ impl RuntimeConfig {
         self
     }
 
-    /// Customize object store registry
-    pub fn with_table_factories(
-        mut self,
-        table_factories: HashMap<String, Arc<dyn TableProviderFactory>>,
-    ) -> Self {
-        self.table_factories = table_factories;
-        self
-    }
-
     /// Specify the total memory to use while running the DataFusion
     /// plan to `max_memory * memory_fraction` in bytes.
     ///
diff --git a/datafusion/core/tests/sql/create_drop.rs 
b/datafusion/core/tests/sql/create_drop.rs
index 9326a9a5c..5bb8abb14 100644
--- a/datafusion/core/tests/sql/create_drop.rs
+++ b/datafusion/core/tests/sql/create_drop.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use datafusion::execution::context::SessionState;
 use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
 use datafusion::test_util::TestTableFactory;
 
@@ -106,12 +107,14 @@ async fn sql_create_table_exists() -> Result<()> {
 
 #[tokio::test]
 async fn create_custom_table() -> Result<()> {
-    let mut cfg = RuntimeConfig::new();
-    cfg.table_factories
-        .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {}));
+    let cfg = RuntimeConfig::new();
     let env = RuntimeEnv::new(cfg).unwrap();
     let ses = SessionConfig::new();
-    let ctx = SessionContext::with_config_rt(ses, Arc::new(env));
+    let mut state = SessionState::with_config_rt(ses, Arc::new(env));
+    state
+        .table_factories_mut()
+        .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {}));
+    let ctx = SessionContext::with_state(state);
 
     let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 
's3://bucket/schema/table';";
     ctx.sql(sql).await.unwrap();
@@ -126,12 +129,14 @@ async fn create_custom_table() -> Result<()> {
 
 #[tokio::test]
 async fn create_external_table_with_ddl() -> Result<()> {
-    let mut cfg = RuntimeConfig::new();
-    cfg.table_factories
-        .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {}));
+    let cfg = RuntimeConfig::new();
     let env = RuntimeEnv::new(cfg).unwrap();
     let ses = SessionConfig::new();
-    let ctx = SessionContext::with_config_rt(ses, Arc::new(env));
+    let mut state = SessionState::with_config_rt(ses, Arc::new(env));
+    state
+        .table_factories_mut()
+        .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {}));
+    let ctx = SessionContext::with_state(state);
 
     let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool 
boolean) STORED AS MOCKTABLE LOCATION 'mockprotocol://path/to/table';";
     ctx.sql(sql).await.unwrap();
diff --git a/datafusion/proto/src/logical_plan/mod.rs 
b/datafusion/proto/src/logical_plan/mod.rs
index 2fd7ed725..706128259 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -496,10 +496,9 @@ impl AsLogicalPlan for LogicalPlanNode {
                 };
 
                 let file_type = create_extern_table.file_type.as_str();
-                let env = ctx.runtime_env();
-                if !env.table_factories.contains_key(file_type) {
+                if ctx.table_factory(file_type).is_none() {
                     Err(DataFusionError::Internal(format!(
-                        "No TableProvider for file type: {file_type}"
+                        "No TableProviderFactory for file type: {file_type}"
                     )))?
                 }
 
@@ -1377,6 +1376,7 @@ mod roundtrip_tests {
     };
     use datafusion::datasource::datasource::TableProviderFactory;
     use datafusion::datasource::TableProvider;
+    use datafusion::execution::context::SessionState;
     use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
     use datafusion::physical_plan::functions::make_scalar_function;
     use datafusion::prelude::{
@@ -1523,10 +1523,13 @@ mod roundtrip_tests {
         let mut table_factories: HashMap<String, Arc<dyn 
TableProviderFactory>> =
             HashMap::new();
         table_factories.insert("TESTTABLE".to_string(), 
Arc::new(TestTableFactory {}));
-        let cfg = RuntimeConfig::new().with_table_factories(table_factories);
+        let cfg = RuntimeConfig::new();
         let env = RuntimeEnv::new(cfg).unwrap();
         let ses = SessionConfig::new();
-        let ctx = SessionContext::with_config_rt(ses, Arc::new(env));
+        let mut state = SessionState::with_config_rt(ses, Arc::new(env));
+        // replace factories
+        *state.table_factories_mut() = table_factories;
+        let ctx = SessionContext::with_state(state);
 
         let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 
's3://bucket/schema/table';";
         ctx.sql(sql).await.unwrap();

Reply via email to