This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 931cabc CRUD Schema support for `BaseSessionContext` (#392)
931cabc is described below
commit 931cabc629a28bb47c030e7e17de59a19feb38dd
Author: Jeremy Dyer <[email protected]>
AuthorDate: Tue May 30 14:51:27 2023 -0400
CRUD Schema support for `BaseSessionContext` (#392)
* checkpoint commit
* Introduce BaseSessionContext abstract class
* Introduce abstract methods for CRUD schema operations
* Clean up schema.rs file
---
.gitignore | 3 +
datafusion/context.py | 52 ++++++++++++
datafusion/cudf.py | 42 +++++++++-
datafusion/pandas.py | 32 ++++++-
datafusion/polars.py | 22 ++++-
src/common.rs | 7 ++
src/common/function.rs | 55 ++++++++++++
src/common/schema.rs | 222 +++++++++++++++++++++++++++++++++++++++++++++++++
8 files changed, 428 insertions(+), 7 deletions(-)
diff --git a/.gitignore b/.gitignore
index 365b89d..0030b90 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,6 +16,9 @@ dist
# C extensions
*.so
+# Python dist
+dist
+
# pyenv
# For a library or package, you might want to ignore these files since the
code is
# intended to run in multiple environments; otherwise, check them in:
diff --git a/datafusion/context.py b/datafusion/context.py
index aa9c9a8..a364621 100644
--- a/datafusion/context.py
+++ b/datafusion/context.py
@@ -16,6 +16,9 @@
# under the License.
from abc import ABC, abstractmethod
+from typing import Dict
+
+from datafusion.common import SqlSchema
class BaseSessionContext(ABC):
@@ -24,6 +27,55 @@ class BaseSessionContext(ABC):
shared amongst implementations using DataFusion as their SQL Parser/Engine
"""
+ DEFAULT_CATALOG_NAME = "root"
+ DEFAULT_SCHEMA_NAME = "datafusion"
+
+ @abstractmethod
+ def create_schema(
+ self,
+ schema_name: str,
+ **kwargs,
+ ):
+ """
+ Creates/Registers a logical container that holds database
+ objects such as tables, views, indexes, and other
+ related objects. It provides a way to group related database
+ objects together. A schema can be owned by a database
+ user and can be used to separate objects in different
+ logical groups for easy management.
+ """
+ pass
+
+ @abstractmethod
+ def update_schema(
+ self,
+ schema_name: str,
+ new_schema: SqlSchema,
+ **kwargs,
+ ):
+ """
+ Updates an existing schema in the SessionContext
+ """
+ pass
+
+ @abstractmethod
+ def drop_schema(
+ self,
+ schema_name: str,
+ **kwargs,
+ ):
+ """
+ Drops the specified Schema, based on name, from the current context
+ """
+ pass
+
+ @abstractmethod
+ def show_schemas(self, **kwargs) -> Dict[str, SqlSchema]:
+ """
+ Return all schemas in the current SessionContext impl.
+ """
+ pass
+
@abstractmethod
def register_table(
self,
diff --git a/datafusion/cudf.py b/datafusion/cudf.py
index 594e5ef..e39daea 100644
--- a/datafusion/cudf.py
+++ b/datafusion/cudf.py
@@ -15,16 +15,36 @@
# specific language governing permissions and limitations
# under the License.
+import logging
import cudf
-import datafusion
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column
+from datafusion.common import SqlSchema
+
+logger = logging.getLogger(__name__)
+
class SessionContext(BaseSessionContext):
- def __init__(self):
- self.datafusion_ctx = datafusion.SessionContext()
- self.parquet_tables = {}
+ def __init__(self, context, logging_level=logging.INFO):
+ """
+ Create a new Session.
+ """
+ # Cudf requires a provided context
+ self.context = context
+
+ # Set the logging level for this SQL context
+ logging.basicConfig(level=logging_level)
+
+ # Name of the root catalog
+ self.catalog_name = self.DEFAULT_CATALOG_NAME
+ # Name of the root schema
+ self.schema_name = self.DEFAULT_SCHEMA_NAME
+ # Add the schema to the context
+ sch = SqlSchema(self.schema_name)
+ self.schemas = {}
+ self.schemas[self.schema_name] = sch
+ self.context.register_schema(self.schema_name, sch)
def to_cudf_expr(self, expr):
# get Python wrapper for logical expression
@@ -52,6 +72,20 @@ class SessionContext(BaseSessionContext):
"unsupported logical operator: {}".format(type(node))
)
+ def create_schema(self, schema_name: str, **kwargs):
+ logger.debug(f"Creating schema: {schema_name}")
+ self.schemas[schema_name] = SqlSchema(schema_name)
+ self.context.register_schema(schema_name, SqlSchema(schema_name))
+
+ def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
+ self.schemas[schema_name] = new_schema
+
+ def drop_schema(self, schema_name, **kwargs):
+ del self.schemas[schema_name]
+
+ def show_schemas(self, **kwargs):
+ return self.schemas
+
def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)
diff --git a/datafusion/pandas.py b/datafusion/pandas.py
index 935d961..c2da83f 100644
--- a/datafusion/pandas.py
+++ b/datafusion/pandas.py
@@ -15,17 +15,33 @@
# specific language governing permissions and limitations
# under the License.
+import logging
import pandas as pd
import datafusion
+from datafusion.common import SqlSchema
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column
+logger = logging.getLogger(__name__)
+
class SessionContext(BaseSessionContext):
- def __init__(self):
+ def __init__(self, logging_level=logging.INFO):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}
+ # Set the logging level for this SQL context
+ logging.basicConfig(level=logging_level)
+
+ # Name of the root catalog
+ self.catalog_name = self.DEFAULT_CATALOG_NAME
+ # Name of the root schema
+ self.schema_name = self.DEFAULT_SCHEMA_NAME
+ # Add the schema to the context
+ sch = SqlSchema(self.schema_name)
+ self.schemas[self.schema_name] = sch
+ self.context.register_schema(self.schema_name, sch)
+
def to_pandas_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
@@ -52,6 +68,20 @@ class SessionContext(BaseSessionContext):
"unsupported logical operator: {}".format(type(node))
)
+ def create_schema(self, schema_name: str, **kwargs):
+ logger.debug(f"Creating schema: {schema_name}")
+ self.schemas[schema_name] = SqlSchema(schema_name)
+ self.context.register_schema(schema_name, SqlSchema(schema_name))
+
+ def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
+ self.schemas[schema_name] = new_schema
+
+ def drop_schema(self, schema_name, **kwargs):
+ del self.schemas[schema_name]
+
+ def show_schemas(self, **kwargs):
+ return self.schemas
+
def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)
diff --git a/datafusion/polars.py b/datafusion/polars.py
index bbc1fd7..e4eb966 100644
--- a/datafusion/polars.py
+++ b/datafusion/polars.py
@@ -14,16 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+import logging
import polars
import datafusion
from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Aggregate
from datafusion.expr import Column, AggregateFunction
+from datafusion.common import SqlSchema
+
+logger = logging.getLogger(__name__)
+
class SessionContext(BaseSessionContext):
- def __init__(self):
+ def __init__(self, logging_level=logging.INFO):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}
@@ -75,6 +79,20 @@ class SessionContext(BaseSessionContext):
"unsupported logical operator: {}".format(type(node))
)
+ def create_schema(self, schema_name: str, **kwargs):
+ logger.debug(f"Creating schema: {schema_name}")
+ self.schemas[schema_name] = SqlSchema(schema_name)
+ self.context.register_schema(schema_name, SqlSchema(schema_name))
+
+ def update_schema(self, schema_name: str, new_schema: SqlSchema, **kwargs):
+ self.schemas[schema_name] = new_schema
+
+ def drop_schema(self, schema_name, **kwargs):
+ del self.schemas[schema_name]
+
+ def show_schemas(self, **kwargs):
+ return self.schemas
+
def register_table(self, name, path, **kwargs):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)
diff --git a/src/common.rs b/src/common.rs
index 8a8e2ad..4552317 100644
--- a/src/common.rs
+++ b/src/common.rs
@@ -20,6 +20,8 @@ use pyo3::prelude::*;
pub mod data_type;
pub mod df_field;
pub mod df_schema;
+pub mod function;
+pub mod schema;
/// Initializes the `common` module to match the pattern of
`datafusion-common`
https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
@@ -29,5 +31,10 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<data_type::DataTypeMap>()?;
m.add_class::<data_type::PythonType>()?;
m.add_class::<data_type::SqlType>()?;
+ m.add_class::<schema::SqlTable>()?;
+ m.add_class::<schema::SqlSchema>()?;
+ m.add_class::<schema::SqlView>()?;
+ m.add_class::<schema::SqlStatistics>()?;
+ m.add_class::<function::SqlFunction>()?;
Ok(())
}
diff --git a/src/common/function.rs b/src/common/function.rs
new file mode 100644
index 0000000..a8d752f
--- /dev/null
+++ b/src/common/function.rs
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+
+use datafusion::arrow::datatypes::DataType;
+use pyo3::prelude::*;
+
+use super::data_type::PyDataType;
+
+#[pyclass(name = "SqlFunction", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlFunction {
+ pub name: String,
+ pub return_types: HashMap<Vec<DataType>, DataType>,
+ pub aggregation: bool,
+}
+
+impl SqlFunction {
+ pub fn new(
+ function_name: String,
+ input_types: Vec<PyDataType>,
+ return_type: PyDataType,
+ aggregation_bool: bool,
+ ) -> Self {
+ let mut func = Self {
+ name: function_name,
+ return_types: HashMap::new(),
+ aggregation: aggregation_bool,
+ };
+ func.add_type_mapping(input_types, return_type);
+ func
+ }
+
+ pub fn add_type_mapping(&mut self, input_types: Vec<PyDataType>,
return_type: PyDataType) {
+ self.return_types.insert(
+ input_types.iter().map(|t| t.clone().into()).collect(),
+ return_type.into(),
+ );
+ }
+}
diff --git a/src/common/schema.rs b/src/common/schema.rs
new file mode 100644
index 0000000..3043193
--- /dev/null
+++ b/src/common/schema.rs
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::any::Any;
+
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource};
+use pyo3::prelude::*;
+
+use datafusion_optimizer::utils::split_conjunction;
+
+use super::{data_type::DataTypeMap, function::SqlFunction};
+
+#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlSchema {
+ #[pyo3(get, set)]
+ pub name: String,
+ #[pyo3(get, set)]
+ pub tables: Vec<SqlTable>,
+ #[pyo3(get, set)]
+ pub views: Vec<SqlView>,
+ #[pyo3(get, set)]
+ pub functions: Vec<SqlFunction>,
+}
+
+#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlTable {
+ #[pyo3(get, set)]
+ pub name: String,
+ #[pyo3(get, set)]
+ pub columns: Vec<(String, DataTypeMap)>,
+ #[pyo3(get, set)]
+ pub primary_key: Option<String>,
+ #[pyo3(get, set)]
+ pub foreign_keys: Vec<String>,
+ #[pyo3(get, set)]
+ pub indexes: Vec<String>,
+ #[pyo3(get, set)]
+ pub constraints: Vec<String>,
+ #[pyo3(get, set)]
+ pub statistics: SqlStatistics,
+ #[pyo3(get, set)]
+ pub filepath: Option<String>,
+}
+
+#[pymethods]
+impl SqlTable {
+ #[new]
+ pub fn new(
+ _schema_name: String,
+ table_name: String,
+ columns: Vec<(String, DataTypeMap)>,
+ row_count: f64,
+ filepath: Option<String>,
+ ) -> Self {
+ Self {
+ name: table_name,
+ columns,
+ primary_key: None,
+ foreign_keys: Vec::new(),
+ indexes: Vec::new(),
+ constraints: Vec::new(),
+ statistics: SqlStatistics::new(row_count),
+ filepath,
+ }
+ }
+}
+
+#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlView {
+ #[pyo3(get, set)]
+ pub name: String,
+ #[pyo3(get, set)]
+ pub definition: String, // SQL code that defines the view
+}
+
+#[pymethods]
+impl SqlSchema {
+ #[new]
+ pub fn new(schema_name: &str) -> Self {
+ Self {
+ name: schema_name.to_owned(),
+ tables: Vec::new(),
+ views: Vec::new(),
+ functions: Vec::new(),
+ }
+ }
+
+ pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
+ for tbl in &self.tables {
+ if tbl.name.eq(table_name) {
+ return Some(tbl.clone());
+ }
+ }
+ None
+ }
+
+ pub fn add_table(&mut self, table: SqlTable) {
+ self.tables.push(table);
+ }
+}
+
+/// SqlTable wrapper that is compatible with DataFusion logical query plans
+pub struct SqlTableSource {
+ schema: SchemaRef,
+ statistics: Option<SqlStatistics>,
+ filepath: Option<String>,
+}
+
+impl SqlTableSource {
+ /// Initialize a new `EmptyTable` from a schema
+ pub fn new(
+ schema: SchemaRef,
+ statistics: Option<SqlStatistics>,
+ filepath: Option<String>,
+ ) -> Self {
+ Self {
+ schema,
+ statistics,
+ filepath,
+ }
+ }
+
+ /// Access optional statistics associated with this table source
+ pub fn statistics(&self) -> Option<&SqlStatistics> {
+ self.statistics.as_ref()
+ }
+
+ /// Access optional filepath associated with this table source
+ #[allow(dead_code)]
+ pub fn filepath(&self) -> Option<&String> {
+ self.filepath.as_ref()
+ }
+}
+
+/// Implement TableSource, used in the logical query plan and in logical query
optimizations
+impl TableSource for SqlTableSource {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn supports_filter_pushdown(
+ &self,
+ filter: &Expr,
+ ) -> datafusion_common::Result<TableProviderFilterPushDown> {
+ let filters = split_conjunction(filter);
+ if filters.iter().all(|f| is_supported_push_down_expr(f)) {
+ // Push down filters to the tablescan operation if all are
supported
+ Ok(TableProviderFilterPushDown::Exact)
+ } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
+ // Partially apply the filter in the TableScan but retain
+ // the Filter operator in the plan as well
+ Ok(TableProviderFilterPushDown::Inexact)
+ } else {
+ Ok(TableProviderFilterPushDown::Unsupported)
+ }
+ }
+
+ fn table_type(&self) -> datafusion_expr::TableType {
+ datafusion_expr::TableType::Base
+ }
+
+ #[allow(deprecated)]
+ fn supports_filters_pushdown(
+ &self,
+ filters: &[&Expr],
+ ) -> datafusion_common::Result<Vec<TableProviderFilterPushDown>> {
+ filters
+ .iter()
+ .map(|f| self.supports_filter_pushdown(f))
+ .collect()
+ }
+
+ fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> {
+ None
+ }
+}
+
+fn is_supported_push_down_expr(_expr: &Expr) -> bool {
+ // For now we support all kinds of expr's at this level
+ true
+}
+
+#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)]
+#[derive(Debug, Clone)]
+pub struct SqlStatistics {
+ row_count: f64,
+}
+
+#[pymethods]
+impl SqlStatistics {
+ #[new]
+ pub fn new(row_count: f64) -> Self {
+ Self { row_count }
+ }
+
+ #[pyo3(name = "getRowCount")]
+ pub fn get_row_count(&self) -> f64 {
+ self.row_count
+ }
+}