This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 07324b84378 [SPARK-42818][CONNECT][PYTHON] Implement 
DataFrameReader/Writer.jdbc
07324b84378 is described below

commit 07324b84378984eff612aa5d395042e7207f9878
Author: Takuya UESHIN <[email protected]>
AuthorDate: Thu Mar 16 10:38:19 2023 +0900

    [SPARK-42818][CONNECT][PYTHON] Implement DataFrameReader/Writer.jdbc
    
    ### What changes were proposed in this pull request?
    
    Implements `DataFrameReader/Writer.jdbc`.
    
    ### Why are the changes needed?
    
    Missing API.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, `DataFrameReader/Writer.jdbc` will be available.
    
    ### How was this patch tested?
    
    Added related tests.
    
    Closes #40450 from ueshin/issues/SPARK-42818/jdbc.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 8b2f28bd53d0eacbac7555c3a09af908bc682e41)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/plan.py                 |  8 ++
 python/pyspark/sql/connect/readwriter.py           | 96 ++++++++++++++++++++--
 .../sql/tests/connect/test_connect_basic.py        | 12 ---
 python/pyspark/sql/tests/test_datasources.py       | 24 ++++++
 4 files changed, 123 insertions(+), 17 deletions(-)

diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 4e31811a9e2..9807c9722a6 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -259,6 +259,7 @@ class DataSource(LogicalPlan):
         schema: Optional[str] = None,
         options: Optional[Mapping[str, str]] = None,
         paths: Optional[List[str]] = None,
+        predicates: Optional[List[str]] = None,
     ) -> None:
         super().__init__(None)
 
@@ -274,10 +275,15 @@ class DataSource(LogicalPlan):
             assert isinstance(paths, list)
             assert all(isinstance(path, str) for path in paths)
 
+        if predicates is not None:
+            assert isinstance(predicates, list)
+            assert all(isinstance(predicate, str) for predicate in predicates)
+
         self._format = format
         self._schema = schema
         self._options = options
         self._paths = paths
+        self._predicates = predicates
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         plan = self._create_proto_relation()
@@ -290,6 +296,8 @@ class DataSource(LogicalPlan):
                 plan.read.data_source.options[k] = v
         if self._paths is not None and len(self._paths) > 0:
             plan.read.data_source.paths.extend(self._paths)
+        if self._predicates is not None and len(self._predicates) > 0:
+            plan.read.data_source.predicates.extend(self._predicates)
         return plan
 
 
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index 52a7a6c8cf5..1b58c54b38e 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -19,7 +19,7 @@ from pyspark.sql.connect.utils import check_dependencies
 check_dependencies(__name__)
 
 from typing import Dict
-from typing import Optional, Union, List, overload, Tuple, cast, Any
+from typing import Optional, Union, List, overload, Tuple, cast
 from typing import TYPE_CHECKING
 
 from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, 
WriteOperation, WriteOperationV2
@@ -339,8 +339,83 @@ class DataFrameReader(OptionUtils):
 
     orc.__doc__ = PySparkDataFrameReader.orc.__doc__
 
