This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 31abad98972b [SPARK-50471][PYTHON] Support Arrow-based Python Data 
Source Writer
31abad98972b is described below

commit 31abad98972bd4e421c34115e45a89f0d5c28896
Author: Allison Wang <[email protected]>
AuthorDate: Tue Dec 3 08:37:01 2024 +0900

    [SPARK-50471][PYTHON] Support Arrow-based Python Data Source Writer
    
    ### What changes were proposed in this pull request?
    
    This PR introduces a new Python Data Source Writer that leverages PyArrow’s 
RecordBatch format. Unlike the current DataSourceWriter, which operates on 
iterators of Spark Rows, this new writer takes in an iterator of PyArrow 
`RecordBatch` as input.
    
    ### Why are the changes needed?
    
    Make Python data source write more performant when interfacing with systems 
or libraries that natively support Arrow.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. This PR adds a new user-facing class `DataSourceArrowWriter`.
    
    ### How was this patch tested?
    
    New unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #49028 from allisonwang-db/spark-50471-arrow-writer.
    
    Authored-by: Allison Wang <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/datasource.py                   | 39 ++++++++++++++++++++++
 python/pyspark/sql/tests/test_python_datasource.py | 34 ++++++++++++++++++-
 .../pyspark/sql/worker/write_into_data_source.py   |  6 +++-
 3 files changed, 77 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index a51c96a9d178..06b853ce5b4e 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -32,6 +32,7 @@ __all__ = [
     "DataSourceStreamReader",
     "SimpleDataSourceStreamReader",
     "DataSourceWriter",
+    "DataSourceArrowWriter",
     "DataSourceStreamWriter",
     "DataSourceRegistration",
     "InputPartition",
@@ -666,6 +667,44 @@ class DataSourceWriter(ABC):
         ...
 
 
+class DataSourceArrowWriter(DataSourceWriter):
+    """
+    A base class for data source writers that process data using PyArrow’s 
`RecordBatch`.
+
+    Unlike :class:`DataSourceWriter`, which works with an iterator of Spark 
Rows, this class
+    is optimized for using the Arrow format when writing data. It can offer 
better performance
+    when interfacing with systems or libraries that natively support Arrow.
+
+    .. versionadded: 4.0.0
+    """
+
+    @abstractmethod
+    def write(self, iterator: Iterator["RecordBatch"]) -> 
"WriterCommitMessage":
+        """
+        Writes an iterator of PyArrow `RecordBatch` objects to the sink.
+
+        This method is called once on each executor to write data to the data 
source.
+        It accepts an iterator of PyArrow `RecordBatch`\\s and returns a 
single row
+        representing a commit message, or None if there is no commit message.
+
+        The driver collects commit messages, if any, from all executors and 
passes them
+        to the :class:`DataSourceWriter.commit` method if all tasks run 
successfully. If any
+        task fails, the :class:`DataSourceWriter.abort` method will be called 
with the
+        collected commit messages.
+
+        Parameters
+        ----------
+        iterator : iterator of :class:`RecordBatch`\\s
+            An iterator of PyArrow `RecordBatch` objects representing the 
input data.
+
+        Returns
+        -------
+        :class:`WriterCommitMessage`
+            a serializable commit message
+        """
+        ...
+
+
 class DataSourceStreamWriter(ABC):
     """
     A base class for data stream writers. Data stream writers are responsible 
for writing
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 140c7680b181..a636b852a1e5 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -25,6 +25,7 @@ from pyspark.sql.datasource import (
     DataSourceReader,
     InputPartition,
     DataSourceWriter,
+    DataSourceArrowWriter,
     WriterCommitMessage,
     CaseInsensitiveDict,
 )
@@ -277,7 +278,7 @@ class BasePythonDataSourceTestsMixin:
                 from pyspark import TaskContext
 
                 context = TaskContext.get()
-                output_path = os.path.join(self.path, 
f"{context.partitionId}.json")
+                output_path = os.path.join(self.path, 
f"{context.partitionId()}.json")
                 count = 0
                 with open(output_path, "w") as file:
                     for row in iterator:
@@ -436,6 +437,37 @@ class BasePythonDataSourceTestsMixin:
         ):
             self.spark.read.format("arrowbatch").schema("key int, dummy 
string").load().show()
 
+    def test_arrow_batch_sink(self):
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "arrow_sink"
+
+            def writer(self, schema, overwrite):
+                return TestArrowWriter(self.options["path"])
+
+        class TestArrowWriter(DataSourceArrowWriter):
+            def __init__(self, path):
+                self.path = path
+
+            def write(self, iterator):
+                from pyspark import TaskContext
+
+                context = TaskContext.get()
+                output_path = os.path.join(self.path, 
f"{context.partitionId()}.json")
+                with open(output_path, "w") as file:
+                    for batch in iterator:
+                        df = batch.to_pandas()
+                        df.to_json(file, orient="records", lines=True)
+                return WriterCommitMessage()
+
+        self.spark.dataSource.register(TestDataSource)
+        df = self.spark.range(3)
+        with tempfile.TemporaryDirectory(prefix="test_arrow_batch_sink") as d:
+            df.write.format("arrow_sink").mode("append").save(d)
+            df2 = self.spark.read.format("json").load(d)
+            assertDataFrameEqual(df2, df)
+
     def test_data_source_type_mismatch(self):
         class TestDataSource(DataSource):
             @classmethod
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index a114a3facc46..91a1f4d3b1b3 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -32,6 +32,7 @@ from pyspark.sql import Row
 from pyspark.sql.datasource import (
     DataSource,
     DataSourceWriter,
+    DataSourceArrowWriter,
     WriterCommitMessage,
     CaseInsensitiveDict,
 )
@@ -194,7 +195,10 @@ def main(infile: IO, outfile: IO) -> None:
                         ]
                         yield _create_row(fields=fields, values=values)
 
-            res = writer.write(batch_to_rows())
+            if isinstance(writer, DataSourceArrowWriter):
+                res = writer.write(iterator)
+            else:
+                res = writer.write(batch_to_rows())
 
             # Check the commit message has the right type.
             if not isinstance(res, WriterCommitMessage):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to