This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5ac39181fe87 [SPARK-47273][SS][PYTHON] implement Python data stream writer interface 5ac39181fe87 is described below commit 5ac39181fe87aba4eab66ff2590bbc16349c0bab Author: Chaoqin Li <chaoqin...@databricks.com> AuthorDate: Wed Mar 27 12:51:36 2024 +0900 [SPARK-47273][SS][PYTHON] implement Python data stream writer interface ### What changes were proposed in this pull request? Reuse PythonPartitionWriter to implement the serialization and execution of write callback in executor. Implement python worker process to run python streaming data sink committer and communicate with JVM through socket in spark driver. For each python streaming data sink instance there will be a long live python worker process created. Inside the python process, the python write committer will receive abort or commit function call and send back result through socket. ### Why are the changes needed? In order to support developing spark streaming sink in python, we need to implement python stream writer interface. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit and integration test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45305 from chaoqin-li1123/python_stream_writer. Authored-by: Chaoqin Li <chaoqin...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- python/pyspark/sql/datasource.py | 94 ++++++++ .../sql/worker/python_streaming_sink_runner.py | 140 +++++++++++ .../pyspark/sql/worker/write_into_data_source.py | 10 +- .../python/PythonStreamingSinkCommitRunner.scala | 133 +++++++++++ .../v2/python/PythonStreamingWrite.scala | 84 +++++++ .../datasources/v2/python/PythonTable.scala | 4 +- .../datasources/v2/python/PythonWrite.scala | 12 +- .../v2/python/UserDefinedPythonDataSource.scala | 11 +- .../spark/sql/streaming/DataStreamWriter.scala | 5 + .../python/PythonStreamingDataSourceSuite.scala | 261 ++++++++++++++++++++- 10 files changed, 744 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index 803765e83093..c08b5b7af77f 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -160,6 +160,29 @@ class DataSource(ABC): message_parameters={"feature": "writer"}, ) + def streamWriter(self, schema: StructType, overwrite: bool) -> "DataSourceStreamWriter": + """ + Returns a :class:`DataSourceStreamWriter` instance for writing data into a streaming sink. + + The implementation is required for writable streaming data sources. + + Parameters + ---------- + schema : :class:`StructType` + The schema of the data to be written. + overwrite : bool + A flag indicating whether to overwrite existing data when writing current microbatch. + + Returns + ------- + writer : :class:`DataSourceStreamWriter` + A writer instance for writing data into a streaming sink. + """ + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={"feature": "streamWriter"}, + ) + def streamReader(self, schema: StructType) -> "DataSourceStreamReader": """ Returns a :class:`DataSourceStreamReader` instance for reading streaming data. @@ -513,6 +536,77 @@ class DataSourceWriter(ABC): ... +class DataSourceStreamWriter(ABC): + """ + A base class for data stream writers. Data stream writers are responsible for writing + the data to the streaming sink. + + .. versionadded: 4.0.0 + """ + + @abstractmethod + def write(self, iterator: Iterator[Row]) -> "WriterCommitMessage": + """ + Writes data into the streaming sink. + + This method is called on executors to write data to the streaming data sink in + each microbatch. It accepts an iterator of input data and returns a single row + representing a commit message, or None if there is no commit message. + + The driver collects commit messages, if any, from all executors and passes them + to the ``commit`` method if all tasks run successfully. If any task fails, the + ``abort`` method will be called with the collected commit messages. + + Parameters + ---------- + iterator : Iterator[Row] + An iterator of input data. + + Returns + ------- + WriterCommitMessage : a serializable commit message + """ + ... + + def commit(self, messages: List["WriterCommitMessage"], batchId: int) -> None: + """ + Commits this microbatch with a list of commit messages. + + This method is invoked on the driver when all tasks run successfully. The + commit messages are collected from the ``write`` method call from each task, + and are passed to this method. The implementation should use the commit messages + to commit the microbatch in the streaming sink. + + Parameters + ---------- + messages : List[WriterCommitMessage] + A list of commit messages. + batchId: int + An integer that uniquely identifies a batch of data being written. + The integer increase by 1 with each microbatch processed. + """ + ... + + def abort(self, messages: List["WriterCommitMessage"], batchId: int) -> None: + """ + Aborts this microbatch due to task failures. + + This method is invoked on the driver when one or more tasks failed. The commit + messages are collected from the ``write`` method call from each task, and are + passed to this method. The implementation should use the commit messages to + abort the microbatch in the streaming sink. + + Parameters + ---------- + messages : List[WriterCommitMessage] + A list of commit messages. + batchId: int + An integer that uniquely identifies a batch of data being written. + The integer increase by 1 with each microbatch processed. + """ + ... + + class WriterCommitMessage: """ A commit message returned by the :meth:`DataSourceWriter.write` and will be diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py b/python/pyspark/sql/worker/python_streaming_sink_runner.py new file mode 100644 index 000000000000..d4f81da5aceb --- /dev/null +++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +from typing import IO + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.errors import PySparkAssertionError, PySparkRuntimeError +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_bool, + read_int, + read_long, + write_int, + SpecialLengths, +) +from pyspark.sql.datasource import DataSource, WriterCommitMessage +from pyspark.sql.types import ( + _parse_datatype_json_string, + StructType, +) +from pyspark.util import handle_worker_exception +from pyspark.worker_util import ( + check_python_version, + read_command, + pickleSer, + send_accumulator_updates, + setup_memory_limits, + setup_spark_files, + utf8_deserializer, +) + + +def main(infile: IO, outfile: IO) -> None: + try: + check_python_version(infile) + setup_spark_files(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + + _accumulatorRegistry.clear() + + # Receive the data source instance. + data_source = read_command(pickleSer, infile) + + if not isinstance(data_source, DataSource): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "a Python data source instance of type 'DataSource'", + "actual": f"'{type(data_source).__name__}'", + }, + ) + # Receive the data source output schema. + schema_json = utf8_deserializer.loads(infile) + schema = _parse_datatype_json_string(schema_json) + if not isinstance(schema, StructType): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an output schema of type 'StructType'", + "actual": f"'{type(schema).__name__}'", + }, + ) + # Receive the `overwrite` flag. + overwrite = read_bool(infile) + # Instantiate data source reader. + try: + writer = data_source.streamWriter(schema=schema, overwrite=overwrite) + # Initialization succeed. + write_int(0, outfile) + outfile.flush() + + # handle method call from socket + while True: + num_messages = read_int(infile) + commit_messages = [] + for _ in range(num_messages): + message = pickleSer._read_with_length(infile) + if message is not None and not isinstance(message, WriterCommitMessage): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an instance of WriterCommitMessage", + "actual": f"'{type(message).__name__}'", + }, + ) + commit_messages.append(message) + batch_id = read_long(infile) + abort = read_bool(infile) + # Commit or abort the Python data source write. + # Note the commit messages can be None if there are failed tasks. + if abort: + writer.abort(commit_messages, batch_id) # type: ignore[arg-type] + else: + writer.commit(commit_messages, batch_id) # type: ignore[arg-type] + write_int(0, outfile) + outfile.flush() + except Exception as e: + error_msg = "data source {} throw exception: {}".format(data_source.name, e) + raise PySparkRuntimeError( + error_class="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + message_parameters={"action": "commitOrAbort", "error": error_msg}, + ) + except BaseException as e: + handle_worker_exception(e, outfile) + sys.exit(-1) + send_accumulator_updates(outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 0ba6fc6eb17f..490ede9ab0f2 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -152,11 +152,17 @@ def main(infile: IO, outfile: IO) -> None: # Receive the `overwrite` flag. overwrite = read_bool(infile) + is_streaming = read_bool(infile) + # Instantiate a data source. data_source = data_source_cls(options=options) # type: ignore - # Instantiate the data source writer. - writer = data_source.writer(schema, overwrite) + if is_streaming: + # Instantiate the streaming data source writer. + writer = data_source.streamWriter(schema, overwrite) + else: + # Instantiate the data source writer. + writer = data_source.writer(schema, overwrite) # type: ignore[assignment] # Create a function that can be used in mapInArrow. import pyarrow as pa diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala new file mode 100644 index 000000000000..a444fdfff7d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkEnv +import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerFactory, PythonWorkerUtils, SpecialLengths} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.BUFFER_SIZE +import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.StructType + +/** + * This class is a proxy to invoke commit or abort methods in Python DataSourceStreamWriter. + * A runner spawns a python worker process. In the main function, set up communication + * between JVM and python process through socket and create a DataSourceStreamWriter instance. + * In an infinite loop, the python worker process receive write commit messages + * from the socket, then commit or abort a microbatch. + */ +class PythonStreamingSinkCommitRunner( + func: PythonFunction, + schema: StructType, + overwrite: Boolean) extends Logging { + val workerModule: String = "pyspark.sql.worker.python_streaming_sink_runner" + + private val conf = SparkEnv.get.conf + protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) + + private val envVars: java.util.Map[String, String] = func.envVars + private val pythonExec: String = func.pythonExec + private var pythonWorker: Option[PythonWorker] = None + private var pythonWorkerFactory: Option[PythonWorkerFactory] = None + protected val pythonVer: String = func.pythonVer + + private var dataOut: DataOutputStream = null + private var dataIn: DataInputStream = null + + /** + * Initializes the Python worker for running the streaming sink committer. + */ + def init(): Unit = { + logInfo(s"Initializing Python runner pythonExec: $pythonExec") + val env = SparkEnv.get + + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) + + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + + val workerFactory = + new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap, false) + val (worker: PythonWorker, _) = workerFactory.createSimpleWorker(blockingMode = true) + pythonWorker = Some(worker) + pythonWorkerFactory = Some(workerFactory) + + val stream = new BufferedOutputStream( + pythonWorker.get.channel.socket().getOutputStream, bufferSize) + dataOut = new DataOutputStream(stream) + + PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) + + val pythonIncludes = func.pythonIncludes.asScala.toSet + PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, dataOut) + + // Send the user function to python process + PythonWorkerUtils.writePythonFunction(func, dataOut) + + PythonWorkerUtils.writeUTF(schema.json, dataOut) + + dataOut.writeBoolean(overwrite) + + dataOut.flush() + + dataIn = new DataInputStream( + new BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, bufferSize)) + + val initStatus = dataIn.readInt() + if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "initialize streaming sink", msg) + } + } + + init() + + def commitOrAbort( + messages: Array[WriterCommitMessage], + batchId: Long, + abort: Boolean): Unit = { + dataOut.writeInt(messages.length) + messages.foreach { message => + // Commit messages can be null if there are task failures. + if (message == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + PythonWorkerUtils.writeBytes( + message.asInstanceOf[PythonWriterCommitMessage].pickledMessage, dataOut) + } + } + dataOut.writeLong(batchId) + dataOut.writeBoolean(abort) + dataOut.flush() + val status = dataIn.readInt() + if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + val action = if (abort) "abort" else "commit" + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(action, msg) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala new file mode 100644 index 000000000000..483fd5a4e0a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.JobArtifactSet +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.types.StructType + +/** + * A [[streamingWrite]] for python data source writing. + * Responsible for generating the writer factory, committing or aborting a microbatch. + * */ +class PythonStreamingWrite( + ds: PythonDataSourceV2, + shortName: String, + info: LogicalWriteInfo, + isTruncate: Boolean) extends StreamingWrite { + + // Store the pickled data source writer instance. + private var pythonDataSourceWriter: Array[Byte] = _ + + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + private def createDataSourceFunc = + ds.source.createPythonFunction( + ds.getOrCreateDataSourceInPython(shortName, info.options(), Some(info.schema())).dataSource + ) + + private lazy val pythonStreamingSinkCommitRunner = + new PythonStreamingSinkCommitRunner(createDataSourceFunc, info.schema(), isTruncate) + + override def createStreamingWriterFactory( + physicalInfo: PhysicalWriteInfo): StreamingDataWriterFactory = { + val writeInfo = ds.source.createWriteInfoInPython( + shortName, + info.schema(), + info.options(), + isTruncate, + isStreaming = true) + + pythonDataSourceWriter = writeInfo.writer + + new PythonStreamingWriterFactory(ds.source, writeInfo.func, info.schema(), jobArtifactUUID) + } + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + pythonStreamingSinkCommitRunner.commitOrAbort(messages, epochId, false) + } + + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + pythonStreamingSinkCommitRunner.commitOrAbort(messages, epochId, true) + } +} + +class PythonStreamingWriterFactory( + source: UserDefinedPythonDataSource, + pickledWriteFunc: Array[Byte], + inputSchema: StructType, + jobArtifactUUID: Option[String]) + extends PythonBatchWriterFactory(source, pickledWriteFunc, inputSchema, jobArtifactUUID) + with StreamingDataWriterFactory { + override def createWriter( + partitionId: Int, + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { + createWriter(partitionId, taskId) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala index 6bea97795a35..0476650a60bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2.python import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} -import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, TRUNCATE} +import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.types.StructType @@ -32,7 +32,7 @@ class PythonTable( override def name(): String = shortName override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( - BATCH_READ, BATCH_WRITE, TRUNCATE) + BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new PythonScanBuilder(ds, shortName, outputSchema, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala index a10a18e43f64..447221715264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.datasources.v2.python import org.apache.spark.JobArtifactSet import org.apache.spark.sql.connector.metric.CustomMetric -import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.connector.write.{BatchWrite, _} +import org.apache.spark.sql.connector.write.streaming.StreamingWrite class PythonWrite( @@ -32,12 +33,18 @@ class PythonWrite( override def toBatch: BatchWrite = new PythonBatchWrite(ds, shortName, info, isTruncate) + override def toStreaming: StreamingWrite = + new PythonStreamingWrite(ds, shortName, info, isTruncate) + override def description: String = "(Python)" override def supportedCustomMetrics(): Array[CustomMetric] = ds.source.createPythonMetrics() } +/** + * A [[BatchWrite]] for python data source writing. Responsible for generating the writer factory. + * */ class PythonBatchWrite( ds: PythonDataSourceV2, shortName: String, @@ -56,7 +63,8 @@ class PythonBatchWrite( shortName, info.schema(), info.options(), - isTruncate) + isTruncate, + isStreaming = false) pythonDataSourceWriter = writeInfo.writer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index 9b8219c4dc2d..0586d1fd4bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -81,13 +81,15 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, inputSchema: StructType, options: CaseInsensitiveStringMap, - overwrite: Boolean): PythonDataSourceWriteInfo = { + overwrite: Boolean, + isStreaming: Boolean): PythonDataSourceWriteInfo = { new UserDefinedPythonDataSourceWriteRunner( dataSourceCls, provider, inputSchema, options.asCaseSensitiveMap().asScala.toMap, - overwrite).runInPython() + overwrite, + isStreaming).runInPython() } /** @@ -369,7 +371,8 @@ private class UserDefinedPythonDataSourceWriteRunner( provider: String, inputSchema: StructType, options: Map[String, String], - overwrite: Boolean) extends PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) { + overwrite: Boolean, + isStreaming: Boolean) extends PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) { override val workerModule: String = "pyspark.sql.worker.write_into_data_source" @@ -395,6 +398,8 @@ private class UserDefinedPythonDataSourceWriteRunner( // Send the `overwrite` flag dataOut.writeBoolean(overwrite) + + dataOut.writeBoolean(isStreaming) } override protected def receiveFromPython( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index be25cd0dda3b..1db03c5d816f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2 import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -395,6 +396,10 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } else { None } + provider match { + case p: PythonDataSourceV2 => p.setShortName(source) + case _ => + } val table = DataSourceV2Utils.getTableFromProvider( provider, dsOptions, userSpecifiedSchema = outputSchema) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala index f022e353edd7..42eaa492be73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala @@ -16,15 +16,24 @@ */ package org.apache.spark.sql.execution.python +import java.io.File + +import scala.concurrent.duration._ + import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, shouldTestPandasUDFs} import org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, PythonMicroBatchStream, PythonStreamingSourceOffset} +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { + import testImplicits._ + + val waitTimeout = 15.seconds + protected def simpleDataStreamReaderScript: String = """ |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition @@ -65,6 +74,40 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { | yield (2, partition.value) |""".stripMargin + protected def simpleDataStreamWriterScript: String = + s""" + |import json + |import uuid + |import os + |from pyspark import TaskContext + |from pyspark.sql.datasource import DataSource, DataSourceStreamWriter + |from pyspark.sql.datasource import WriterCommitMessage + | + |class SimpleDataSourceStreamWriter(DataSourceStreamWriter): + | def __init__(self, options, overwrite): + | self.options = options + | self.overwrite = overwrite + | + | def write(self, iterator): + | context = TaskContext.get() + | partition_id = context.partitionId() + | path = self.options.get("path") + | assert path is not None + | output_path = os.path.join(path, f"{partition_id}.json") + | cnt = 0 + | mode = "w" if self.overwrite else "a" + | with open(output_path, mode) as file: + | for row in iterator: + | file.write(json.dumps(row.asDict()) + "\\n") + | return WriterCommitMessage() + | + |class SimpleDataSource(DataSource): + | def schema(self) -> str: + | return "id INT" + | def streamWriter(self, schema, overwrite): + | return SimpleDataSourceStreamWriter(self.options, overwrite) + |""".stripMargin + private val errorDataSourceName = "ErrorDataSource" test("simple data stream source") { @@ -230,4 +273,220 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { stream => stream.commit(offset) } } + + Seq("append", "complete").foreach { mode => + test(s"data source stream write - $mode mode") { + assume(shouldTestPandasUDFs) + val dataSource = + createUserDefinedPythonDataSource(dataSourceName, simpleDataStreamWriterScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val inputData = MemoryStream[Int] + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + checkpointDir.mkdir() + val outputDir = new File(path, "output") + outputDir.mkdir() + val streamDF = if (mode == "append") { + inputData.toDF() + } else { + // Complete mode only supports stateful aggregation + inputData.toDF() + .groupBy("value").count() + } + def resultDf: DataFrame = spark.read.format("json") + .load(outputDir.getAbsolutePath) + val q = streamDF + .writeStream + .format(dataSourceName) + .outputMode(mode) + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + + inputData.addData(1, 2, 3) + eventually(timeout(waitTimeout)) { + if (mode == "append") { + checkAnswer( + resultDf, + Seq(Row(1), Row(2), Row(3))) + } else { + checkAnswer( + resultDf.select("value", "count"), + Seq(Row(1, 1), Row(2, 1), Row(3, 1))) + } + } + + inputData.addData(1, 4) + eventually(timeout(waitTimeout)) { + if (mode == "append") { + checkAnswer( + resultDf, + Seq(Row(1), Row(2), Row(3), Row(4), Row(1))) + } else { + checkAnswer( + resultDf.select("value", "count"), + Seq(Row(1, 2), Row(2, 1), Row(3, 1), Row(4, 1))) + } + } + + q.stop() + q.awaitTermination() + assert(q.exception.isEmpty) + } + } + } + + test("streaming sink write commit and abort") { + assume(shouldTestPandasUDFs) + // The data source write the number of rows and partitions into batchId.json in + // the output directory in commit() function. If aborting a microbatch, it writes + // batchId.txt into output directory. + val dataSourceScript = + s""" + |import json + |import os + |from dataclasses import dataclass + |from pyspark import TaskContext + |from pyspark.sql.datasource import DataSource, DataSourceStreamWriter, WriterCommitMessage + | + |@dataclass + |class SimpleCommitMessage(WriterCommitMessage): + | partition_id: int + | count: int + | + |class SimpleDataSourceStreamWriter(DataSourceStreamWriter): + | def __init__(self, options): + | self.options = options + | self.path = self.options.get("path") + | assert self.path is not None + | + | def write(self, iterator): + | context = TaskContext.get() + | partition_id = context.partitionId() + | cnt = 0 + | for row in iterator: + | if row.value > 50: + | raise Exception("invalid value") + | cnt += 1 + | return SimpleCommitMessage(partition_id=partition_id, count=cnt) + | + | def commit(self, messages, batchId) -> None: + | status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages)) + | + | with open(os.path.join(self.path, f"{batchId}.json"), "a") as file: + | file.write(json.dumps(status) + "\\n") + | + | def abort(self, messages, batchId) -> None: + | with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file: + | file.write(f"failed in batch {batchId}") + | + |class SimpleDataSource(DataSource): + | def streamWriter(self, schema, overwrite): + | return SimpleDataSourceStreamWriter(self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val inputData = MemoryStream[Int](numPartitions = 3) + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + checkpointDir.mkdir() + val outputDir = new File(path, "output") + outputDir.mkdir() + val q = inputData.toDF() + .writeStream + .format(dataSourceName) + .outputMode("append") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + + def metadataDf: DataFrame = spark.read.format("json") + .load(outputDir.getAbsolutePath) + + // Batch 0-2 should succeed and json commit files are written. + inputData.addData(1 to 30) + eventually(timeout(waitTimeout)) { + checkAnswer(metadataDf, Seq(Row(3, 30))) + } + + inputData.addData(31 to 50) + eventually(timeout(waitTimeout)) { + checkAnswer(metadataDf, Seq(Row(3, 30), Row(3, 20))) + } + + // Write and commit an empty batch. + inputData.addData(Seq.empty) + eventually(timeout(waitTimeout)) { + checkAnswer(metadataDf, Seq(Row(3, 30), Row(3, 20), Row(3, 0))) + } + + // The sink throws exception when encountering value > 50 in batch 3. + // The streamWriter will write error message in 3.txt during abort(). + inputData.addData(51 to 100) + eventually(timeout(waitTimeout)) { + checkAnswer( + spark.read.text(outputDir.getAbsolutePath + "/3.txt"), + Seq(Row("failed in batch 3"))) + } + + q.stop() + assert(q.exception.get.message.contains("invalid value")) + } + } + + test("python streaming sink: invalid write mode") { + assume(shouldTestPandasUDFs) + // The data source write the number of rows and partitions into batchId.json in + // the output directory in commit() function. If aborting a microbatch, it writes + // batchId.txt into output directory. + + val dataSource = createUserDefinedPythonDataSource(dataSourceName, simpleDataStreamWriterScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + checkpointDir.mkdir() + val outputDir = new File(path, "output") + outputDir.mkdir() + + def runQuery(mode: String): Unit = { + val inputData = MemoryStream[Int] + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + checkpointDir.mkdir() + val outputDir = new File(path, "output") + outputDir.mkdir() + val q = inputData.toDF() + .writeStream + .format(dataSourceName) + .outputMode(mode) + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + q.stop() + q.awaitTermination() + } + } + + runQuery("append") + runQuery("update") + + // Complete mode is not supported for stateless query. + checkError( + exception = intercept[AnalysisException] { + runQuery("complete") + }, + errorClass = "_LEGACY_ERROR_TEMP_3102", + parameters = Map( + "msg" -> ("Complete output mode not supported when there are no streaming aggregations" + + " on streaming DataFrames/Datasets"))) + + // Query should fail in planning with "invalid" mode. + val error2 = intercept[IllegalArgumentException] { + runQuery("invalid") + } + assert(error2.getMessage.contains("invalid")) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org