This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 5664a1e  Introduce BaseSessionContext abstract class (#390)
5664a1e is described below

commit 5664a1e38f8b45af15afd60bcef841df01da655e
Author: Jeremy Dyer <[email protected]>
AuthorDate: Fri May 26 18:45:05 2023 -0400

    Introduce BaseSessionContext abstract class (#390)
---
 examples/sql-on-polars.py => datafusion/context.py | 35 +++++++++++++++++-----
 datafusion/cudf.py                                 | 14 +++++----
 datafusion/pandas.py                               | 14 +++++----
 datafusion/polars.py                               | 14 +++++----
 docs/mdbook/src/usage/create-table.md              |  2 +-
 examples/sql-on-cudf.py                            |  2 +-
 examples/sql-on-pandas.py                          |  2 +-
 examples/sql-on-polars.py                          |  2 +-
 8 files changed, 58 insertions(+), 27 deletions(-)

diff --git a/examples/sql-on-polars.py b/datafusion/context.py
similarity index 51%
copy from examples/sql-on-polars.py
copy to datafusion/context.py
index c208114..aa9c9a8 100644
--- a/examples/sql-on-polars.py
+++ b/datafusion/context.py
@@ -6,7 +6,7 @@
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
 #
-# http://www.apache.org/licenses/LICENSE-2.0
+#   http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
@@ -15,12 +15,31 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from datafusion.polars import SessionContext
+from abc import ABC, abstractmethod
 
 
-ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
-df = ctx.sql(
-    "select passenger_count, count(*) from taxi group by passenger_count"
-)
-print(df)
+class BaseSessionContext(ABC):
+    """
+    Abstraction defining all methods, properties, and common functionality
+    shared amongst implementations using DataFusion as their SQL Parser/Engine
+    """
+
+    @abstractmethod
+    def register_table(
+        self,
+        table_name: str,
+        path: str,
+        **kwargs,
+    ):
+        pass
+
+    # TODO: Remove abstraction, this functionality can be shared
+    # between all implementing classes since it just prints the
+    # logical plan from DataFusion
+    @abstractmethod
+    def explain(self, sql):
+        pass
+
+    @abstractmethod
+    def sql(self, sql):
+        pass
diff --git a/datafusion/cudf.py b/datafusion/cudf.py
index d5f0215..594e5ef 100644
--- a/datafusion/cudf.py
+++ b/datafusion/cudf.py
@@ -17,18 +17,15 @@
 
 import cudf
 import datafusion
+from datafusion.context import BaseSessionContext
 from datafusion.expr import Projection, TableScan, Column
 
 
-class SessionContext:
+class SessionContext(BaseSessionContext):
     def __init__(self):
         self.datafusion_ctx = datafusion.SessionContext()
         self.parquet_tables = {}
 
-    def register_parquet(self, name, path):
-        self.parquet_tables[name] = path
-        self.datafusion_ctx.register_parquet(name, path)
-
     def to_cudf_expr(self, expr):
         # get Python wrapper for logical expression
         expr = expr.to_variant()
@@ -55,6 +52,13 @@ class SessionContext:
                 "unsupported logical operator: {}".format(type(node))
             )
 
+    def register_table(self, name, path, **kwargs):
+        self.parquet_tables[name] = path
+        self.datafusion_ctx.register_parquet(name, path)
+
+    def explain(self, sql):
+        super.explain()
+
     def sql(self, sql):
         datafusion_df = self.datafusion_ctx.sql(sql)
         plan = datafusion_df.logical_plan()
diff --git a/datafusion/pandas.py b/datafusion/pandas.py
index f8e5651..935d961 100644
--- a/datafusion/pandas.py
+++ b/datafusion/pandas.py
@@ -17,18 +17,15 @@
 
 import pandas as pd
 import datafusion
+from datafusion.context import BaseSessionContext
 from datafusion.expr import Projection, TableScan, Column
 
 
-class SessionContext:
+class SessionContext(BaseSessionContext):
     def __init__(self):
         self.datafusion_ctx = datafusion.SessionContext()
         self.parquet_tables = {}
 
-    def register_parquet(self, name, path):
-        self.parquet_tables[name] = path
-        self.datafusion_ctx.register_parquet(name, path)
-
     def to_pandas_expr(self, expr):
         # get Python wrapper for logical expression
         expr = expr.to_variant()
@@ -55,6 +52,13 @@ class SessionContext:
                 "unsupported logical operator: {}".format(type(node))
             )
 
