This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4fcd5bfe003 [SPARK-45525][SQL][PYTHON] Support for Python data source
write using DSv2
4fcd5bfe003 is described below
commit 4fcd5bfe003bb546ca888efaf1d39c15c9685673
Author: allisonwang-db <[email protected]>
AuthorDate: Fri Dec 22 09:28:47 2023 +0800
[SPARK-45525][SQL][PYTHON] Support for Python data source write using DSv2
### What changes were proposed in this pull request?
This PR adds initial support for Python data source write by implementing
the DSv2 `SupportsWrite` interface for `PythonTableProvider`.
Note this PR only supports the `def write(self, iterator)` API. `commit`
and `abort` will be supported in
[SPARK-45914](https://issues.apache.org/jira/browse/SPARK-45914).
### Why are the changes needed?
To support Python data source APIs. For instance:
```python
class SimpleWriter(DataSourceWriter):
def write(self, iterator: Iterator[Row]) -> WriterCommitMessage:
for row in iterator:
print(row)
return WriterCommitMessage()
class SimpleDataSource(DataSource):
def writer(self, schema, overwrite):
return SimpleWriter()
# Regsiter the Python data source
spark.dataSource.register(SimpleDataSource)
df.range(10).write.format("SimpleDataSource").mode("append").save()
```
### Does this PR introduce _any_ user-facing change?
Yes, this PR supports writing data into a Python data source.
### How was this patch tested?
New unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43791 from allisonwang-db/spark-45525-data-source-write.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-classes.json | 6 +
docs/sql-error-conditions.md | 6 +
python/pyspark/errors/error_classes.py | 5 +
python/pyspark/sql/tests/test_python_datasource.py | 36 ++-
.../pyspark/sql/worker/write_into_data_source.py | 233 ++++++++++++++++++
.../spark/sql/errors/QueryExecutionErrors.scala | 6 +
.../python/UserDefinedPythonDataSource.scala | 269 +++++++++++++++++----
.../execution/python/PythonDataSourceSuite.scala | 95 ++++++++
8 files changed, 612 insertions(+), 44 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index df223f3298e..8970045d4ab 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -2513,6 +2513,12 @@
],
"sqlState" : "42601"
},
+ "INVALID_WRITER_COMMIT_MESSAGE" : {
+ "message" : [
+ "The data source writer has generated an invalid number of commit
messages. Expected exactly one writer commit message from each task, but
received <detail>."
+ ],
+ "sqlState" : "42KDE"
+ },
"INVALID_WRITE_DISTRIBUTION" : {
"message" : [
"The requested write distribution is invalid."
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index a1af6863913..0722cae5815 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -1398,6 +1398,12 @@ Rewrite the query to avoid window functions, aggregate
functions, and generator
Cannot specify ORDER BY or a window frame for `<aggFunc>`.
+### INVALID_WRITER_COMMIT_MESSAGE
+
+[SQLSTATE:
42KDE](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+The data source writer has generated an invalid number of commit messages.
Expected exactly one writer commit message from each task, but received
`<detail>`.
+
###
[INVALID_WRITE_DISTRIBUTION](sql-error-conditions-invalid-write-distribution-error-class.html)
[SQLSTATE:
42000](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index bb278481262..2200b73dffc 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -772,6 +772,11 @@ ERROR_CLASSES_JSON = """
"Expected <expected>, but got <actual>."
]
},
+ "PYTHON_DATA_SOURCE_WRITE_ERROR" : {
+ "message" : [
+ "Unable to write to the Python data source: <error>."
+ ]
+ },
"PYTHON_HASH_SEED_NOT_SET" : {
"message" : [
"Randomness of hash of string should be disabled via PYTHONHASHSEED."
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index 74ef6a87458..b1bba584d85 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -15,11 +15,18 @@
# limitations under the License.
#
import os
+import tempfile
import unittest
from typing import Callable, Union
from pyspark.errors import PythonException
-from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceReader,
+ InputPartition,
+ DataSourceWriter,
+ WriterCommitMessage,
+)
from pyspark.sql.types import Row, StructType
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -235,6 +242,17 @@ class BasePythonDataSourceTestsMixin:
data = json.loads(line)
yield data.get("name"), data.get("age")
+ class JsonDataSourceWriter(DataSourceWriter):
+ def __init__(self, options):
+ self.options = options
+
+ def write(self, iterator):
+ path = self.options.get("path")
+ with open(path, "w") as file:
+ for row in iterator:
+ file.write(json.dumps(row.asDict()) + "\n")
+ return WriterCommitMessage()
+
class JsonDataSource(DataSource):
@classmethod
def name(cls):
@@ -246,7 +264,11 @@ class BasePythonDataSourceTestsMixin:
def reader(self, schema) -> "DataSourceReader":
return JsonDataSourceReader(self.options)
+ def writer(self, schema, overwrite):
+ return JsonDataSourceWriter(self.options)
+
self.spark.dataSource.register(JsonDataSource)
+ # Test data source read.
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
path2 = os.path.join(SPARK_HOME,
"python/test_support/sql/people1.json")
assertDataFrameEqual(
@@ -257,6 +279,18 @@ class BasePythonDataSourceTestsMixin:
self.spark.read.format("my-json").load(path2),
[Row(name="Jonathan", age=None)],
)
+ # Test data source write.
+ df = self.spark.read.json(path1)
+ with tempfile.TemporaryDirectory() as d:
+ path = os.path.join(d, "res.json")
+ df.write.format("my-json").mode("append").save(path)
+ with open(path, "r") as file:
+ text = file.read()
+ assert text == (
+ '{"age": null, "name": "Michael"}\n'
+ '{"age": 30, "name": "Andy"}\n'
+ '{"age": 19, "name": "Justin"}\n'
+ )
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/worker/write_into_data_source.py
b/python/pyspark/sql/worker/write_into_data_source.py
new file mode 100644
index 00000000000..9c311dad033
--- /dev/null
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -0,0 +1,233 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import inspect
+import os
+import sys
+from typing import IO, Iterable, Iterator
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError,
PySparkTypeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+ read_int,
+ write_int,
+ SpecialLengths,
+)
+from pyspark.sql import Row
+from pyspark.sql.datasource import DataSource, WriterCommitMessage
+from pyspark.sql.types import (
+ _parse_datatype_json_string,
+ StructType,
+ BinaryType,
+ _create_row,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+ check_python_version,
+ read_command,
+ pickleSer,
+ send_accumulator_updates,
+ setup_broadcasts,
+ setup_memory_limits,
+ setup_spark_files,
+ utf8_deserializer,
+)
+
+
+def main(infile: IO, outfile: IO) -> None:
+ """
+ Main method for saving into a Python data source.
+
+ This process is invoked from the
`SaveIntoPythonDataSourceRunner.runInPython` method
+ in the optimizer rule `PythonDataSourceWrites` in JVM. This process is
responsible for
+ creating a `DataSource` object and a DataSourceWriter instance, and send
information
+ needed back to the JVM.
+
+ The JVM sends the following information to this process:
+ - a `DataSource` class representing the data source to be created.
+ - a provider name in string.
+ - a schema in json string.
+ - a dictionary of options in string.
+
+ This process first creates a `DataSource` instance and then a
`DataSourceWriter`
+ instance and send a function using the writer instance that can be used
+ in mapInPandas/mapInArrow back to the JVM.
+ """
+ try:
+ check_python_version(infile)
+
+ memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB",
"-1"))
+ setup_memory_limits(memory_limit_mb)
+
+ setup_spark_files(infile)
+ setup_broadcasts(infile)
+
+ _accumulatorRegistry.clear()
+
+ # Receive the data source class.
+ data_source_cls = read_command(pickleSer, infile)
+ if not (isinstance(data_source_cls, type) and
issubclass(data_source_cls, DataSource)):
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "a subclass of DataSource",
+ "actual": f"'{type(data_source_cls).__name__}'",
+ },
+ )
+
+ # Check the name method is a class method.
+ if not inspect.ismethod(data_source_cls.name):
+ raise PySparkTypeError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "'name()' method to be a classmethod",
+ "actual": f"'{type(data_source_cls.name).__name__}'",
+ },
+ )
+
+ # Receive the provider name.
+ provider = utf8_deserializer.loads(infile)
+
+ # Check if the provider name matches the data source's name.
+ if provider.lower() != data_source_cls.name().lower():
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": f"provider with name {data_source_cls.name()}",
+ "actual": f"'{provider}'",
+ },
+ )
+
+ # Receive the input schema
+ schema = _parse_datatype_json_string(utf8_deserializer.loads(infile))
+ if not isinstance(schema, StructType):
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "the schema to be a 'StructType'",
+ "actual": f"'{type(data_source_cls).__name__}'",
+ },
+ )
+
+ # Receive the return type
+ return_type =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+ if not isinstance(return_type, StructType):
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "a return type of type 'StructType'",
+ "actual": f"'{type(return_type).__name__}'",
+ },
+ )
+ assert len(return_type) == 1 and isinstance(return_type[0].dataType,
BinaryType), (
+ "The output schema of Python data source write should contain only
one column of type "
+ f"'BinaryType', but got '{return_type}'"
+ )
+ return_col_name = return_type[0].name
+
+ # Receive the options.
+ options = dict()
+ num_options = read_int(infile)
+ for _ in range(num_options):
+ key = utf8_deserializer.loads(infile)
+ value = utf8_deserializer.loads(infile)
+ options[key] = value
+
+ # Receive the save mode.
+ save_mode = utf8_deserializer.loads(infile)
+
+ # Instantiate a data source.
+ try:
+ data_source = data_source_cls(options=options)
+ except Exception as e:
+ raise PySparkRuntimeError(
+ error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
+ message_parameters={"type": "instance", "error": str(e)},
+ )
+
+ # Instantiate the data source writer.
+ try:
+ writer = data_source.writer(schema, save_mode)
+ except Exception as e:
+ raise PySparkRuntimeError(
+ error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
+ message_parameters={"type": "writer", "error": str(e)},
+ )
+
+ # Create a function that can be used in mapInArrow.
+ import pyarrow as pa
+
+ converters = [
+ ArrowTableToRowsConversion._create_converter(f.dataType) for f in
schema.fields
+ ]
+ fields = schema.fieldNames()
+
+ def data_source_write_func(iterator: Iterable[pa.RecordBatch]) ->
Iterable[pa.RecordBatch]:
+ def batch_to_rows() -> Iterator[Row]:
+ for batch in iterator:
+ columns = [column.to_pylist() for column in batch.columns]
+ for row in range(0, batch.num_rows):
+ values = [
+ converters[col](columns[col][row]) for col in
range(batch.num_columns)
+ ]
+ yield _create_row(fields=fields, values=values)
+
+ res = writer.write(batch_to_rows())
+
+ # Check the commit message has the right type.
+ if not isinstance(res, WriterCommitMessage):
+ raise PySparkRuntimeError(
+ error_class="PYTHON_DATA_SOURCE_WRITE_ERROR",
+ message_parameters={
+ "error": f"return type of the `write` method must be "
+ f"an instance of WriterCommitMessage, but got
{type(res)}"
+ },
+ )
+
+ # Serialize the commit message and return it.
+ pickled = pickleSer.dumps(res)
+
+ # Return the commit message.
+ messages = pa.array([pickled])
+ yield pa.record_batch([messages], names=[return_col_name])
+
+ # Return the pickled write UDF.
+ command = (data_source_write_func, return_type)
+ pickleSer._write_with_length(command, outfile)
+
+ except BaseException as e:
+ handle_worker_exception(e, outfile)
+ sys.exit(-1)
+
+ send_accumulator_updates(outfile)
+
+ # check end of stream
+ if read_int(infile) == SpecialLengths.END_OF_STREAM:
+ write_int(SpecialLengths.END_OF_STREAM, outfile)
+ else:
+ # write a different value to tell JVM to not reuse this worker
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ # Read information about how to connect back to the JVM from the
environment.
+ java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+ auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+ (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ main(sock_file, sock_file)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 113f995968a..b0eaf84fe6a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2771,4 +2771,10 @@ private[sql] object QueryExecutionErrors extends
QueryErrorsBase with ExecutionE
"parameter" -> toSQLId("charset"),
"charset" -> charset))
}
+
+ def invalidWriterCommitMessageError(details: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_WRITER_COMMIT_MESSAGE",
+ messageParameters = Map("details" -> details))
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index d31b3135d65..00974a7e297 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -24,19 +24,20 @@ import scala.jdk.CollectionConverters._
import net.razorvine.pickle.Pickler
-import org.apache.spark.JobArtifactSet
+import org.apache.spark.{JobArtifactSet, SparkException}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType,
PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths}
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.PythonUDF
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.connector.catalog.{SupportsRead, Table,
TableCapability, TableProvider}
-import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
+import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite,
Table, TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ,
BATCH_WRITE}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.read.{Batch, InputPartition,
PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter,
DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder,
WriterCommitMessage}
+import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BinaryType, DataType, StructType}
@@ -72,11 +73,11 @@ class PythonTableProvider extends TableProvider {
partitioning: Array[Transform],
properties: java.util.Map[String, String]): Table = {
val outputSchema = schema
- new Table with SupportsRead {
+ new Table with SupportsRead with SupportsWrite {
override def name(): String = shortName
override def capabilities(): java.util.Set[TableCapability] =
java.util.EnumSet.of(
- BATCH_READ)
+ BATCH_READ, BATCH_WRITE)
override def newScanBuilder(options: CaseInsensitiveStringMap):
ScanBuilder = {
new ScanBuilder with Batch with Scan {
@@ -103,6 +104,7 @@ class PythonTableProvider extends TableProvider {
new PythonPartitionReaderFactory(
source, readerFunc, outputSchema, jobArtifactUUID)
}
+
override def description: String = "(Python)"
override def supportedCustomMetrics(): Array[CustomMetric] =
@@ -111,6 +113,38 @@ class PythonTableProvider extends TableProvider {
}
override def schema(): StructType = outputSchema
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ new WriteBuilder {
+ override def build(): Write = new Write {
+
+ override def toBatch: BatchWrite = new BatchWrite {
+
+ override def createBatchWriterFactory(
+ physicalInfo: PhysicalWriteInfo): DataWriterFactory = {
+
+ val writeInfo = source.createWriteInfoInPython(
+ shortName,
+ info.schema(),
+ info.options(),
+ SaveMode.Append)
+ PythonBatchWriterFactory(source, writeInfo.func,
info.schema(), jobArtifactUUID)
+ }
+
+ // TODO(SPARK-45914): Support commit protocol
+ override def commit(messages: Array[WriterCommitMessage]): Unit
= {}
+
+ // TODO(SPARK-45914): Support commit protocol
+ override def abort(messages: Array[WriterCommitMessage]): Unit =
{}
+ }
+
+ override def description: String = "(Python)"
+
+ override def supportedCustomMetrics(): Array[CustomMetric] =
+ source.createPythonMetrics()
+ }
+ }
+ }
}
}
@@ -124,27 +158,26 @@ class PythonPartitionReaderFactory(
pickledReadFunc: Array[Byte],
outputSchema: StructType,
jobArtifactUUID: Option[String])
- extends PartitionReaderFactory {
+ extends PartitionReaderFactory with PythonDataSourceSQLMetrics {
override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
new PartitionReader[InternalRow] {
- // Dummy SQLMetrics. The result is manually reported via DSv2 interface
- // via passing the value to `CustomTaskMetric`. Note that
`pythonOtherMetricsDesc`
- // is not used when it is reported. It is to reuse existing Python
runner.
- // See also `UserDefinedPythonDataSource.createPythonMetrics`.
- private[this] val metrics: Map[String, SQLMetric] = {
- PythonSQLMetrics.pythonSizeMetricsDesc.keys
- .map(_ -> new SQLMetric("size", -1)).toMap ++
- PythonSQLMetrics.pythonOtherMetricsDesc.keys
- .map(_ -> new SQLMetric("sum", -1)).toMap
- }
- private val outputIter = source.createPartitionReadIteratorInPython(
- partition.asInstanceOf[PythonInputPartition],
- pickledReadFunc,
- outputSchema,
- metrics,
- jobArtifactUUID)
+ private[this] val metrics: Map[String, SQLMetric] = pythonMetrics
+
+ private val outputIter = {
+ val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+ pickledReadFunc,
+ "read_from_data_source",
+ UserDefinedPythonDataSource.readInputSchema,
+ outputSchema,
+ metrics,
+ jobArtifactUUID)
+
+ val part = partition.asInstanceOf[PythonInputPartition]
+ evaluatorFactory.createEvaluator().eval(
+ part.index, Iterator.single(InternalRow(part.pickedPartition)))
+ }
override def next(): Boolean = outputIter.hasNext
@@ -159,9 +192,75 @@ class PythonPartitionReaderFactory(
}
}
+case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends
WriterCommitMessage
+
+private case class PythonBatchWriterFactory(
+ source: UserDefinedPythonDataSource,
+ pickledWriteFunc: Array[Byte],
+ inputSchema: StructType,
+ jobArtifactUUID: Option[String]) extends DataWriterFactory with
PythonDataSourceSQLMetrics {
+ override def createWriter(partitionId: Int, taskId: Long):
DataWriter[InternalRow] = {
+ new DataWriter[InternalRow] {
+
+ private[this] val metrics: Map[String, SQLMetric] = pythonMetrics
+
+ private var commitMessage: PythonWriterCommitMessage = _
+
+ override def writeAll(records: java.util.Iterator[InternalRow]): Unit = {
+ val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+ pickledWriteFunc,
+ "write_to_data_source",
+ inputSchema,
+ UserDefinedPythonDataSource.writeOutputSchema,
+ metrics,
+ jobArtifactUUID)
+ val outputIter = evaluatorFactory.createEvaluator().eval(partitionId,
records.asScala)
+ outputIter.foreach { row =>
+ if (commitMessage == null) {
+ commitMessage = PythonWriterCommitMessage(row.getBinary(0))
+ } else {
+ throw QueryExecutionErrors.invalidWriterCommitMessageError(details
= "more than one")
+ }
+ }
+ if (commitMessage == null) {
+ throw QueryExecutionErrors.invalidWriterCommitMessageError(details =
"zero")
+ }
+ }
+
+ override def write(record: InternalRow): Unit =
+ SparkException.internalError("write method for Python data source
should not be called.")
+
+ override def commit(): WriterCommitMessage = {
+ commitMessage.asInstanceOf[WriterCommitMessage]
+ }
+
+ override def abort(): Unit = {}
+
+ override def close(): Unit = {}
+
+ override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ source.createPythonTaskMetrics(metrics.map { case (k, v) => k ->
v.value })
+ }
+ }
+ }
+}
+
+trait PythonDataSourceSQLMetrics {
+ // Dummy SQLMetrics. The result is manually reported via DSv2 interface
+ // via passing the value to `CustomTaskMetric`. Note that
`pythonOtherMetricsDesc`
+ // is not used when it is reported. It is to reuse existing Python runner.
+ // See also `UserDefinedPythonDataSource.createPythonMetrics`.
+ protected lazy val pythonMetrics: Map[String, SQLMetric] = {
+ PythonSQLMetrics.pythonSizeMetricsDesc.keys
+ .map(_ -> new SQLMetric("size", -1)).toMap ++
+ PythonSQLMetrics.pythonOtherMetricsDesc.keys
+ .map(_ -> new SQLMetric("sum", -1)).toMap
+ }
+}
+
class PythonCustomMetric(
- override val name: String,
- override val description: String) extends CustomMetric {
+ override val name: String,
+ override val description: String) extends CustomMetric {
// To allow the aggregation can be called. See
`SQLAppStatusListener.aggregateMetrics`
def this() = this(null, null)
@@ -182,8 +281,6 @@ class PythonCustomTaskMetric(
*/
case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
- private val inputSchema: StructType = new StructType().add("partition",
BinaryType)
-
/**
* (Driver-side) Run Python process, and get the pickled Python Data Source
* instance and its schema.
@@ -207,26 +304,44 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
pythonResult: PythonDataSourceCreationResult,
outputSchema: StructType): PythonDataSourceReadInfo = {
new UserDefinedPythonDataSourceReadRunner(
- createPythonFunction(
- pythonResult.dataSource), inputSchema, outputSchema).runInPython()
+ createPythonFunction(pythonResult.dataSource),
+ UserDefinedPythonDataSource.readInputSchema,
+ outputSchema).runInPython()
+ }
+
+ /**
+ * (Driver-side) Run Python process and get pickled write function.
+ */
+ def createWriteInfoInPython(
+ provider: String,
+ inputSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ mode: SaveMode): PythonDataSourceWriteInfo = {
+ new UserDefinedPythonDataSourceWriteRunner(
+ dataSourceCls,
+ provider,
+ inputSchema,
+ options.asCaseSensitiveMap().asScala.toMap,
+ mode).runInPython()
}
/**
- * (Executor-side) Create an iterator that reads the input partitions.
+ * (Executor-side) Create an iterator that execute the Python function.
*/
- def createPartitionReadIteratorInPython(
- partition: PythonInputPartition,
- pickledReadFunc: Array[Byte],
+ def createMapInBatchEvaluatorFactory(
+ pickledFunc: Array[Byte],
+ funcName: String,
+ inputSchema: StructType,
outputSchema: StructType,
metrics: Map[String, SQLMetric],
- jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
- val readerFunc = createPythonFunction(pickledReadFunc)
+ jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
+ val pythonFunc = createPythonFunction(pickledFunc)
val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
val pythonUDF = PythonUDF(
- name = "read_from_data_source",
- func = readerFunc,
+ name = funcName,
+ func = pythonFunc,
dataType = outputSchema,
children = toAttributes(inputSchema),
evalType = pythonEvalType,
@@ -235,7 +350,7 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
val conf = SQLConf.get
val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
- val evaluatorFactory = new MapInBatchEvaluatorFactory(
+ new MapInBatchEvaluatorFactory(
toAttributes(outputSchema),
Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
inputSchema,
@@ -246,10 +361,6 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
pythonRunnerConf,
metrics,
jobArtifactUUID)
-
- val part = partition
- evaluatorFactory.createEvaluator().eval(
- part.index, Iterator.single(InternalRow(part.pickedPartition)))
}
def createPythonMetrics(): Array[CustomMetric] = {
@@ -275,6 +386,18 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
}
}
+object UserDefinedPythonDataSource {
+ /**
+ * The schema of the input to the Python data source read function.
+ */
+ val readInputSchema: StructType = new StructType().add("partition",
BinaryType)
+
+ /**
+ * The schema of the output to the Python data source write function.
+ */
+ val writeOutputSchema: StructType = new StructType().add("message",
BinaryType)
+}
+
/**
* Used to store the result of creating a Python data source in the Python
process.
*/
@@ -402,3 +525,63 @@ class UserDefinedPythonDataSourceReadRunner(
partitions = pickledPartitions.toSeq)
}
}
+
+/**
+ * Hold the results of running [[UserDefinedPythonDataSourceWriteRunner]].
+ */
+case class PythonDataSourceWriteInfo(func: Array[Byte])
+
+/**
+ * A runner that creates a Python data source writer instance and returns a
Python function
+ * to be used to write data into the data source.
+ */
+class UserDefinedPythonDataSourceWriteRunner(
+ dataSourceCls: PythonFunction,
+ provider: String,
+ inputSchema: StructType,
+ options: Map[String, String],
+ mode: SaveMode) extends
PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) {
+
+ override val workerModule: String =
"pyspark.sql.worker.write_into_data_source"
+
+ override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
+ // Send the Python data source class.
+ PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
+
+ // Send the provider name
+ PythonWorkerUtils.writeUTF(provider, dataOut)
+
+ // Send the input schema
+ PythonWorkerUtils.writeUTF(inputSchema.json, dataOut)
+
+ // Send the return type
+
PythonWorkerUtils.writeUTF(UserDefinedPythonDataSource.writeOutputSchema.json,
dataOut)
+
+ // Send the options
+ dataOut.writeInt(options.size)
+ options.iterator.foreach { case (key, value) =>
+ PythonWorkerUtils.writeUTF(key, dataOut)
+ PythonWorkerUtils.writeUTF(value, dataOut)
+ }
+
+ // Send the mode
+ PythonWorkerUtils.writeUTF(mode.toString, dataOut)
+ }
+
+ override protected def receiveFromPython(
+ dataIn: DataInputStream): PythonDataSourceWriteInfo = {
+
+ // Receive the picked UDF or an exception raised in Python worker.
+ val length = dataIn.readInt()
+ if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryCompilationErrors.failToPlanDataSourceError(
+ action = "plan", tpe = "write", msg = msg)
+ }
+
+ // Receive the pickled data source.
+ val writeUdf: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn)
+
+ PythonDataSourceWriteInfo(func = writeUdf)
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index e8a46449ac2..b04569ae554 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.python
+import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils,
QueryTest, Row}
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2ScanRelation}
import org.apache.spark.sql.test.SharedSparkSession
@@ -25,6 +26,8 @@ import org.apache.spark.sql.types.StructType
class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
import IntegratedUDFTestUtils._
+ setupTestData()
+
private def dataSourceName = "SimpleDataSource"
private def simpleDataSourceReaderScript: String =
"""
@@ -453,4 +456,96 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
assert(metrics.contains(pythonDataReceived.id))
assert(metrics(pythonDataReceived.id).asInstanceOf[String].endsWith("B"))
}
+
+ test("simple data source write") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |import json
+ |from pyspark import TaskContext
+ |from pyspark.sql.datasource import DataSource, DataSourceWriter,
WriterCommitMessage
+ |
+ |class SimpleDataSourceWriter(DataSourceWriter):
+ | def __init__(self, options):
+ | self.options = options
+ |
+ | def write(self, iterator):
+ | context = TaskContext.get()
+ | partition_id = context.partitionId()
+ | path = self.options.get("path")
+ | assert path is not None
+ | output_path = f"{path}/{partition_id}.json"
+ | cnt = 0
+ | with open(output_path, "w") as file:
+ | for row in iterator:
+ | file.write(json.dumps(row.asDict()) + "\\n")
+ | cnt += 1
+ | return WriterCommitMessage()
+ |
+ |class SimpleDataSource(DataSource):
+ | def writer(self, schema, saveMode):
+ | return SimpleDataSourceWriter(self.options)
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ Seq(
+ "SELECT * FROM range(0, 5, 1, 3)",
+ "SELECT * FROM testData LIMIT 5",
+ "SELECT * FROM testData3",
+ "SELECT * FROM arrayData"
+ ).foreach { query =>
+ withTempDir { dir =>
+ val df = sql(query)
+ val path = dir.getAbsolutePath
+ df.write.format(dataSourceName).mode("append").save(path)
+ val df2 = spark.read.json(path)
+ checkAnswer(df, df2)
+ }
+ }
+ }
+
+ test("data source write - error cases") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceWriter
+ |
+ |class SimpleDataSourceWriter(DataSourceWriter):
+ | def write(self, iterator):
+ | num_rows = 0
+ | for row in iterator:
+ | num_rows += 1
+ | if num_rows > 2:
+ | raise Exception("something is wrong")
+ |
+ |class SimpleDataSource(DataSource):
+ | def writer(self, schema, saveMode):
+ | return SimpleDataSourceWriter()
+ |""".stripMargin
+ spark.dataSource.registerPython(dataSourceName,
+ createUserDefinedPythonDataSource(dataSourceName, dataSourceScript))
+
+ withClue("user error") {
+ val error = intercept[SparkException] {
+ spark.range(10).write.format(dataSourceName).mode("append").save()
+ }
+ assert(error.getMessage.contains("something is wrong"))
+ }
+
+ withClue("no commit message") {
+ val error = intercept[SparkException] {
+ spark.range(1).write.format(dataSourceName).mode("append").save()
+ }
+ assert(error.getMessage.contains("PYTHON_DATA_SOURCE_WRITE_ERROR"))
+ }
+
+ withClue("without mode") {
+ val error = intercept[AnalysisException] {
+ spark.range(1).write.format(dataSourceName).save()
+ }
+ // TODO: improve this error message.
+ assert(error.getMessage.contains("TableProvider implementation
SimpleDataSource " +
+ "cannot be written with ErrorIfExists mode, please use Append or
Overwrite modes instead."))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]