This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 07324b84378 [SPARK-42818][CONNECT][PYTHON] Implement
DataFrameReader/Writer.jdbc
07324b84378 is described below
commit 07324b84378984eff612aa5d395042e7207f9878
Author: Takuya UESHIN <[email protected]>
AuthorDate: Thu Mar 16 10:38:19 2023 +0900
[SPARK-42818][CONNECT][PYTHON] Implement DataFrameReader/Writer.jdbc
### What changes were proposed in this pull request?
Implements `DataFrameReader/Writer.jdbc`.
### Why are the changes needed?
Missing API.
### Does this PR introduce _any_ user-facing change?
Yes, `DataFrameReader/Writer.jdbc` will be available.
### How was this patch tested?
Added related tests.
Closes #40450 from ueshin/issues/SPARK-42818/jdbc.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 8b2f28bd53d0eacbac7555c3a09af908bc682e41)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/plan.py | 8 ++
python/pyspark/sql/connect/readwriter.py | 96 ++++++++++++++++++++--
.../sql/tests/connect/test_connect_basic.py | 12 ---
python/pyspark/sql/tests/test_datasources.py | 24 ++++++
4 files changed, 123 insertions(+), 17 deletions(-)
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 4e31811a9e2..9807c9722a6 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -259,6 +259,7 @@ class DataSource(LogicalPlan):
schema: Optional[str] = None,
options: Optional[Mapping[str, str]] = None,
paths: Optional[List[str]] = None,
+ predicates: Optional[List[str]] = None,
) -> None:
super().__init__(None)
@@ -274,10 +275,15 @@ class DataSource(LogicalPlan):
assert isinstance(paths, list)
assert all(isinstance(path, str) for path in paths)
+ if predicates is not None:
+ assert isinstance(predicates, list)
+ assert all(isinstance(predicate, str) for predicate in predicates)
+
self._format = format
self._schema = schema
self._options = options
self._paths = paths
+ self._predicates = predicates
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
@@ -290,6 +296,8 @@ class DataSource(LogicalPlan):
plan.read.data_source.options[k] = v
if self._paths is not None and len(self._paths) > 0:
plan.read.data_source.paths.extend(self._paths)
+ if self._predicates is not None and len(self._predicates) > 0:
+ plan.read.data_source.predicates.extend(self._predicates)
return plan
diff --git a/python/pyspark/sql/connect/readwriter.py
b/python/pyspark/sql/connect/readwriter.py
index 52a7a6c8cf5..1b58c54b38e 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -19,7 +19,7 @@ from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
from typing import Dict
-from typing import Optional, Union, List, overload, Tuple, cast, Any
+from typing import Optional, Union, List, overload, Tuple, cast
from typing import TYPE_CHECKING
from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan,
WriteOperation, WriteOperationV2
@@ -339,8 +339,83 @@ class DataFrameReader(OptionUtils):
orc.__doc__ = PySparkDataFrameReader.orc.__doc__
- def jdbc(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("jdbc() not supported for DataFrameWriter")
+ @overload
+ def jdbc(
+ self, url: str, table: str, *, properties: Optional[Dict[str, str]] =
None
+ ) -> "DataFrame":
+ ...
+
+ @overload
+ def jdbc(
+ self,
+ url: str,
+ table: str,
+ column: str,
+ lowerBound: Union[int, str],
+ upperBound: Union[int, str],
+ numPartitions: int,
+ *,
+ properties: Optional[Dict[str, str]] = None,
+ ) -> "DataFrame":
+ ...
+
+ @overload
+ def jdbc(
+ self,
+ url: str,
+ table: str,
+ *,
+ predicates: List[str],
+ properties: Optional[Dict[str, str]] = None,
+ ) -> "DataFrame":
+ ...
+
+ def jdbc(
+ self,
+ url: str,
+ table: str,
+ column: Optional[str] = None,
+ lowerBound: Optional[Union[int, str]] = None,
+ upperBound: Optional[Union[int, str]] = None,
+ numPartitions: Optional[int] = None,
+ predicates: Optional[List[str]] = None,
+ properties: Optional[Dict[str, str]] = None,
+ ) -> "DataFrame":
+ if properties is None:
+ properties = dict()
+
+ self.format("jdbc")
+
+ if column is not None:
+ assert lowerBound is not None, "lowerBound can not be None when
``column`` is specified"
+ assert upperBound is not None, "upperBound can not be None when
``column`` is specified"
+ assert (
+ numPartitions is not None
+ ), "numPartitions can not be None when ``column`` is specified"
+ self.options(
+ partitionColumn=column,
+ lowerBound=lowerBound,
+ upperBound=upperBound,
+ numPartitions=numPartitions,
+ )
+ self.options(**properties)
+ self.options(url=url, dbtable=table)
+ return self.load()
+ else:
+ self.options(**properties)
+ self.options(url=url, dbtable=table)
+ if predicates is not None:
+ plan = DataSource(
+ format=self._format,
+ schema=self._schema,
+ options=self._options,
+ predicates=predicates,
+ )
+ return self._df(plan)
+ else:
+ return self.load()
+
+ jdbc.__doc__ = PySparkDataFrameReader.jdbc.__doc__
DataFrameReader.__doc__ = PySparkDataFrameReader.__doc__
@@ -603,8 +678,19 @@ class DataFrameWriter(OptionUtils):
orc.__doc__ = PySparkDataFrameWriter.orc.__doc__
- def jdbc(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("jdbc() not supported for DataFrameWriter")
+ def jdbc(
+ self,
+ url: str,
+ table: str,
+ mode: Optional[str] = None,
+ properties: Optional[Dict[str, str]] = None,
+ ) -> None:
+ if properties is None:
+ properties = dict()
+
+ self.format("jdbc").mode(mode).options(**properties).options(url=url,
dbtable=table).save()
+
+ jdbc.__doc__ = PySparkDataFrameWriter.jdbc.__doc__
class DataFrameWriterV2(OptionUtils):
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index cd6890a630b..9da3285d07e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2871,18 +2871,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
with self.assertRaises(NotImplementedError):
getattr(self.connect, f)()
- def test_unsupported_io_functions(self):
- # SPARK-41964: Disable unsupported functions.
- df = self.connect.createDataFrame([(x, f"{x}") for x in range(100)],
["id", "name"])
-
- for f in ("jdbc",):
- with self.assertRaises(NotImplementedError):
- getattr(self.connect.read, f)()
-
- for f in ("jdbc",):
- with self.assertRaises(NotImplementedError):
- getattr(df.write, f)()
-
def test_sql_with_command(self):
# SPARK-42705: spark.sql should return values from the command.
self.assertEqual(
diff --git a/python/pyspark/sql/tests/test_datasources.py
b/python/pyspark/sql/tests/test_datasources.py
index 9b6692a9d21..6418983b06a 100644
--- a/python/pyspark/sql/tests/test_datasources.py
+++ b/python/pyspark/sql/tests/test_datasources.py
@@ -198,6 +198,30 @@ class DataSourcesTestsMixin:
url = f"jdbc:derby:{db}"
dbtable = "test_table"
+ try:
+ df = self.spark.range(10)
+ df.write.jdbc(url=f"{url};create=true", table=dbtable)
+
+ readback = self.spark.read.jdbc(url=url, table=dbtable)
+ self.assertEqual(sorted(df.collect()), sorted(readback.collect()))
+
+ additional_arguments = dict(column="id", lowerBound=3,
upperBound=8, numPartitions=10)
+ readback = self.spark.read.jdbc(url=url, table=dbtable,
**additional_arguments)
+ self.assertEqual(sorted(df.collect()), sorted(readback.collect()))
+
+ additional_arguments = dict(predicates=['"id" < 5'])
+ readback = self.spark.read.jdbc(url=url, table=dbtable,
**additional_arguments)
+ self.assertEqual(sorted(df.filter("id < 5").collect()),
sorted(readback.collect()))
+ finally:
+ # Clean up.
+ with self.assertRaisesRegex(Exception, f"Database '{db}'
dropped."):
+ self.spark.read.jdbc(url=f"{url};drop=true",
table=dbtable).collect()
+
+ def test_jdbc_format(self):
+ db = f"memory:{uuid.uuid4()}"
+ url = f"jdbc:derby:{db}"
+ dbtable = "test_table"
+
try:
df = self.spark.range(10)
df.write.format("jdbc").options(url=f"{url};create=true",
dbtable=dbtable).save()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]