This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 931cabc  CRUD Schema support for `BaseSessionContext` (#392)
931cabc is described below

commit 931cabc629a28bb47c030e7e17de59a19feb38dd
Author: Jeremy Dyer <[email protected]>
AuthorDate: Tue May 30 14:51:27 2023 -0400

    CRUD Schema support for `BaseSessionContext` (#392)
    
    * checkpoint commit
    
    * Introduce BaseSessionContext abstract class
    
    * Introduce abstract methods for CRUD schema operations
    
    * Clean up schema.rs file
---
 .gitignore             |   3 +
 datafusion/context.py  |  52 ++++++++++++
 datafusion/cudf.py     |  42 +++++++++-
 datafusion/pandas.py   |  32 ++++++-
 datafusion/polars.py   |  22 ++++-
 src/common.rs          |   7 ++
 src/common/function.rs |  55 ++++++++++++
 src/common/schema.rs   | 222 +++++++++++++++++++++++++++++++++++++++++++++++++
 8 files changed, 428 insertions(+), 7 deletions(-)

diff --git a/.gitignore b/.gitignore
index 365b89d..0030b90 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,6 +16,9 @@ dist
 # C extensions
 *.so
 
+# Python dist
+dist
+
 # pyenv
 #   For a library or package, you might want to ignore these files since the 
code is
 #   intended to run in multiple environments; otherwise, check them in:
diff --git a/datafusion/context.py b/datafusion/context.py
index aa9c9a8..a364621 100644
--- a/datafusion/context.py
+++ b/datafusion/context.py
@@ -16,6 +16,9 @@
 # under the License.
 
 from abc import ABC, abstractmethod
+from typing import Dict
+
+from datafusion.common import SqlSchema
 
 
 class BaseSessionContext(ABC):
@@ -24,6 +27,55 @@ class BaseSessionContext(ABC):
     shared amongst implementations using DataFusion as their SQL Parser/Engine
     """
 
+    DEFAULT_CATALOG_NAME = "root"
+    DEFAULT_SCHEMA_NAME = "datafusion"
+
+    @abstractmethod
+    def create_schema(
+        self,
+        schema_name: str,
+        **kwargs,
+    ):
+        """
+        Creates/Registers a logical container that holds database
+        objects such as tables, views, indexes, and other
+        related objects. It provides a way to group related database
+        objects together. A schema can be owned by a database
+        user and can be used to separate objects in different
+        logical groups for easy management.
+        """
+        pass
+
+    @abstractmethod
+    def update_schema(
+        self,
+        schema_name: str,
+        new_schema: SqlSchema,
+        **kwargs,
+    ):
+        """
+        Updates an existing schema in the SessionContext
+        """
+        pass
+
+    @abstractmethod
+    def drop_schema(
+        self,
+        schema_name: str,
+        **kwargs,
+    ):
+        """
+        Drops the specified Schema, based on name, from the current context
+        """
+        pass
+
+    @abstractmethod
+    def show_schemas(self, **kwargs) -> Dict[str, SqlSchema]:
+        """
+        Return all schemas in the current SessionContext impl.
+        """
+        pass
+
     @abstractmethod
     def register_table(
         self,
diff --git a/datafusion/cudf.py b/datafusion/cudf.py
index 594e5ef..e39daea 100644
--- a/datafusion/cudf.py
+++ b/datafusion/cudf.py
@@ -15,16 +15,36 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import logging
 import cudf
-import datafusion
 from datafusion.context import BaseSessionContext
 from datafusion.expr import Projection, TableScan, Column
 
+from datafusion.common import SqlSchema
+
+logger = logging.getLogger(__name__)
+
 
 class SessionContext(BaseSessionContext):
-    def __init__(self):
-        self.datafusion_ctx = datafusion.SessionContext()
-        self.parquet_tables = {}
+    def __init__(self, context, logging_level=logging.INFO):
+        """
+        Create a new Session.
+        """
+        # Cudf requires a provided context
+        self.context = context
+
+        # Set the logging level for this SQL context
+        logging.basicConfig(level=logging_level)
+
+        # Name of the root catalog
+        self.catalog_name = self.DEFAULT_CATALOG_NAME
+        # Name of the root schema
+        self.schema_name = self.DEFAULT_SCHEMA_NAME
+        # Add the schema to the context
+        sch = SqlSchema(self.schema_name)
+        self.schemas = {}
+        self.schemas[self.schema_name] = sch
+        self.context.register_schema(self.schema_name, sch)
 
     def to_cudf_expr(self, expr):
         # get Python wrapper for logical expression
@@ -52,6 +72,20 @@ class SessionContext(BaseSessionContext):
                 "unsupported logical operator: {}".format(type(node))
             )
 
+    def create_schema(self, schema_name: str, **kwargs):
+        logger.debug(f"Creating schema: {schema_name}")
+        self.schemas[schema_name] = SqlSchema(schema_name)
+        self.context.register_schema(schema_name, SqlSchema(schema_name))
+
+    def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
+        self.schemas[schema_name] = new_schema
+
+    def drop_schema(self, schema_name, **kwargs):
+        del self.schemas[schema_name]
+
+    def show_schemas(self, **kwargs):
+        return self.schemas
+
     def register_table(self, name, path, **kwargs):
         self.parquet_tables[name] = path
         self.datafusion_ctx.register_parquet(name, path)
diff --git a/datafusion/pandas.py b/datafusion/pandas.py
index 935d961..c2da83f 100644
--- a/datafusion/pandas.py
+++ b/datafusion/pandas.py
@@ -15,17 +15,33 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import logging
 import pandas as pd
 import datafusion
+from datafusion.common import SqlSchema
 from datafusion.context import BaseSessionContext
 from datafusion.expr import Projection, TableScan, Column
 
+logger = logging.getLogger(__name__)
+
 
 class SessionContext(BaseSessionContext):
-    def __init__(self):
+    def __init__(self, logging_level=logging.INFO):
         self.datafusion_ctx = datafusion.SessionContext()
         self.parquet_tables = {}
 
+        # Set the logging level for this SQL context
+        logging.basicConfig(level=logging_level)
+
+        # Name of the root catalog
+        self.catalog_name = self.DEFAULT_CATALOG_NAME
+        # Name of the root schema
+        self.schema_name = self.DEFAULT_SCHEMA_NAME
+        # Add the schema to the context
+        sch = SqlSchema(self.schema_name)
+        self.schemas[self.schema_name] = sch
+        self.context.register_schema(self.schema_name, sch)
+
     def to_pandas_expr(self, expr):
         # get Python wrapper for logical expression
         expr = expr.to_variant()
@@ -52,6 +68,20 @@ class SessionContext(BaseSessionContext):
                 "unsupported logical operator: {}".format(type(node))
             )
 
+    def create_schema(self, schema_name: str, **kwargs):
+        logger.debug(f"Creating schema: {schema_name}")
+        self.schemas[schema_name] = SqlSchema(schema_name)
+        self.context.register_schema(schema_name, SqlSchema(schema_name))
+
+    def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
+        self.schemas[schema_name] = new_schema
+
+    def drop_schema(self, schema_name, **kwargs):
+        del self.schemas[schema_name]
+
+    def show_schemas(self, **kwargs):
+        return self.schemas
+
     def register_table(self, name, path, **kwargs):
         self.parquet_tables[name] = path
         self.datafusion_ctx.register_parquet(name, path)
diff --git a/datafusion/polars.py b/datafusion/polars.py
index bbc1fd7..e4eb966 100644
--- a/datafusion/polars.py
+++ b/datafusion/polars.py
@@ -14,16 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import logging
 import polars
 import datafusion
 from datafusion.context import BaseSessionContext
 from datafusion.expr import Projection, TableScan, Aggregate
 from datafusion.expr import Column, AggregateFunction
 
+from datafusion.common import SqlSchema
+
+logger = logging.getLogger(__name__)
+
 
 class SessionContext(BaseSessionContext):
-    def __init__(self):
+    def __init__(self, logging_level=logging.INFO):
         self.datafusion_ctx = datafusion.SessionContext()
         self.parquet_tables = {}
 
@@ -75,6 +79,20 @@ class SessionContext(BaseSessionContext):
                 "unsupported logical operator: {}".format(type(node))
             )
 
+    def create_schema(self, schema_name: str, **kwargs):
+        logger.debug(f"Creating schema: {schema_name}")
+        self.schemas[schema_name] = SqlSchema(schema_name)
+        self.context.register_schema(schema_name, SqlSchema(schema_name))
+
+    def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
+        self.schemas[schema_name] = new_schema
+
+    def drop_schema(self, schema_name, **kwargs):
+        del self.schemas[schema_name]
+
+    def show_schemas(self, **kwargs):
+        return self.schemas
+
     def register_table(self, name, path, **kwargs):
         self.parquet_tables[name] = path
         self.datafusion_ctx.register_parquet(name, path)
diff --git a/src/common.rs b/src/common.rs
index 8a8e2ad..4552317 100644
--- a/src/common.rs
+++ b/src/common.rs
@@ -20,6 +20,8 @@ use pyo3::prelude::*;
 pub mod data_type;
 pub mod df_field;
 pub mod df_schema;
+pub mod function;
+pub mod schema;
 
 /// Initializes the `common` module to match the pattern of 
`datafusion-common` 
https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
 pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
@@ -29,5 +31,10 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
     m.add_class::<data_type::DataTypeMap>()?;
     m.add_class::<data_type::PythonType>()?;
     m.add_class::<data_type::SqlType>()?;
+    m.add_class::<schema::SqlTable>()?;
+    m.add_class::<schema::SqlSchema>()?;
+    m.add_class::<schema::SqlView>()?;
+    m.add_class::<schema::SqlStatistics>()?;
+    m.add_class::<function::SqlFunction>()?;
     Ok(())
 }
diff --git a/src/common/function.rs b/src/common/function.rs
new file mode 100644
index 0000000..a8d752f
--- /dev/null
+++ b/src/common/function.rs
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+
+use datafusion::arrow::datatypes::DataType;
+use pyo3::prelude::*;
+
+use super::data_type::PyDataType;
+
+#[pyclass(name = "SqlFunction", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlFunction {
+    pub name: String,
+    pub return_types: HashMap<Vec<DataType>, DataType>,
+    pub aggregation: bool,
+}
+
+impl SqlFunction {
+    pub fn new(
+        function_name: String,
+        input_types: Vec<PyDataType>,
+        return_type: PyDataType,
+        aggregation_bool: bool,
+    ) -> Self {
+        let mut func = Self {
+            name: function_name,
+            return_types: HashMap::new(),
+            aggregation: aggregation_bool,
+        };
+        func.add_type_mapping(input_types, return_type);
+        func
+    }
+
+    pub fn add_type_mapping(&mut self, input_types: Vec<PyDataType>, 
return_type: PyDataType) {
+        self.return_types.insert(
+            input_types.iter().map(|t| t.clone().into()).collect(),
+            return_type.into(),
+        );
+    }
+}
diff --git a/src/common/schema.rs b/src/common/schema.rs
new file mode 100644
index 0000000..3043193
--- /dev/null
+++ b/src/common/schema.rs
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::any::Any;
+
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource};
+use pyo3::prelude::*;
+
+use datafusion_optimizer::utils::split_conjunction;
+
+use super::{data_type::DataTypeMap, function::SqlFunction};
+
+#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlSchema {
+    #[pyo3(get, set)]
+    pub name: String,
+    #[pyo3(get, set)]
+    pub tables: Vec<SqlTable>,
+    #[pyo3(get, set)]
+    pub views: Vec<SqlView>,
+    #[pyo3(get, set)]
+    pub functions: Vec<SqlFunction>,
+}
+
+#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlTable {
+    #[pyo3(get, set)]
+    pub name: String,
+    #[pyo3(get, set)]
+    pub columns: Vec<(String, DataTypeMap)>,
+    #[pyo3(get, set)]
+    pub primary_key: Option<String>,
+    #[pyo3(get, set)]
+    pub foreign_keys: Vec<String>,
+    #[pyo3(get, set)]
+    pub indexes: Vec<String>,
+    #[pyo3(get, set)]
+    pub constraints: Vec<String>,
+    #[pyo3(get, set)]
+    pub statistics: SqlStatistics,
+    #[pyo3(get, set)]
+    pub filepath: Option<String>,
+}
+
+#[pymethods]
+impl SqlTable {
+    #[new]
+    pub fn new(
+        _schema_name: String,
+        table_name: String,
+        columns: Vec<(String, DataTypeMap)>,
+        row_count: f64,
+        filepath: Option<String>,
+    ) -> Self {
+        Self {
+            name: table_name,
+            columns,
+            primary_key: None,
+            foreign_keys: Vec::new(),
+            indexes: Vec::new(),
+            constraints: Vec::new(),
+            statistics: SqlStatistics::new(row_count),
+            filepath,
+        }
+    }
+}
+
+#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlView {
+    #[pyo3(get, set)]
+    pub name: String,
+    #[pyo3(get, set)]
+    pub definition: String, // SQL code that defines the view
+}
+
+#[pymethods]
+impl SqlSchema {
+    #[new]
+    pub fn new(schema_name: &str) -> Self {
+        Self {
+            name: schema_name.to_owned(),
+            tables: Vec::new(),
+            views: Vec::new(),
+            functions: Vec::new(),
+        }
+    }
+
+    pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
+        for tbl in &self.tables {
+            if tbl.name.eq(table_name) {
+                return Some(tbl.clone());
+            }
+        }
+        None
+    }
+
+    pub fn add_table(&mut self, table: SqlTable) {
+        self.tables.push(table);
+    }
+}
+
+/// SqlTable wrapper that is compatible with DataFusion logical query plans
+pub struct SqlTableSource {
+    schema: SchemaRef,
+    statistics: Option<SqlStatistics>,
+    filepath: Option<String>,
+}
+
+impl SqlTableSource {
+    /// Initialize a new `EmptyTable` from a schema
+    pub fn new(
+        schema: SchemaRef,
+        statistics: Option<SqlStatistics>,
+        filepath: Option<String>,
+    ) -> Self {
+        Self {
+            schema,
+            statistics,
+            filepath,
+        }
+    }
+
+    /// Access optional statistics associated with this table source
+    pub fn statistics(&self) -> Option<&SqlStatistics> {
+        self.statistics.as_ref()
+    }
+
+    /// Access optional filepath associated with this table source
+    #[allow(dead_code)]
+    pub fn filepath(&self) -> Option<&String> {
+        self.filepath.as_ref()
+    }
+}
+
+/// Implement TableSource, used in the logical query plan and in logical query 
optimizations
+impl TableSource for SqlTableSource {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+
+    fn supports_filter_pushdown(
+        &self,
+        filter: &Expr,
+    ) -> datafusion_common::Result<TableProviderFilterPushDown> {
+        let filters = split_conjunction(filter);
+        if filters.iter().all(|f| is_supported_push_down_expr(f)) {
+            // Push down filters to the tablescan operation if all are 
supported
+            Ok(TableProviderFilterPushDown::Exact)
+        } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
+            // Partially apply the filter in the TableScan but retain
+            // the Filter operator in the plan as well
+            Ok(TableProviderFilterPushDown::Inexact)
+        } else {
+            Ok(TableProviderFilterPushDown::Unsupported)
+        }
+    }
+
+    fn table_type(&self) -> datafusion_expr::TableType {
+        datafusion_expr::TableType::Base
+    }
+
+    #[allow(deprecated)]
+    fn supports_filters_pushdown(
+        &self,
+        filters: &[&Expr],
+    ) -> datafusion_common::Result<Vec<TableProviderFilterPushDown>> {
+        filters
+            .iter()
+            .map(|f| self.supports_filter_pushdown(f))
+            .collect()
+    }
+
+    fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> {
+        None
+    }
+}
+
+fn is_supported_push_down_expr(_expr: &Expr) -> bool {
+    // For now we support all kinds of expr's at this level
+    true
+}
+
+#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlStatistics {
+    row_count: f64,
+}
+
+#[pymethods]
+impl SqlStatistics {
+    #[new]
+    pub fn new(row_count: f64) -> Self {
+        Self { row_count }
+    }
+
+    #[pyo3(name = "getRowCount")]
+    pub fn get_row_count(&self) -> f64 {
+        self.row_count
+    }
+}

Reply via email to