-    def jdbc(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("jdbc() not supported for DataFrameWriter")
+    @overload
+    def jdbc(
+        self, url: str, table: str, *, properties: Optional[Dict[str, str]] = 
None
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def jdbc(
+        self,
+        url: str,
+        table: str,
+        column: str,
+        lowerBound: Union[int, str],
+        upperBound: Union[int, str],
+        numPartitions: int,
+        *,
+        properties: Optional[Dict[str, str]] = None,
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def jdbc(
+        self,
+        url: str,
+        table: str,
+        *,
+        predicates: List[str],
+        properties: Optional[Dict[str, str]] = None,
+    ) -> "DataFrame":
+        ...
+
+    def jdbc(
+        self,
+        url: str,
+        table: str,
+        column: Optional[str] = None,
+        lowerBound: Optional[Union[int, str]] = None,
+        upperBound: Optional[Union[int, str]] = None,
+        numPartitions: Optional[int] = None,
+        predicates: Optional[List[str]] = None,
+        properties: Optional[Dict[str, str]] = None,
+    ) -> "DataFrame":
+        if properties is None:
+            properties = dict()
+
+        self.format("jdbc")
+
+        if column is not None:
+            assert lowerBound is not None, "lowerBound can not be None when 
``column`` is specified"
+            assert upperBound is not None, "upperBound can not be None when 
``column`` is specified"
+            assert (
+                numPartitions is not None
+            ), "numPartitions can not be None when ``column`` is specified"
+            self.options(
+                partitionColumn=column,
+                lowerBound=lowerBound,
+                upperBound=upperBound,
+                numPartitions=numPartitions,
+            )
+            self.options(**properties)
+            self.options(url=url, dbtable=table)
+            return self.load()
+        else:
+            self.options(**properties)
+            self.options(url=url, dbtable=table)
+            if predicates is not None:
+                plan = DataSource(
+                    format=self._format,
+                    schema=self._schema,
+                    options=self._options,
+                    predicates=predicates,
+                )
+                return self._df(plan)
+            else:
+                return self.load()
+
+    jdbc.__doc__ = PySparkDataFrameReader.jdbc.__doc__
 
 
 DataFrameReader.__doc__ = PySparkDataFrameReader.__doc__
@@ -603,8 +678,19 @@ class DataFrameWriter(OptionUtils):
 
     orc.__doc__ = PySparkDataFrameWriter.orc.__doc__
 
-    def jdbc(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("jdbc() not supported for DataFrameWriter")
+    def jdbc(
+        self,
+        url: str,
+        table: str,
+        mode: Optional[str] = None,
+        properties: Optional[Dict[str, str]] = None,
+    ) -> None:
+        if properties is None:
+            properties = dict()
+
+        self.format("jdbc").mode(mode).options(**properties).options(url=url, 
dbtable=table).save()
+
+    jdbc.__doc__ = PySparkDataFrameWriter.jdbc.__doc__
 
 
 class DataFrameWriterV2(OptionUtils):
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index cd6890a630b..9da3285d07e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2871,18 +2871,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             with self.assertRaises(NotImplementedError):
                 getattr(self.connect, f)()
 
-    def test_unsupported_io_functions(self):
-        # SPARK-41964: Disable unsupported functions.
-        df = self.connect.createDataFrame([(x, f"{x}") for x in range(100)], 
["id", "name"])
-
-        for f in ("jdbc",):
-            with self.assertRaises(NotImplementedError):
-                getattr(self.connect.read, f)()
-
-        for f in ("jdbc",):
-            with self.assertRaises(NotImplementedError):
-                getattr(df.write, f)()
-
     def test_sql_with_command(self):
         # SPARK-42705: spark.sql should return values from the command.
         self.assertEqual(
diff --git a/python/pyspark/sql/tests/test_datasources.py 
b/python/pyspark/sql/tests/test_datasources.py
index 9b6692a9d21..6418983b06a 100644
--- a/python/pyspark/sql/tests/test_datasources.py
+++ b/python/pyspark/sql/tests/test_datasources.py
@@ -198,6 +198,30 @@ class DataSourcesTestsMixin:
         url = f"jdbc:derby:{db}"
         dbtable = "test_table"
 
+        try:
+            df = self.spark.range(10)
+            df.write.jdbc(url=f"{url};create=true", table=dbtable)
+
+            readback = self.spark.read.jdbc(url=url, table=dbtable)
+            self.assertEqual(sorted(df.collect()), sorted(readback.collect()))
+
+            additional_arguments = dict(column="id", lowerBound=3, 
upperBound=8, numPartitions=10)
+            readback = self.spark.read.jdbc(url=url, table=dbtable, 
**additional_arguments)
+            self.assertEqual(sorted(df.collect()), sorted(readback.collect()))
+
+            additional_arguments = dict(predicates=['"id" < 5'])
+            readback = self.spark.read.jdbc(url=url, table=dbtable, 
**additional_arguments)
+            self.assertEqual(sorted(df.filter("id < 5").collect()), 
sorted(readback.collect()))
+        finally:
+            # Clean up.
+            with self.assertRaisesRegex(Exception, f"Database '{db}' 
dropped."):
+                self.spark.read.jdbc(url=f"{url};drop=true", 
table=dbtable).collect()
+
+    def test_jdbc_format(self):
+        db = f"memory:{uuid.uuid4()}"
+        url = f"jdbc:derby:{db}"
+        dbtable = "test_table"
+
         try:
             df = self.spark.range(10)
             df.write.format("jdbc").options(url=f"{url};create=true", 
dbtable=dbtable).save()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to