+    def register_table(self, name, path, **kwargs):
+        self.parquet_tables[name] = path
+        self.datafusion_ctx.register_parquet(name, path)
+
+    def explain(self, sql):
+        super.explain()
+
     def sql(self, sql):
         datafusion_df = self.datafusion_ctx.sql(sql)
         plan = datafusion_df.logical_plan()
diff --git a/datafusion/polars.py b/datafusion/polars.py
index a1bafbe..bbc1fd7 100644
--- a/datafusion/polars.py
+++ b/datafusion/polars.py
@@ -17,19 +17,16 @@
 
 import polars
 import datafusion
+from datafusion.context import BaseSessionContext
 from datafusion.expr import Projection, TableScan, Aggregate
 from datafusion.expr import Column, AggregateFunction
 
 
-class SessionContext:
+class SessionContext(BaseSessionContext):
     def __init__(self):
         self.datafusion_ctx = datafusion.SessionContext()
         self.parquet_tables = {}
 
-    def register_parquet(self, name, path):
-        self.parquet_tables[name] = path
-        self.datafusion_ctx.register_parquet(name, path)
-
     def to_polars_expr(self, expr):
         # get Python wrapper for logical expression
         expr = expr.to_variant()
@@ -78,6 +75,13 @@ class SessionContext:
                 "unsupported logical operator: {}".format(type(node))
             )
 
+    def register_table(self, name, path, **kwargs):
+        self.parquet_tables[name] = path
+        self.datafusion_ctx.register_parquet(name, path)
+
+    def explain(self, sql):
+        super.explain()
+
     def sql(self, sql):
         datafusion_df = self.datafusion_ctx.sql(sql)
         plan = datafusion_df.logical_plan()
diff --git a/docs/mdbook/src/usage/create-table.md 
b/docs/mdbook/src/usage/create-table.md
index 332863a..98870fa 100644
--- a/docs/mdbook/src/usage/create-table.md
+++ b/docs/mdbook/src/usage/create-table.md
@@ -55,5 +55,5 @@ ctx.register_csv("csv_1e8", "G1_1e8_1e2_0_0.csv")
 You can read a Parquet file into a DataFusion DataFrame.  Here's how to read 
the `yellow_tripdata_2021-01.parquet` file into a table named `taxi`.
 
 ```python
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
 ```
diff --git a/examples/sql-on-cudf.py b/examples/sql-on-cudf.py
index 999756f..b64d8f0 100644
--- a/examples/sql-on-cudf.py
+++ b/examples/sql-on-cudf.py
@@ -19,6 +19,6 @@ from datafusion.cudf import SessionContext
 
 
 ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
 df = ctx.sql("select passenger_count from taxi")
 print(df)
diff --git a/examples/sql-on-pandas.py b/examples/sql-on-pandas.py
index 0efd776..e3312a2 100644
--- a/examples/sql-on-pandas.py
+++ b/examples/sql-on-pandas.py
@@ -19,6 +19,6 @@ from datafusion.pandas import SessionContext
 
 
 ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
 df = ctx.sql("select passenger_count from taxi")
 print(df)
diff --git a/examples/sql-on-polars.py b/examples/sql-on-polars.py
index c208114..dd7a9e0 100644
--- a/examples/sql-on-polars.py
+++ b/examples/sql-on-polars.py
@@ -19,7 +19,7 @@ from datafusion.polars import SessionContext
 
 
 ctx = SessionContext()
-ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet")
+ctx.register_table("taxi", "yellow_tripdata_2021-01.parquet")
 df = ctx.sql(
     "select passenger_count, count(*) from taxi group by passenger_count"
 )

Reply via email to