This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 31abad98972b [SPARK-50471][PYTHON] Support Arrow-based Python Data
Source Writer
31abad98972b is described below
commit 31abad98972bd4e421c34115e45a89f0d5c28896
Author: Allison Wang <[email protected]>
AuthorDate: Tue Dec 3 08:37:01 2024 +0900
[SPARK-50471][PYTHON] Support Arrow-based Python Data Source Writer
### What changes were proposed in this pull request?
This PR introduces a new Python Data Source Writer that leverages PyArrow’s
RecordBatch format. Unlike the current DataSourceWriter, which operates on
iterators of Spark Rows, this new writer takes in an iterator of PyArrow
`RecordBatch` as input.
### Why are the changes needed?
Make Python data source write more performant when interfacing with systems
or libraries that natively support Arrow.
### Does this PR introduce _any_ user-facing change?
Yes. This PR adds a new user-facing class `DataSourceArrowWriter`.
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49028 from allisonwang-db/spark-50471-arrow-writer.
Authored-by: Allison Wang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/datasource.py | 39 ++++++++++++++++++++++
python/pyspark/sql/tests/test_python_datasource.py | 34 ++++++++++++++++++-
.../pyspark/sql/worker/write_into_data_source.py | 6 +++-
3 files changed, 77 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index a51c96a9d178..06b853ce5b4e 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -32,6 +32,7 @@ __all__ = [
"DataSourceStreamReader",
"SimpleDataSourceStreamReader",
"DataSourceWriter",
+ "DataSourceArrowWriter",
"DataSourceStreamWriter",
"DataSourceRegistration",
"InputPartition",
@@ -666,6 +667,44 @@ class DataSourceWriter(ABC):
...
+class DataSourceArrowWriter(DataSourceWriter):
+ """
+ A base class for data source writers that process data using PyArrow’s
`RecordBatch`.
+
+ Unlike :class:`DataSourceWriter`, which works with an iterator of Spark
Rows, this class
+ is optimized for using the Arrow format when writing data. It can offer
better performance
+ when interfacing with systems or libraries that natively support Arrow.
+
+ .. versionadded: 4.0.0
+ """
+
+ @abstractmethod
+ def write(self, iterator: Iterator["RecordBatch"]) ->
"WriterCommitMessage":
+ """
+ Writes an iterator of PyArrow `RecordBatch` objects to the sink.
+
+ This method is called once on each executor to write data to the data
source.
+ It accepts an iterator of PyArrow `RecordBatch`\\s and returns a
single row
+ representing a commit message, or None if there is no commit message.
+
+ The driver collects commit messages, if any, from all executors and
passes them
+ to the :class:`DataSourceWriter.commit` method if all tasks run
successfully. If any
+ task fails, the :class:`DataSourceWriter.abort` method will be called
with the
+ collected commit messages.
+
+ Parameters
+ ----------
+ iterator : iterator of :class:`RecordBatch`\\s
+ An iterator of PyArrow `RecordBatch` objects representing the
input data.
+
+ Returns
+ -------
+ :class:`WriterCommitMessage`
+ a serializable commit message
+ """
+ ...
+
+
class DataSourceStreamWriter(ABC):
"""
A base class for data stream writers. Data stream writers are responsible
for writing
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index 140c7680b181..a636b852a1e5 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -25,6 +25,7 @@ from pyspark.sql.datasource import (
DataSourceReader,
InputPartition,
DataSourceWriter,
+ DataSourceArrowWriter,
WriterCommitMessage,
CaseInsensitiveDict,
)
@@ -277,7 +278,7 @@ class BasePythonDataSourceTestsMixin:
from pyspark import TaskContext
context = TaskContext.get()
- output_path = os.path.join(self.path,
f"{context.partitionId}.json")
+ output_path = os.path.join(self.path,
f"{context.partitionId()}.json")
count = 0
with open(output_path, "w") as file:
for row in iterator:
@@ -436,6 +437,37 @@ class BasePythonDataSourceTestsMixin:
):
self.spark.read.format("arrowbatch").schema("key int, dummy
string").load().show()
+ def test_arrow_batch_sink(self):
+ class TestDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ return "arrow_sink"
+
+ def writer(self, schema, overwrite):
+ return TestArrowWriter(self.options["path"])
+
+ class TestArrowWriter(DataSourceArrowWriter):
+ def __init__(self, path):
+ self.path = path
+
+ def write(self, iterator):
+ from pyspark import TaskContext
+
+ context = TaskContext.get()
+ output_path = os.path.join(self.path,
f"{context.partitionId()}.json")
+ with open(output_path, "w") as file:
+ for batch in iterator:
+ df = batch.to_pandas()
+ df.to_json(file, orient="records", lines=True)
+ return WriterCommitMessage()
+
+ self.spark.dataSource.register(TestDataSource)
+ df = self.spark.range(3)
+ with tempfile.TemporaryDirectory(prefix="test_arrow_batch_sink") as d:
+ df.write.format("arrow_sink").mode("append").save(d)
+ df2 = self.spark.read.format("json").load(d)
+ assertDataFrameEqual(df2, df)
+
def test_data_source_type_mismatch(self):
class TestDataSource(DataSource):
@classmethod
diff --git a/python/pyspark/sql/worker/write_into_data_source.py
b/python/pyspark/sql/worker/write_into_data_source.py
index a114a3facc46..91a1f4d3b1b3 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -32,6 +32,7 @@ from pyspark.sql import Row
from pyspark.sql.datasource import (
DataSource,
DataSourceWriter,
+ DataSourceArrowWriter,
WriterCommitMessage,
CaseInsensitiveDict,
)
@@ -194,7 +195,10 @@ def main(infile: IO, outfile: IO) -> None:
]
yield _create_row(fields=fields, values=values)
- res = writer.write(batch_to_rows())
+ if isinstance(writer, DataSourceArrowWriter):
+ res = writer.write(iterator)
+ else:
+ res = writer.write(batch_to_rows())
# Check the commit message has the right type.
if not isinstance(res, WriterCommitMessage):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]