This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 5664a1e Introduce BaseSessionContext abstract class (#390)
5664a1e is described below
commit 5664a1e38f8b45af15afd60bcef841df01da655e
Author: Jeremy Dyer <[email protected]>
AuthorDate: Fri May 26 18:45:05 2023 -0400
Introduce BaseSessionContext abstract class (#390)
---
examples/sql-on-polars.py => datafusion/context.py | 35 +++++++++++++++++-----
datafusion/cudf.py | 14 +++++----
datafusion/pandas.py | 14 +++++----
datafusion/polars.py | 14 +++++----
docs/mdbook/src/usage/create-table.md | 2 +-
examples/sql-on-cudf.py | 2 +-
examples/sql-on-pandas.py | 2 +-
examples/sql-on-polars.py | 2 +-
8 files changed, 58 insertions(+), 27 deletions(-)
diff --git a/examples/sql-on-polars.py b/datafusion/context.py
similarity index 51%
copy from examples/sql-on-polars.py
copy to datafusion/context.py
index c208114..aa9c9a8 100644
--- a/examples/sql-on-polars.py
+++ b/datafusion/context.py
@@ -6,7 +6,7 @@
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
@@ -15,12 +15,31 @@
# specific language governing permissions and limitations
# under the License.
-from datafusion.polars import SessionContext
+from abc import ABC, abstractmethod
-ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
-df = ctx.sql(
- "select passenger_count, count(*) from taxi group by passenger_count"
-)
-print(df)
+class BaseSessionContext(ABC):
+ """
+ Abstraction defining all methods, properties, and common functionality
+ shared amongst implementations using DataFusion as their SQL Parser/Engine
+ """
+
+ @abstractmethod
+ def register_table(
+ self,
+ table_name: str,
+ path: str,
+ **kwargs,
+ ):
+ pass
+
+ # TODO: Remove abstraction, this functionality can be shared
+ # between all implementing classes since it just prints the
+ # logical plan from DataFusion
+ @abstractmethod
+ def explain(self, sql):
+ pass
+
+ @abstractmethod
+ def sql(self, sql):
+ pass
diff --git a/datafusion/cudf.py b/datafusion/cudf.py
index d5f0215..594e5ef 100644
--- a/datafusion/cudf.py
+++ b/datafusion/cudf.py
@@ -17,18 +17,15 @@
import cudf
import datafusion
+from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column
-class SessionContext:
+class SessionContext(BaseSessionContext):
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}
- def register_parquet(self, name, path):
- self.parquet_tables[name] = path
- self.datafusion_ctx.register_parquet(name, path)
-
def to_cudf_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
@@ -55,6 +52,13 @@ class SessionContext:
"unsupported logical operator: {}".format(type(node))
)
+ def register_table(self, name, path, **kwargs):
+ self.parquet_tables[name] = path
+ self.datafusion_ctx.register_parquet(name, path)
+
+ def explain(self, sql):
+ super.explain()
+
def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
diff --git a/datafusion/pandas.py b/datafusion/pandas.py
index f8e5651..935d961 100644
--- a/datafusion/pandas.py
+++ b/datafusion/pandas.py
@@ -17,18 +17,15 @@
import pandas as pd
import datafusion
+from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Column
-class SessionContext:
+class SessionContext(BaseSessionContext):
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}
- def register_parquet(self, name, path):
- self.parquet_tables[name] = path
- self.datafusion_ctx.register_parquet(name, path)
-
def to_pandas_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
@@ -55,6 +52,13 @@ class SessionContext:
"unsupported logical operator: {}".format(type(node))
)
+ def register_table(self, name, path, **kwargs):
+ self.parquet_tables[name] = path
+ self.datafusion_ctx.register_parquet(name, path)
+
+ def explain(self, sql):
+ super.explain()
+
def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
diff --git a/datafusion/polars.py b/datafusion/polars.py
index a1bafbe..bbc1fd7 100644
--- a/datafusion/polars.py
+++ b/datafusion/polars.py
@@ -17,19 +17,16 @@
import polars
import datafusion
+from datafusion.context import BaseSessionContext
from datafusion.expr import Projection, TableScan, Aggregate
from datafusion.expr import Column, AggregateFunction
-class SessionContext:
+class SessionContext(BaseSessionContext):
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}
- def register_parquet(self, name, path):
- self.parquet_tables[name] = path
- self.datafusion_ctx.register_parquet(name, path)
-
def to_polars_expr(self, expr):
# get Python wrapper for logical expression
expr = expr.to_variant()
@@ -78,6 +75,13 @@ class SessionContext:
"unsupported logical operator: {}".format(type(node))
)
+ def register_table(self, name, path, **kwargs):
+ self.parquet_tables[name] = path
+ self.datafusion_ctx.register_parquet(name, path)
+
+ def explain(self, sql):
+ super.explain()
+
def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
diff --git a/docs/mdbook/src/usage/create-table.md
b/docs/mdbook/src/usage/create-table.md
index 332863a..98870fa 100644
--- a/docs/mdbook/src/usage/create-table.md
+++ b/docs/mdbook/src/usage/create-table.md
@@ -55,5 +55,5 @@ ctx.register_csv("csv_1e8", "G1_1e8_1e2_0_0.csv")
You can read a Parquet file into a DataFusion DataFrame. Here's how to read
the `yellow_tripdata_2021-01.parquet` file into a table named `taxi`.
```python
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
```
diff --git a/examples/sql-on-cudf.py b/examples/sql-on-cudf.py
index 999756f..b64d8f0 100644
--- a/examples/sql-on-cudf.py
+++ b/examples/sql-on-cudf.py
@@ -19,6 +19,6 @@ from datafusion.cudf import SessionContext
ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
df = ctx.sql("select passenger_count from taxi")
print(df)
diff --git a/examples/sql-on-pandas.py b/examples/sql-on-pandas.py
index 0efd776..e3312a2 100644
--- a/examples/sql-on-pandas.py
+++ b/examples/sql-on-pandas.py
@@ -19,6 +19,6 @@ from datafusion.pandas import SessionContext
ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
df = ctx.sql("select passenger_count from taxi")
print(df)
diff --git a/examples/sql-on-polars.py b/examples/sql-on-polars.py
index c208114..dd7a9e0 100644
--- a/examples/sql-on-polars.py
+++ b/examples/sql-on-polars.py
@@ -19,7 +19,7 @@ from datafusion.polars import SessionContext
ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
df = ctx.sql(
"select passenger_count, count(*) from taxi group by passenger_count"
)