This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch WIP-python-data-source-admission-control-trigger-availablenow-change-the-method-signature in repository https://gitbox.apache.org/repos/asf/spark.git
commit 9b0a9b34b071bc87590941a6d4c9cd7c339bc0eb Author: Jungtaek Lim <[email protected]> AuthorDate: Mon Jan 19 14:58:48 2026 +0900 WIP python data source Trigger.AvailableNow --- python/pyspark/sql/datasource.py | 55 ++++++++++ python/pyspark/sql/datasource_internal.py | 21 +++- .../streaming/python_streaming_source_runner.py | 70 ++++++++++++- .../v2/python/PythonMicroBatchStream.scala | 64 ++++++++++-- .../datasources/v2/python/PythonScan.scala | 19 +++- .../streaming/PythonStreamingSourceRunner.scala | 58 +++++++++++ .../streaming/PythonStreamingDataSourceSuite.scala | 116 +++++++++++++++++---- 7 files changed, 369 insertions(+), 34 deletions(-) diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index f1908180a3ba..854f67217acf 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -909,6 +909,61 @@ class SimpleDataSourceStreamReader(ABC): ... +class ReadLimit(ABC): + pass + + +class ReadAllAvailable(ReadLimit): + def __init__(self): + pass + + +class SupportsAdmissionControl(ABC): + @abstractmethod + def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict: + """ + FIXME: docstring needed + + /** + * Returns the most recent offset available given a read limit. The start offset can be used + * to figure out how much new data should be read given the limit. Users should implement this + * method instead of latestOffset for a MicroBatchStream or getOffset for Source. + * <p> + * When this method is called on a `Source`, the source can return `null` if there is no + * data to process. In addition, for the very first micro-batch, the `startOffset` will be + * null as well. + * <p> + * When this method is called on a MicroBatchStream, the `startOffset` will be `initialOffset` + * for the very first micro-batch. The source can return `null` if there is no data to process. + */ + """ + pass + + +class SupportsTriggerAvailableNow(ABC): + @abstractmethod + def prepareForTriggerAvailableNow(self) -> None: + """ + FIXME: docstring needed + + /** + * This will be called at the beginning of streaming queries with Trigger.AvailableNow, to let the + * source record the offset for the current latest data at the time (a.k.a the target offset for + * the query). The source will behave as if there is no new data coming in after the target + * offset, i.e., the source will not return an offset higher than the target offset when + * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called. + * <p> + * Note that there is an exception on the first uncommitted batch after a restart, where the end + * offset is not derived from the current latest offset. Sources need to take special + * considerations if wanting to assert such relation. One possible way is to have an internal + * flag in the source to indicate whether it is Trigger.AvailableNow, set the flag in this method, + * and record the target offset in the first call of + * {@link #latestOffset(Offset, ReadLimit) latestOffset}. + */ + """ + pass + + class DataSourceWriter(ABC): """ A base class for data source writers. Data source writers are responsible for saving diff --git a/python/pyspark/sql/datasource_internal.py b/python/pyspark/sql/datasource_internal.py index 6df0be4192ec..8b0fa5eb0a5f 100644 --- a/python/pyspark/sql/datasource_internal.py +++ b/python/pyspark/sql/datasource_internal.py @@ -25,7 +25,10 @@ from pyspark.sql.datasource import ( DataSource, DataSourceStreamReader, InputPartition, + ReadAllAvailable, + ReadLimit, SimpleDataSourceStreamReader, + SupportsAdmissionControl, ) from pyspark.sql.types import StructType from pyspark.errors import PySparkNotImplementedError @@ -56,7 +59,7 @@ class PrefetchedCacheEntry: self.iterator = iterator -class _SimpleStreamReaderWrapper(DataSourceStreamReader): +class _SimpleStreamReaderWrapper(DataSourceStreamReader, SupportsAdmissionControl): """ A private class that wrap :class:`SimpleDataSourceStreamReader` in prefetch and cache pattern, so that :class:`SimpleDataSourceStreamReader` can integrate with streaming engine like an @@ -97,6 +100,22 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader): self.current_offset = end return end + def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict: + if self.current_offset is None: + assert start != None, "start offset should not be None" + self.current_offset = start + else: + assert self.current_offset == start, ("start offset does not match current offset. " + f"current: {self.current_offset}, start: {start}") + + assert isinstance(readLimit, ReadAllAvailable), ("simple stream reader does not " + "support read limit") + + (iter, end) = self.simple_reader.read(self.current_offset) + self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter)) + self.current_offset = end + return end + def commit(self, end: dict) -> None: if self.current_offset is None: self.current_offset = end diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index ab988eb714cc..54bf12843232 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -28,7 +28,13 @@ from pyspark.serializers import ( write_with_length, SpecialLengths, ) -from pyspark.sql.datasource import DataSource, DataSourceStreamReader +from pyspark.sql.datasource import ( + DataSource, + DataSourceStreamReader, + ReadAllAvailable, + SupportsAdmissionControl, + SupportsTriggerAvailableNow +) from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, _streamReader from pyspark.sql.pandas.serializers import ArrowStreamSerializer from pyspark.sql.types import ( @@ -51,11 +57,17 @@ INITIAL_OFFSET_FUNC_ID = 884 LATEST_OFFSET_FUNC_ID = 885 PARTITIONS_FUNC_ID = 886 COMMIT_FUNC_ID = 887 +CHECK_SUPPORTED_FEATURES_ID = 888 +PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889 +LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890 PREFETCHED_RECORDS_NOT_FOUND = 0 NON_EMPTY_PYARROW_RECORD_BATCHES = 1 EMPTY_PYARROW_RECORD_BATCHES = 2 +SUPPORTS_ADMISSION_CONTROL = 1 +SUPPORTS_TRIGGER_AVAILABLE_NOW = 1 << 1 + def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: offset = reader.initialOffset() @@ -116,6 +128,56 @@ def send_batch_func( write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile) +def check_support_func(reader, outfile): + support_flags = 0 + if isinstance(reader, _SimpleStreamReaderWrapper): + # We consider the method of `read` in simple_reader to already have admission control + # into it. + support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW + if isinstance(reader.simple_reader, SupportsTriggerAvailableNow): + support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW + else: + if isinstance(reader, SupportsAdmissionControl): + support_flags |= SUPPORTS_ADMISSION_CONTROL + if isinstance(reader, SupportsTriggerAvailableNow): + support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW + write_int(support_flags, outfile) + + +def prepare_for_trigger_available_now_func(reader, outfile): + if isinstance(reader, _SimpleStreamReaderWrapper): + if isinstance(reader.simple_reader, SupportsTriggerAvailableNow): + reader.simple_reader.prepareForTriggerAvailableNow() + else: + # FIXME: code for not supported? or should it be assertion? + raise Exception("prepareForTriggerAvailableNow is not supported by the " + "underlying simple reader.") + else: + if isinstance(reader, SupportsTriggerAvailableNow): + reader.prepareForTriggerAvailableNow() + else: + # FIXME: code for not supported? or should it be assertion? + raise Exception("prepareForTriggerAvailableNow is not supported by the " + "stream reader.") + write_int(0, outfile) + + +def latest_offset_admission_control_func(reader, infile, outfile): + start_offset_dict = json.loads(utf8_deserializer.loads(infile)) + + limit_type = read_int(infile) + if limit_type == 0: + # ReadAllAvailable + limit = ReadAllAvailable() + else: + # FIXME: raise error + # FIXME: code for not supported? + raise Exception("Only ReadAllAvailable is supported for latestOffsetAdmissionControl.") + + offset = reader.latestOffset(start_offset_dict, limit) + write_with_length(json.dumps(offset).encode("utf-8"), outfile) + + def main(infile: IO, outfile: IO) -> None: try: check_python_version(infile) @@ -176,6 +238,12 @@ def main(infile: IO, outfile: IO) -> None: ) elif func_id == COMMIT_FUNC_ID: commit_func(reader, infile, outfile) + elif func_id == CHECK_SUPPORTED_FEATURES_ID: + check_support_func(reader, outfile) + elif func_id == PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID: + prepare_for_trigger_available_now_func(reader, outfile) + elif func_id == LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID: + latest_offset_admission_control_func(reader, infile, outfile) else: raise IllegalArgumentException( errorClass="UNSUPPORTED_OPERATION", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala index 50ea7616061c..e4fef9e6763c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.datasources.v2.python import org.apache.spark.SparkEnv +import org.apache.spark.api.python.PythonFunction import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} -import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, MicroBatchStream, Offset} +import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, MicroBatchStream, Offset, ReadLimit, SupportsAdmissionControl, SupportsTriggerAvailableNow} import org.apache.spark.sql.execution.datasources.v2.python.PythonMicroBatchStream.nextStreamId import org.apache.spark.sql.execution.python.streaming.PythonStreamingSourceRunner import org.apache.spark.sql.types.StructType @@ -32,14 +33,12 @@ class PythonMicroBatchStream( ds: PythonDataSourceV2, shortName: String, outputSchema: StructType, - options: CaseInsensitiveStringMap + options: CaseInsensitiveStringMap, + runner: PythonStreamingSourceRunner ) extends MicroBatchStream with Logging with AcceptsLatestSeenOffset { - private def createDataSourceFunc = - ds.source.createPythonFunction( - ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)).dataSource) private val streamId = nextStreamId private var nextBlockId = 0L @@ -49,10 +48,6 @@ class PythonMicroBatchStream( // from python to JVM. private var cachedInputPartition: Option[(String, String, PythonStreamingInputPartition)] = None - private val runner: PythonStreamingSourceRunner = - new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema) - runner.init() - override def initialOffset(): Offset = PythonStreamingSourceOffset(runner.initialOffset()) override def latestOffset(): Offset = PythonStreamingSourceOffset(runner.latestOffset()) @@ -110,10 +105,61 @@ class PythonMicroBatchStream( override def deserializeOffset(json: String): Offset = PythonStreamingSourceOffset(json) } +class PythonMicroBatchStreamWithAdmissionControl( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap, + runner: PythonStreamingSourceRunner) + extends PythonMicroBatchStream(ds, shortName, outputSchema, options, runner) + with SupportsAdmissionControl { + + override def latestOffset(): Offset = { + throw new IllegalStateException("latestOffset without parameters is not expected to be " + + "called. Please use latestOffset(startOffset: Offset, limit: ReadLimit) instead.") + } + + override def latestOffset(startOffset: Offset, limit: ReadLimit): Offset = { + PythonStreamingSourceOffset(runner.latestOffset(startOffset, limit)) + } +} + +class PythonMicroBatchStreamWithTriggerAvailableNow( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap, + runner: PythonStreamingSourceRunner) + extends PythonMicroBatchStreamWithAdmissionControl(ds, shortName, outputSchema, options, runner) + with SupportsTriggerAvailableNow { + + override def prepareForTriggerAvailableNow(): Unit = { + runner.prepareForTriggerAvailableNow() + } +} + object PythonMicroBatchStream { private var currentId = 0 def nextStreamId: Int = synchronized { currentId = currentId + 1 currentId } + + def createPythonStreamingSourceRunner( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap): PythonStreamingSourceRunner = { + + // Below methods were called during the construction of PythonMicroBatchStream, so there is no + // timing/sequencing issue of calling them in here. + def createDataSourceFunc: PythonFunction = + ds.source.createPythonFunction( + ds.getOrCreateDataSourceInPython( + shortName, + options, + Some(outputSchema)).dataSource) + + new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala index a133c40cde60..9e3effe7d441 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala @@ -35,8 +35,23 @@ class PythonScan( ) extends Scan with SupportsMetadata { override def toBatch: Batch = new PythonBatch(ds, shortName, outputSchema, options) - override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = - new PythonMicroBatchStream(ds, shortName, outputSchema, options) + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { + val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( + ds, shortName, outputSchema, options) + runner.init() + + val supportedFeatures = runner.checkSupportedFeatures() + + if (supportedFeatures.triggerAvailableNow) { + new PythonMicroBatchStreamWithTriggerAvailableNow( + ds, shortName, outputSchema, options, runner) + } else if (supportedFeatures.admissionControl) { + new PythonMicroBatchStreamWithAdmissionControl( + ds, shortName, outputSchema, options, runner) + } else { + new PythonMicroBatchStream(ds, shortName, outputSchema, options, runner) + } + } override def description: String = "(Python)" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala index 270d816e9bd9..36e4d09041ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.LogKeys.PYTHON_EXEC import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.streaming.{Offset, ReadAllAvailable, ReadLimit} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -46,11 +47,16 @@ object PythonStreamingSourceRunner { val LATEST_OFFSET_FUNC_ID = 885 val PARTITIONS_FUNC_ID = 886 val COMMIT_FUNC_ID = 887 + val CHECK_SUPPORTED_FEATURES_ID = 888 + val PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889 + val LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890 // Status code for JVM to decide how to receive prefetched record batches // for simple stream reader. val PREFETCHED_RECORDS_NOT_FOUND = 0 val NON_EMPTY_PYARROW_RECORD_BATCHES = 1 val EMPTY_PYARROW_RECORD_BATCHES = 2 + + case class SupportedFeatures(admissionControl: Boolean, triggerAvailableNow: Boolean) } /** @@ -129,6 +135,34 @@ class PythonStreamingSourceRunner( } } + def checkSupportedFeatures(): SupportedFeatures = { + dataOut.writeInt(CHECK_SUPPORTED_FEATURES_ID) + dataOut.flush() + + val featureBits = dataIn.readInt() + if (featureBits == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "checkSupportedFeatures", msg) + } + val admissionControl = (featureBits | (1 << 0)) == 1 + val availableNow = (featureBits | (1 << 1)) == (1 << 1) + + SupportedFeatures(admissionControl, availableNow) + } + + def prepareForTriggerAvailableNow(): Unit = { + dataOut.writeInt(PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID) + dataOut.flush() + val status = dataIn.readInt() + // FIXME: code for not supported? + if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "prepareForTriggerAvailableNow", msg) + } + } + /** * Invokes latestOffset() function of the stream reader and receive the return value. */ @@ -144,6 +178,30 @@ class PythonStreamingSourceRunner( PythonWorkerUtils.readUTF(len, dataIn) } + def latestOffset(startOffset: Offset, limit: ReadLimit): String = { + dataOut.writeInt(LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID) + PythonWorkerUtils.writeUTF(startOffset.json, dataOut) + limit match { + case _: ReadAllAvailable => + dataOut.writeInt(0) + dataOut.flush() + + case _ => + // FIXME: Add support for other ReadLimit types + // throw QueryExecutionErrors.unsupportedReadLimitTypeError(limit.getClass.getName) + throw new UnsupportedOperationException("Unsupported ReadLimit type: " + + s"${limit.getClass.getName}") + } + dataOut.flush() + val len = dataIn.readInt() + if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "latestOffset", msg) + } + PythonWorkerUtils.readUTF(len, dataIn) + } + /** * Invokes initialOffset() function of the stream reader and receive the return value. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala index 0e33b6e55a43..330a0513d360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala @@ -24,7 +24,8 @@ import scala.concurrent.duration._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, shouldTestPandasUDFs} -import org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, PythonMicroBatchStream, PythonStreamingSourceOffset} +import org.apache.spark.sql.connector.read.streaming.ReadLimit +import org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, PythonMicroBatchStream, PythonMicroBatchStreamWithAdmissionControl, PythonStreamingSourceOffset} import org.apache.spark.sql.execution.python.PythonDataSourceSuiteBase import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, OffsetSeqLog} @@ -249,12 +250,18 @@ class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { pythonDs.setShortName("ErrorDataSource") def testMicroBatchStreamError(action: String, msg: String)( - func: PythonMicroBatchStream => Unit): Unit = { - val stream = new PythonMicroBatchStream( + func: PythonMicroBatchStreamWithAdmissionControl => Unit): Unit = { + val options = CaseInsensitiveStringMap.empty() + val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( + pythonDs, errorDataSourceName, inputSchema, options) + runner.init() + + val stream = new PythonMicroBatchStreamWithAdmissionControl( pythonDs, errorDataSourceName, inputSchema, - CaseInsensitiveStringMap.empty() + options, + runner ) val err = intercept[SparkException] { func(stream) @@ -277,16 +284,6 @@ class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { stream => stream.initialOffset() } - - // User don't need to implement latestOffset for SimpleDataSourceStreamReader. - // The latestOffset method of simple stream reader invokes initialOffset() and read() - // So the not implemented method is initialOffset. - testMicroBatchStreamError( - "latestOffset", - "[NOT_IMPLEMENTED] initialOffset is not implemented") { - stream => - stream.latestOffset() - } } test("read() method throw error in SimpleDataSourceStreamReader") { @@ -314,12 +311,18 @@ class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { pythonDs.setShortName("ErrorDataSource") def testMicroBatchStreamError(action: String, msg: String)( - func: PythonMicroBatchStream => Unit): Unit = { - val stream = new PythonMicroBatchStream( + func: PythonMicroBatchStreamWithAdmissionControl => Unit): Unit = { + val options = CaseInsensitiveStringMap.empty() + val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( + pythonDs, errorDataSourceName, inputSchema, options) + runner.init() + + val stream = new PythonMicroBatchStreamWithAdmissionControl( pythonDs, errorDataSourceName, inputSchema, - CaseInsensitiveStringMap.empty() + options, + runner ) val err = intercept[SparkException] { func(stream) @@ -337,7 +340,59 @@ class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { } testMicroBatchStreamError("latestOffset", "Exception: error reading available data") { stream => - stream.latestOffset() + stream.latestOffset(PythonStreamingSourceOffset("""{"partition": 0}"""), + ReadLimit.allAvailable()) + } + } + + test("SimpleDataSourceStreamReader with Trigger.AvailableNow") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + |from pyspark.sql.datasource import SimpleDataSourceStreamReader, SupportsTriggerAvailableNow + | + |class SimpleDataStreamReader(SimpleDataSourceStreamReader, SupportsTriggerAvailableNow): + | def initialOffset(self): + | return {"partition-1": 0} + | def read(self, start: dict): + | start_idx = start["partition-1"] + | end_offset = min(start_idx + 2, self.desired_end_offset) + | it = iter([(i, ) for i in range(start_idx, end_offset)]) + | return (it, {"partition-1": end_offset}) + | def readBetweenOffsets(self, start: dict, end: dict): + | start_idx = start["partition-1"] + | end_idx = end["partition-1"] + | return iter([(i, ) for i in range(start_idx, end_idx)]) + | def prepareForTriggerAvailableNow(self): + | self.desired_end_offset = 10 + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "id INT" + | def simpleStreamReader(self, schema): + | return SimpleDataStreamReader() + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + val outputDir = new File(path, "output") + val df = spark.readStream.format(dataSourceName).load() + val q = df.writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("json") + .trigger(Trigger.AvailableNow()) + .start(outputDir.getAbsolutePath) + q.awaitTermination(waitTimeout.toMillis) + val rowCount = spark.read.format("json").load(outputDir.getAbsolutePath).count() + assert(rowCount === 10) + checkAnswer( + spark.read.format("json").load(outputDir.getAbsolutePath), + (0 until rowCount.toInt).map(Row(_)) + ) } } @@ -459,11 +514,18 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { spark.dataSource.registerPython(dataSourceName, dataSource) val pythonDs = new PythonDataSourceV2 pythonDs.setShortName("SimpleDataSource") + + val options = CaseInsensitiveStringMap.empty() + val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( + pythonDs, dataSourceName, inputSchema, options) + runner.init() + val stream = new PythonMicroBatchStream( pythonDs, dataSourceName, inputSchema, - CaseInsensitiveStringMap.empty() + options, + runner ) var startOffset = stream.initialOffset() @@ -706,11 +768,17 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { def testMicroBatchStreamError(action: String, msg: String)( func: PythonMicroBatchStream => Unit): Unit = { + val options = CaseInsensitiveStringMap.empty() + val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( + pythonDs, dataSourceName, inputSchema, options) + runner.init() + val stream = new PythonMicroBatchStream( pythonDs, errorDataSourceName, inputSchema, - CaseInsensitiveStringMap.empty() + options, + runner ) val err = intercept[SparkException] { func(stream) @@ -767,11 +835,17 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { def testMicroBatchStreamError(action: String, msg: String)( func: PythonMicroBatchStream => Unit): Unit = { + val options = CaseInsensitiveStringMap.empty() + val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( + pythonDs, dataSourceName, inputSchema, options) + runner.init() + val stream = new PythonMicroBatchStream( pythonDs, errorDataSourceName, inputSchema, - CaseInsensitiveStringMap.empty() + options, + runner ) val err = intercept[SparkException] { func(stream) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
