This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new db68958fac6 [SPARK-42011][CONNECT][PYTHON] Implement 
DataFrameReader.csv
db68958fac6 is described below

commit db68958fac692113ece8d647f0e0c37b1f0e312b
Author: Sandeep Singh <[email protected]>
AuthorDate: Sun Jan 15 20:08:08 2023 +0900

    [SPARK-42011][CONNECT][PYTHON] Implement DataFrameReader.csv
    
    ### What changes were proposed in this pull request?
    This PR implements `DataFrameReader.csv` alias in Spark Connect.
    
    ### Why are the changes needed?
    For API feature parity.
    
    ### Does this PR introduce any user-facing change?
    This PR adds a user-facing API but Spark Connect has not been released yet.
    
    ### How was this patch tested?
    Unittest was added.
    
    Closes #39559 from techaddict/SPARK-42011.
    
    Authored-by: Sandeep Singh <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/readwriter.py           | 78 +++++++++++++++++++++-
 python/pyspark/sql/readwriter.py                   |  3 +
 .../sql/tests/connect/test_connect_basic.py        | 12 +++-
 3 files changed, 90 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index c26dd828485..8e8f4476799 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -237,8 +237,82 @@ class DataFrameReader(OptionUtils):
 
     text.__doc__ = PySparkDataFrameReader.text.__doc__
 
-    def csv(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("csv() is not implemented.")
+    def csv(
+        self,
+        path: PathOrPaths,
+        schema: Optional[Union[StructType, str]] = None,
+        sep: Optional[str] = None,
+        encoding: Optional[str] = None,
+        quote: Optional[str] = None,
+        escape: Optional[str] = None,
+        comment: Optional[str] = None,
+        header: Optional[Union[bool, str]] = None,
+        inferSchema: Optional[Union[bool, str]] = None,
+        ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None,
+        ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None,
+        nullValue: Optional[str] = None,
+        nanValue: Optional[str] = None,
+        positiveInf: Optional[str] = None,
+        negativeInf: Optional[str] = None,
+        dateFormat: Optional[str] = None,
+        timestampFormat: Optional[str] = None,
+        maxColumns: Optional[Union[int, str]] = None,
+        maxCharsPerColumn: Optional[Union[int, str]] = None,
+        maxMalformedLogPerPartition: Optional[Union[int, str]] = None,
+        mode: Optional[str] = None,
+        columnNameOfCorruptRecord: Optional[str] = None,
+        multiLine: Optional[Union[bool, str]] = None,
+        charToEscapeQuoteEscaping: Optional[str] = None,
+        samplingRatio: Optional[Union[float, str]] = None,
+        enforceSchema: Optional[Union[bool, str]] = None,
+        emptyValue: Optional[str] = None,
+        locale: Optional[str] = None,
+        lineSep: Optional[str] = None,
+        pathGlobFilter: Optional[Union[bool, str]] = None,
+        recursiveFileLookup: Optional[Union[bool, str]] = None,
+        modifiedBefore: Optional[Union[bool, str]] = None,
+        modifiedAfter: Optional[Union[bool, str]] = None,
+        unescapedQuoteHandling: Optional[str] = None,
+    ) -> "DataFrame":
+        self._set_opts(
+            sep=sep,
+            encoding=encoding,
+            quote=quote,
+            escape=escape,
+            comment=comment,
+            header=header,
+            inferSchema=inferSchema,
+            ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
+            ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
+            nullValue=nullValue,
+            nanValue=nanValue,
+            positiveInf=positiveInf,
+            negativeInf=negativeInf,
+            dateFormat=dateFormat,
+            timestampFormat=timestampFormat,
+            maxColumns=maxColumns,
+            maxCharsPerColumn=maxCharsPerColumn,
+            maxMalformedLogPerPartition=maxMalformedLogPerPartition,
+            mode=mode,
+            columnNameOfCorruptRecord=columnNameOfCorruptRecord,
+            multiLine=multiLine,
+            charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
+            samplingRatio=samplingRatio,
+            enforceSchema=enforceSchema,
+            emptyValue=emptyValue,
+            locale=locale,
+            lineSep=lineSep,
+            pathGlobFilter=pathGlobFilter,
+            recursiveFileLookup=recursiveFileLookup,
+            modifiedBefore=modifiedBefore,
+            modifiedAfter=modifiedAfter,
+            unescapedQuoteHandling=unescapedQuoteHandling,
+        )
+        if isinstance(path, str):
+            path = [path]
+        return self.load(path=path, format="csv", schema=schema)
+
+    csv.__doc__ = PySparkDataFrameReader.csv.__doc__
 
     def orc(
         self,
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index aa27c559a0d..8b083ae9054 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -648,6 +648,9 @@ class DataFrameReader(OptionUtils):
 
         .. versionadded:: 2.0.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         path : str or list
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index e0e3cc6d1e0..8d28e713694 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -256,6 +256,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             # Read the text file as a DataFrame.
             self.assert_eq(self.connect.read.text(d).toPandas(), 
self.spark.read.text(d).toPandas())
 
+    def test_csv(self):
+        # SPARK-42011: Implement DataFrameReader.csv
+        with tempfile.TemporaryDirectory() as d:
+            # Write a DataFrame into a text file
+            self.spark.createDataFrame(
+                [{"name": "Sandeep Singh"}, {"name": "Hyukjin Kwon"}]
+            ).write.mode("overwrite").format("csv").save(d)
+            # Read the text file as a DataFrame.
+            self.assert_eq(self.connect.read.csv(d).toPandas(), 
self.spark.read.csv(d).toPandas())
+
     def test_multi_paths(self):
         # SPARK-42041: DataFrameReader should support list of paths
 
@@ -2557,7 +2567,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         # DataFrameWriterV2 is also not implemented yet
         df = self.connect.createDataFrame([(x, f"{x}") for x in range(100)], 
["id", "name"])
 
-        for f in ("csv", "jdbc"):
+        for f in ("jdbc",):
             with self.assertRaises(NotImplementedError):
                 getattr(self.connect.read, f)()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to