This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fcd5bfe003 [SPARK-45525][SQL][PYTHON] Support for Python data source 
write using DSv2
4fcd5bfe003 is described below

commit 4fcd5bfe003bb546ca888efaf1d39c15c9685673
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Fri Dec 22 09:28:47 2023 +0800

    [SPARK-45525][SQL][PYTHON] Support for Python data source write using DSv2
    
    ### What changes were proposed in this pull request?
    
    This PR adds initial support for Python data source write by implementing 
the DSv2 `SupportsWrite` interface for `PythonTableProvider`.
    
    Note this PR only supports the `def write(self, iterator)` API. `commit` 
and `abort` will be supported in 
[SPARK-45914](https://issues.apache.org/jira/browse/SPARK-45914).
    
    ### Why are the changes needed?
    
    To support Python data source APIs. For instance:
    
    ```python
    class SimpleWriter(DataSourceWriter):
        def write(self, iterator: Iterator[Row]) -> WriterCommitMessage:
            for row in iterator:
                print(row)
            return WriterCommitMessage()
    
    class SimpleDataSource(DataSource):
        def writer(self, schema, overwrite):
            return SimpleWriter()
    
    # Regsiter the Python data source
    spark.dataSource.register(SimpleDataSource)
    
    df.range(10).write.format("SimpleDataSource").mode("append").save()
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this PR supports writing data into a Python data source.
    
    ### How was this patch tested?
    
    New unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43791 from allisonwang-db/spark-45525-data-source-write.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../src/main/resources/error/error-classes.json    |   6 +
 docs/sql-error-conditions.md                       |   6 +
 python/pyspark/errors/error_classes.py             |   5 +
 python/pyspark/sql/tests/test_python_datasource.py |  36 ++-
 .../pyspark/sql/worker/write_into_data_source.py   | 233 ++++++++++++++++++
 .../spark/sql/errors/QueryExecutionErrors.scala    |   6 +
 .../python/UserDefinedPythonDataSource.scala       | 269 +++++++++++++++++----
 .../execution/python/PythonDataSourceSuite.scala   |  95 ++++++++
 8 files changed, 612 insertions(+), 44 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index df223f3298e..8970045d4ab 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -2513,6 +2513,12 @@
     ],
     "sqlState" : "42601"
   },
+  "INVALID_WRITER_COMMIT_MESSAGE" : {
+    "message" : [
+      "The data source writer has generated an invalid number of commit 
messages. Expected exactly one writer commit message from each task, but 
received <detail>."
+    ],
+    "sqlState" : "42KDE"
+  },
   "INVALID_WRITE_DISTRIBUTION" : {
     "message" : [
       "The requested write distribution is invalid."
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index a1af6863913..0722cae5815 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -1398,6 +1398,12 @@ Rewrite the query to avoid window functions, aggregate 
functions, and generator
 
 Cannot specify ORDER BY or a window frame for `<aggFunc>`.
 
+### INVALID_WRITER_COMMIT_MESSAGE
+
+[SQLSTATE: 
42KDE](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+The data source writer has generated an invalid number of commit messages. 
Expected exactly one writer commit message from each task, but received 
`<detail>`.
+
 ### 
[INVALID_WRITE_DISTRIBUTION](sql-error-conditions-invalid-write-distribution-error-class.html)
 
 [SQLSTATE: 
42000](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index bb278481262..2200b73dffc 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -772,6 +772,11 @@ ERROR_CLASSES_JSON = """
       "Expected <expected>, but got <actual>."
     ]
   },
+  "PYTHON_DATA_SOURCE_WRITE_ERROR" : {
+    "message" : [
+      "Unable to write to the Python data source: <error>."
+    ]
+  },
   "PYTHON_HASH_SEED_NOT_SET" : {
     "message" : [
       "Randomness of hash of string should be disabled via PYTHONHASHSEED."
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 74ef6a87458..b1bba584d85 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -15,11 +15,18 @@
 # limitations under the License.
 #
 import os
+import tempfile
 import unittest
 from typing import Callable, Union
 
 from pyspark.errors import PythonException
-from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceReader,
+    InputPartition,
+    DataSourceWriter,
+    WriterCommitMessage,
+)
 from pyspark.sql.types import Row, StructType
 from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -235,6 +242,17 @@ class BasePythonDataSourceTestsMixin:
                             data = json.loads(line)
                             yield data.get("name"), data.get("age")
 
+        class JsonDataSourceWriter(DataSourceWriter):
+            def __init__(self, options):
+                self.options = options
+
+            def write(self, iterator):
+                path = self.options.get("path")
+                with open(path, "w") as file:
+                    for row in iterator:
+                        file.write(json.dumps(row.asDict()) + "\n")
+                return WriterCommitMessage()
+
         class JsonDataSource(DataSource):
             @classmethod
             def name(cls):
@@ -246,7 +264,11 @@ class BasePythonDataSourceTestsMixin:
             def reader(self, schema) -> "DataSourceReader":
                 return JsonDataSourceReader(self.options)
 
+            def writer(self, schema, overwrite):
+                return JsonDataSourceWriter(self.options)
+
         self.spark.dataSource.register(JsonDataSource)
+        # Test data source read.
         path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
         path2 = os.path.join(SPARK_HOME, 
"python/test_support/sql/people1.json")
         assertDataFrameEqual(
@@ -257,6 +279,18 @@ class BasePythonDataSourceTestsMixin:
             self.spark.read.format("my-json").load(path2),
             [Row(name="Jonathan", age=None)],
         )
+        # Test data source write.
+        df = self.spark.read.json(path1)
+        with tempfile.TemporaryDirectory() as d:
+            path = os.path.join(d, "res.json")
+            df.write.format("my-json").mode("append").save(path)
+            with open(path, "r") as file:
+                text = file.read()
+            assert text == (
+                '{"age": null, "name": "Michael"}\n'
+                '{"age": 30, "name": "Andy"}\n'
+                '{"age": 19, "name": "Justin"}\n'
+            )
 
 
 class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
new file mode 100644
index 00000000000..9c311dad033
--- /dev/null
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -0,0 +1,233 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import inspect
+import os
+import sys
+from typing import IO, Iterable, Iterator
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, 
PySparkTypeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_int,
+    write_int,
+    SpecialLengths,
+)
+from pyspark.sql import Row
+from pyspark.sql.datasource import DataSource, WriterCommitMessage
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    StructType,
+    BinaryType,
+    _create_row,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    """
+    Main method for saving into a Python data source.
+
+    This process is invoked from the 
`SaveIntoPythonDataSourceRunner.runInPython` method
+    in the optimizer rule `PythonDataSourceWrites` in JVM. This process is 
responsible for
+    creating a `DataSource` object and a DataSourceWriter instance, and send 
information
+    needed back to the JVM.
+
+    The JVM sends the following information to this process:
+    - a `DataSource` class representing the data source to be created.
+    - a provider name in string.
+    - a schema in json string.
+    - a dictionary of options in string.
+
+    This process first creates a `DataSource` instance and then a 
`DataSourceWriter`
+    instance and send a function using the writer instance that can be used
+    in mapInPandas/mapInArrow back to the JVM.
+    """
+    try:
+        check_python_version(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        setup_spark_files(infile)
+        setup_broadcasts(infile)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source class.
+        data_source_cls = read_command(pickleSer, infile)
+        if not (isinstance(data_source_cls, type) and 
issubclass(data_source_cls, DataSource)):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a subclass of DataSource",
+                    "actual": f"'{type(data_source_cls).__name__}'",
+                },
+            )
+
+        # Check the name method is a class method.
+        if not inspect.ismethod(data_source_cls.name):
+            raise PySparkTypeError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "'name()' method to be a classmethod",
+                    "actual": f"'{type(data_source_cls.name).__name__}'",
+                },
+            )
+
+        # Receive the provider name.
+        provider = utf8_deserializer.loads(infile)
+
+        # Check if the provider name matches the data source's name.
+        if provider.lower() != data_source_cls.name().lower():
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": f"provider with name {data_source_cls.name()}",
+                    "actual": f"'{provider}'",
+                },
+            )
+
+        # Receive the input schema
+        schema = _parse_datatype_json_string(utf8_deserializer.loads(infile))
+        if not isinstance(schema, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "the schema to be a 'StructType'",
+                    "actual": f"'{type(data_source_cls).__name__}'",
+                },
+            )
+
+        # Receive the return type
+        return_type = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+        if not isinstance(return_type, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a return type of type 'StructType'",
+                    "actual": f"'{type(return_type).__name__}'",
+                },
+            )
+        assert len(return_type) == 1 and isinstance(return_type[0].dataType, 
BinaryType), (
+            "The output schema of Python data source write should contain only 
one column of type "
+            f"'BinaryType', but got '{return_type}'"
+        )
+        return_col_name = return_type[0].name
+
+        # Receive the options.
+        options = dict()
+        num_options = read_int(infile)
+        for _ in range(num_options):
+            key = utf8_deserializer.loads(infile)
+            value = utf8_deserializer.loads(infile)
+            options[key] = value
+
+        # Receive the save mode.
+        save_mode = utf8_deserializer.loads(infile)
+
+        # Instantiate a data source.
+        try:
+            data_source = data_source_cls(options=options)
+        except Exception as e:
+            raise PySparkRuntimeError(
+                error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
+                message_parameters={"type": "instance", "error": str(e)},
+            )
+
+        # Instantiate the data source writer.
+        try:
+            writer = data_source.writer(schema, save_mode)
+        except Exception as e:
+            raise PySparkRuntimeError(
+                error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
+                message_parameters={"type": "writer", "error": str(e)},
+            )
+
+        # Create a function that can be used in mapInArrow.
+        import pyarrow as pa
+
+        converters = [
+            ArrowTableToRowsConversion._create_converter(f.dataType) for f in 
schema.fields
+        ]
+        fields = schema.fieldNames()
+
+        def data_source_write_func(iterator: Iterable[pa.RecordBatch]) -> 
Iterable[pa.RecordBatch]:
+            def batch_to_rows() -> Iterator[Row]:
+                for batch in iterator:
+                    columns = [column.to_pylist() for column in batch.columns]
+                    for row in range(0, batch.num_rows):
+                        values = [
+                            converters[col](columns[col][row]) for col in 
range(batch.num_columns)
+                        ]
+                        yield _create_row(fields=fields, values=values)
+
+            res = writer.write(batch_to_rows())
+
+            # Check the commit message has the right type.
+            if not isinstance(res, WriterCommitMessage):
+                raise PySparkRuntimeError(
+                    error_class="PYTHON_DATA_SOURCE_WRITE_ERROR",
+                    message_parameters={
+                        "error": f"return type of the `write` method must be "
+                        f"an instance of WriterCommitMessage, but got 
{type(res)}"
+                    },
+                )
+
+            # Serialize the commit message and return it.
+            pickled = pickleSer.dumps(res)
+
+            # Return the commit message.
+            messages = pa.array([pickled])
+            yield pa.record_batch([messages], names=[return_col_name])
+
+        # Return the pickled write UDF.
+        command = (data_source_write_func, return_type)
+        pickleSer._write_with_length(command, outfile)
+
+    except BaseException as e:
+        handle_worker_exception(e, outfile)
+        sys.exit(-1)
+
+    send_accumulator_updates(outfile)
+
+    # check end of stream
+    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    else:
+        # write a different value to tell JVM to not reuse this worker
+        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    # Read information about how to connect back to the JVM from the 
environment.
+    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    main(sock_file, sock_file)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 113f995968a..b0eaf84fe6a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2771,4 +2771,10 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
         "parameter" -> toSQLId("charset"),
         "charset" -> charset))
   }
+
+  def invalidWriterCommitMessageError(details: String): Throwable = {
+    new SparkRuntimeException(
+      errorClass = "INVALID_WRITER_COMMIT_MESSAGE",
+      messageParameters = Map("details" -> details))
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index d31b3135d65..00974a7e297 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -24,19 +24,20 @@ import scala.jdk.CollectionConverters._
 
 import net.razorvine.pickle.Pickler
 
-import org.apache.spark.JobArtifactSet
+import org.apache.spark.{JobArtifactSet, SparkException}
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths}
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{SaveMode, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.PythonUDF
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, 
TableCapability, TableProvider}
-import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
+import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE}
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
 import org.apache.spark.sql.connector.read.{Batch, InputPartition, 
PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, 
DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, 
WriterCommitMessage}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{BinaryType, DataType, StructType}
@@ -72,11 +73,11 @@ class PythonTableProvider extends TableProvider {
       partitioning: Array[Transform],
       properties: java.util.Map[String, String]): Table = {
     val outputSchema = schema
-    new Table with SupportsRead {
+    new Table with SupportsRead with SupportsWrite {
       override def name(): String = shortName
 
       override def capabilities(): java.util.Set[TableCapability] = 
java.util.EnumSet.of(
-        BATCH_READ)
+        BATCH_READ, BATCH_WRITE)
 
       override def newScanBuilder(options: CaseInsensitiveStringMap): 
ScanBuilder = {
         new ScanBuilder with Batch with Scan {
@@ -103,6 +104,7 @@ class PythonTableProvider extends TableProvider {
             new PythonPartitionReaderFactory(
               source, readerFunc, outputSchema, jobArtifactUUID)
           }
+
           override def description: String = "(Python)"
 
           override def supportedCustomMetrics(): Array[CustomMetric] =
@@ -111,6 +113,38 @@ class PythonTableProvider extends TableProvider {
       }
 
       override def schema(): StructType = outputSchema
+
+      override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+        new WriteBuilder {
+          override def build(): Write = new Write {
+
+            override def toBatch: BatchWrite = new BatchWrite {
+
+              override def createBatchWriterFactory(
+                physicalInfo: PhysicalWriteInfo): DataWriterFactory = {
+
+                val writeInfo = source.createWriteInfoInPython(
+                  shortName,
+                  info.schema(),
+                  info.options(),
+                  SaveMode.Append)
+                PythonBatchWriterFactory(source, writeInfo.func, 
info.schema(), jobArtifactUUID)
+              }
+
+              // TODO(SPARK-45914): Support commit protocol
+              override def commit(messages: Array[WriterCommitMessage]): Unit 
= {}
+
+              // TODO(SPARK-45914): Support commit protocol
+              override def abort(messages: Array[WriterCommitMessage]): Unit = 
{}
+            }
+
+            override def description: String = "(Python)"
+
+            override def supportedCustomMetrics(): Array[CustomMetric] =
+              source.createPythonMetrics()
+          }
+        }
+      }
     }
   }
 
@@ -124,27 +158,26 @@ class PythonPartitionReaderFactory(
     pickledReadFunc: Array[Byte],
     outputSchema: StructType,
     jobArtifactUUID: Option[String])
-  extends PartitionReaderFactory {
+  extends PartitionReaderFactory with PythonDataSourceSQLMetrics {
 
   override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
     new PartitionReader[InternalRow] {
-      // Dummy SQLMetrics. The result is manually reported via DSv2 interface
-      // via passing the value to `CustomTaskMetric`. Note that 
`pythonOtherMetricsDesc`
-      // is not used when it is reported. It is to reuse existing Python 
runner.
-      // See also `UserDefinedPythonDataSource.createPythonMetrics`.
-      private[this] val metrics: Map[String, SQLMetric] = {
-        PythonSQLMetrics.pythonSizeMetricsDesc.keys
-          .map(_ -> new SQLMetric("size", -1)).toMap ++
-        PythonSQLMetrics.pythonOtherMetricsDesc.keys
-          .map(_ -> new SQLMetric("sum", -1)).toMap
-      }
 
-      private val outputIter = source.createPartitionReadIteratorInPython(
-        partition.asInstanceOf[PythonInputPartition],
-        pickledReadFunc,
-        outputSchema,
-        metrics,
-        jobArtifactUUID)
+      private[this] val metrics: Map[String, SQLMetric] = pythonMetrics
+
+      private val outputIter = {
+        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+          pickledReadFunc,
+          "read_from_data_source",
+          UserDefinedPythonDataSource.readInputSchema,
+          outputSchema,
+          metrics,
+          jobArtifactUUID)
+
+        val part = partition.asInstanceOf[PythonInputPartition]
+        evaluatorFactory.createEvaluator().eval(
+          part.index, Iterator.single(InternalRow(part.pickedPartition)))
+      }
 
       override def next(): Boolean = outputIter.hasNext
 
@@ -159,9 +192,75 @@ class PythonPartitionReaderFactory(
   }
 }
 
+case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends 
WriterCommitMessage
+
+private case class PythonBatchWriterFactory(
+    source: UserDefinedPythonDataSource,
+    pickledWriteFunc: Array[Byte],
+    inputSchema: StructType,
+    jobArtifactUUID: Option[String]) extends DataWriterFactory with 
PythonDataSourceSQLMetrics {
+  override def createWriter(partitionId: Int, taskId: Long): 
DataWriter[InternalRow] = {
+    new DataWriter[InternalRow] {
+
+      private[this] val metrics: Map[String, SQLMetric] = pythonMetrics
+
+      private var commitMessage: PythonWriterCommitMessage = _
+
+      override def writeAll(records: java.util.Iterator[InternalRow]): Unit = {
+        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+          pickledWriteFunc,
+          "write_to_data_source",
+          inputSchema,
+          UserDefinedPythonDataSource.writeOutputSchema,
+          metrics,
+          jobArtifactUUID)
+        val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, 
records.asScala)
+        outputIter.foreach { row =>
+          if (commitMessage == null) {
+            commitMessage = PythonWriterCommitMessage(row.getBinary(0))
+          } else {
+            throw QueryExecutionErrors.invalidWriterCommitMessageError(details 
= "more than one")
+          }
+        }
+        if (commitMessage == null) {
+          throw QueryExecutionErrors.invalidWriterCommitMessageError(details = 
"zero")
+        }
+      }
+
+      override def write(record: InternalRow): Unit =
+        SparkException.internalError("write method for Python data source 
should not be called.")
+
+      override def commit(): WriterCommitMessage = {
+        commitMessage.asInstanceOf[WriterCommitMessage]
+      }
+
+      override def abort(): Unit = {}
+
+      override def close(): Unit = {}
+
+      override def currentMetricsValues(): Array[CustomTaskMetric] = {
+        source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> 
v.value })
+      }
+    }
+  }
+}
+
+trait PythonDataSourceSQLMetrics {
+  // Dummy SQLMetrics. The result is manually reported via DSv2 interface
+  // via passing the value to `CustomTaskMetric`. Note that 
`pythonOtherMetricsDesc`
+  // is not used when it is reported. It is to reuse existing Python runner.
+  // See also `UserDefinedPythonDataSource.createPythonMetrics`.
+  protected lazy val pythonMetrics: Map[String, SQLMetric] = {
+    PythonSQLMetrics.pythonSizeMetricsDesc.keys
+      .map(_ -> new SQLMetric("size", -1)).toMap ++
+      PythonSQLMetrics.pythonOtherMetricsDesc.keys
+        .map(_ -> new SQLMetric("sum", -1)).toMap
+  }
+}
+
 class PythonCustomMetric(
-    override val name: String,
-    override val description: String) extends CustomMetric {
+  override val name: String,
+  override val description: String) extends CustomMetric {
   // To allow the aggregation can be called. See 
`SQLAppStatusListener.aggregateMetrics`
   def this() = this(null, null)
 
@@ -182,8 +281,6 @@ class PythonCustomTaskMetric(
  */
 case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
 
-  private val inputSchema: StructType = new StructType().add("partition", 
BinaryType)
-
   /**
    * (Driver-side) Run Python process, and get the pickled Python Data Source
    * instance and its schema.
@@ -207,26 +304,44 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       pythonResult: PythonDataSourceCreationResult,
       outputSchema: StructType): PythonDataSourceReadInfo = {
     new UserDefinedPythonDataSourceReadRunner(
-      createPythonFunction(
-        pythonResult.dataSource), inputSchema, outputSchema).runInPython()
+      createPythonFunction(pythonResult.dataSource),
+      UserDefinedPythonDataSource.readInputSchema,
+      outputSchema).runInPython()
+  }
+
+  /**
+   * (Driver-side) Run Python process and get pickled write function.
+   */
+  def createWriteInfoInPython(
+      provider: String,
+      inputSchema: StructType,
+      options: CaseInsensitiveStringMap,
+      mode: SaveMode): PythonDataSourceWriteInfo = {
+    new UserDefinedPythonDataSourceWriteRunner(
+      dataSourceCls,
+      provider,
+      inputSchema,
+      options.asCaseSensitiveMap().asScala.toMap,
+      mode).runInPython()
   }
 
   /**
-   * (Executor-side) Create an iterator that reads the input partitions.
+   * (Executor-side) Create an iterator that execute the Python function.
    */
-  def createPartitionReadIteratorInPython(
-      partition: PythonInputPartition,
-      pickledReadFunc: Array[Byte],
+  def createMapInBatchEvaluatorFactory(
+      pickledFunc: Array[Byte],
+      funcName: String,
+      inputSchema: StructType,
       outputSchema: StructType,
       metrics: Map[String, SQLMetric],
-      jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
-    val readerFunc = createPythonFunction(pickledReadFunc)
+      jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
+    val pythonFunc = createPythonFunction(pickledFunc)
 
     val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
 
     val pythonUDF = PythonUDF(
-      name = "read_from_data_source",
-      func = readerFunc,
+      name = funcName,
+      func = pythonFunc,
       dataType = outputSchema,
       children = toAttributes(inputSchema),
       evalType = pythonEvalType,
@@ -235,7 +350,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
     val conf = SQLConf.get
 
     val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-    val evaluatorFactory = new MapInBatchEvaluatorFactory(
+    new MapInBatchEvaluatorFactory(
       toAttributes(outputSchema),
       Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
       inputSchema,
@@ -246,10 +361,6 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       pythonRunnerConf,
       metrics,
       jobArtifactUUID)
-
-    val part = partition
-    evaluatorFactory.createEvaluator().eval(
-      part.index, Iterator.single(InternalRow(part.pickedPartition)))
   }
 
   def createPythonMetrics(): Array[CustomMetric] = {
@@ -275,6 +386,18 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
   }
 }
 
+object UserDefinedPythonDataSource {
+  /**
+   * The schema of the input to the Python data source read function.
+   */
+  val readInputSchema: StructType = new StructType().add("partition", 
BinaryType)
+
+  /**
+   * The schema of the output to the Python data source write function.
+   */
+  val writeOutputSchema: StructType = new StructType().add("message", 
BinaryType)
+}
+
 /**
  * Used to store the result of creating a Python data source in the Python 
process.
  */
@@ -402,3 +525,63 @@ class UserDefinedPythonDataSourceReadRunner(
       partitions = pickledPartitions.toSeq)
   }
 }
+
+/**
+ * Hold the results of running [[UserDefinedPythonDataSourceWriteRunner]].
+ */
+case class PythonDataSourceWriteInfo(func: Array[Byte])
+
+/**
+ * A runner that creates a Python data source writer instance and returns a 
Python function
+ * to be used to write data into the data source.
+ */
+class UserDefinedPythonDataSourceWriteRunner(
+    dataSourceCls: PythonFunction,
+    provider: String,
+    inputSchema: StructType,
+    options: Map[String, String],
+    mode: SaveMode) extends 
PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) {
+
+  override val workerModule: String = 
"pyspark.sql.worker.write_into_data_source"
+
+  override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
+    // Send the Python data source class.
+    PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
+
+    // Send the provider name
+    PythonWorkerUtils.writeUTF(provider, dataOut)
+
+    // Send the input schema
+    PythonWorkerUtils.writeUTF(inputSchema.json, dataOut)
+
+    // Send the return type
+    
PythonWorkerUtils.writeUTF(UserDefinedPythonDataSource.writeOutputSchema.json, 
dataOut)
+
+    // Send the options
+    dataOut.writeInt(options.size)
+    options.iterator.foreach { case (key, value) =>
+      PythonWorkerUtils.writeUTF(key, dataOut)
+      PythonWorkerUtils.writeUTF(value, dataOut)
+    }
+
+    // Send the mode
+    PythonWorkerUtils.writeUTF(mode.toString, dataOut)
+  }
+
+  override protected def receiveFromPython(
+      dataIn: DataInputStream): PythonDataSourceWriteInfo = {
+
+    // Receive the picked UDF or an exception raised in Python worker.
+    val length = dataIn.readInt()
+    if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.failToPlanDataSourceError(
+        action = "plan", tpe = "write", msg = msg)
+    }
+
+    // Receive the pickled data source.
+    val writeUdf: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
+
+    PythonDataSourceWriteInfo(func = writeUdf)
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index e8a46449ac2..b04569ae554 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.python
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, 
QueryTest, Row}
 import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, 
DataSourceV2ScanRelation}
 import org.apache.spark.sql.test.SharedSparkSession
@@ -25,6 +26,8 @@ import org.apache.spark.sql.types.StructType
 class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
   import IntegratedUDFTestUtils._
 
+  setupTestData()
+
   private def dataSourceName = "SimpleDataSource"
   private def simpleDataSourceReaderScript: String =
     """
@@ -453,4 +456,96 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     assert(metrics.contains(pythonDataReceived.id))
     assert(metrics(pythonDataReceived.id).asInstanceOf[String].endsWith("B"))
   }
+
+  test("simple data source write") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |import json
+         |from pyspark import TaskContext
+         |from pyspark.sql.datasource import DataSource, DataSourceWriter, 
WriterCommitMessage
+         |
+         |class SimpleDataSourceWriter(DataSourceWriter):
+         |    def __init__(self, options):
+         |        self.options = options
+         |
+         |    def write(self, iterator):
+         |        context = TaskContext.get()
+         |        partition_id = context.partitionId()
+         |        path = self.options.get("path")
+         |        assert path is not None
+         |        output_path = f"{path}/{partition_id}.json"
+         |        cnt = 0
+         |        with open(output_path, "w") as file:
+         |            for row in iterator:
+         |                file.write(json.dumps(row.asDict()) + "\\n")
+         |                cnt += 1
+         |        return WriterCommitMessage()
+         |
+         |class SimpleDataSource(DataSource):
+         |    def writer(self, schema, saveMode):
+         |        return SimpleDataSourceWriter(self.options)
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    Seq(
+      "SELECT * FROM range(0, 5, 1, 3)",
+      "SELECT * FROM testData LIMIT 5",
+      "SELECT * FROM testData3",
+      "SELECT * FROM arrayData"
+    ).foreach { query =>
+      withTempDir { dir =>
+        val df = sql(query)
+        val path = dir.getAbsolutePath
+        df.write.format(dataSourceName).mode("append").save(path)
+        val df2 = spark.read.json(path)
+        checkAnswer(df, df2)
+      }
+    }
+  }
+
+  test("data source write - error cases") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceWriter
+         |
+         |class SimpleDataSourceWriter(DataSourceWriter):
+         |    def write(self, iterator):
+         |        num_rows = 0
+         |        for row in iterator:
+         |            num_rows += 1
+         |            if num_rows > 2:
+         |                raise Exception("something is wrong")
+         |
+         |class SimpleDataSource(DataSource):
+         |    def writer(self, schema, saveMode):
+         |        return SimpleDataSourceWriter()
+         |""".stripMargin
+    spark.dataSource.registerPython(dataSourceName,
+      createUserDefinedPythonDataSource(dataSourceName, dataSourceScript))
+
+    withClue("user error") {
+      val error = intercept[SparkException] {
+        spark.range(10).write.format(dataSourceName).mode("append").save()
+      }
+      assert(error.getMessage.contains("something is wrong"))
+    }
+
+    withClue("no commit message") {
+      val error = intercept[SparkException] {
+        spark.range(1).write.format(dataSourceName).mode("append").save()
+      }
+      assert(error.getMessage.contains("PYTHON_DATA_SOURCE_WRITE_ERROR"))
+    }
+
+    withClue("without mode") {
+      val error = intercept[AnalysisException] {
+        spark.range(1).write.format(dataSourceName).save()
+      }
+      // TODO: improve this error message.
+      assert(error.getMessage.contains("TableProvider implementation 
SimpleDataSource " +
+        "cannot be written with ErrorIfExists mode, please use Append or 
Overwrite modes instead."))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to