This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch WIP-python-data-source-admission-control-trigger-availablenow-change-the-method-signature in repository https://gitbox.apache.org/repos/asf/spark.git
commit 35e4ec1714bf1f67008b03e0396fabae7a8b7ed8 Author: Jungtaek Lim <[email protected]> AuthorDate: Wed Jan 21 14:18:41 2026 +0900 address admission control and trigger availableNow for stream reader, built-in read limit, testing... --- python/pyspark/sql/datasource.py | 54 ------ python/pyspark/sql/datasource_internal.py | 46 ++++- python/pyspark/sql/streaming/datasource.py | 207 +++++++++++++++++++++ .../streaming/python_streaming_source_runner.py | 51 +++-- .../sql/tests/test_python_streaming_datasource.py | 76 ++++++++ .../v2/python/PythonMicroBatchStream.scala | 16 ++ .../streaming/PythonStreamingSourceRunner.scala | 59 ++++-- .../streaming/PythonStreamingDataSourceSuite.scala | 206 +++++++++++++++++++- 8 files changed, 625 insertions(+), 90 deletions(-) diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index 854f67217acf..6adfd79f2631 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -909,60 +909,6 @@ class SimpleDataSourceStreamReader(ABC): ... -class ReadLimit(ABC): - pass - - -class ReadAllAvailable(ReadLimit): - def __init__(self): - pass - - -class SupportsAdmissionControl(ABC): - @abstractmethod - def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict: - """ - FIXME: docstring needed - - /** - * Returns the most recent offset available given a read limit. The start offset can be used - * to figure out how much new data should be read given the limit. Users should implement this - * method instead of latestOffset for a MicroBatchStream or getOffset for Source. - * <p> - * When this method is called on a `Source`, the source can return `null` if there is no - * data to process. In addition, for the very first micro-batch, the `startOffset` will be - * null as well. - * <p> - * When this method is called on a MicroBatchStream, the `startOffset` will be `initialOffset` - * for the very first micro-batch. The source can return `null` if there is no data to process. - */ - """ - pass - - -class SupportsTriggerAvailableNow(ABC): - @abstractmethod - def prepareForTriggerAvailableNow(self) -> None: - """ - FIXME: docstring needed - - /** - * This will be called at the beginning of streaming queries with Trigger.AvailableNow, to let the - * source record the offset for the current latest data at the time (a.k.a the target offset for - * the query). The source will behave as if there is no new data coming in after the target - * offset, i.e., the source will not return an offset higher than the target offset when - * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called. - * <p> - * Note that there is an exception on the first uncommitted batch after a restart, where the end - * offset is not derived from the current latest offset. Sources need to take special - * considerations if wanting to assert such relation. One possible way is to have an internal - * flag in the source to indicate whether it is Trigger.AvailableNow, set the flag in this method, - * and record the target offset in the first call of - * {@link #latestOffset(Offset, ReadLimit) latestOffset}. - */ - """ - pass - class DataSourceWriter(ABC): """ diff --git a/python/pyspark/sql/datasource_internal.py b/python/pyspark/sql/datasource_internal.py index 8b0fa5eb0a5f..51ef94f9a325 100644 --- a/python/pyspark/sql/datasource_internal.py +++ b/python/pyspark/sql/datasource_internal.py @@ -19,16 +19,22 @@ import json import copy from itertools import chain -from typing import Iterator, List, Optional, Sequence, Tuple +from typing import Iterator, List, Optional, Sequence, Tuple, Type from pyspark.sql.datasource import ( DataSource, DataSourceStreamReader, InputPartition, + SimpleDataSourceStreamReader, +) +from pyspark.sql.streaming.datasource import ( ReadAllAvailable, ReadLimit, - SimpleDataSourceStreamReader, SupportsAdmissionControl, + CompositeReadLimit, + ReadMaxBytes, + ReadMaxRows, + ReadMinRows, ) from pyspark.sql.types import StructType from pyspark.errors import PySparkNotImplementedError @@ -91,14 +97,9 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader, SupportsAdmissionContro self.initial_offset = self.simple_reader.initialOffset() return self.initial_offset - def latestOffset(self) -> dict: - # when query start for the first time, use initial offset as the start offset. - if self.current_offset is None: - self.current_offset = self.initialOffset() - (iter, end) = self.simple_reader.read(self.current_offset) - self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter)) - self.current_offset = end - return end + def getDefaultReadLimit(self) -> ReadLimit: + # We do not consider providing different read limit on simple stream reader. + return ReadAllAvailable() def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict: if self.current_offset is None: @@ -163,3 +164,28 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader, SupportsAdmissionContro self, input_partition: SimpleInputPartition # type: ignore[override] ) -> Iterator[Tuple]: return self.simple_reader.readBetweenOffsets(input_partition.start, input_partition.end) + + +class ReadLimitRegistry: + def __init__(self) -> None: + self._registry: Dict[str, Type[ReadLimit]] = {} + self.__register(ReadAllAvailable.type_name(), ReadAllAvailable) + self.__register(ReadMinRows.type_name(), ReadMinRows) + self.__register(ReadMaxRows.type_name(), ReadMaxRows) + self.__register(ReadMaxBytes.type_name(), ReadMaxBytes) + self.__register(CompositeReadLimit.type_name(), CompositeReadLimit) + + def __register(self, type_name: str, read_limit_type: Type["ReadLimit"]) -> None: + if type_name in self._registry: + # FIXME: error class? + raise Exception(f"ReadLimit type '{type_name}' is already registered.") + + self._registry[type_name] = read_limit_type + + def get(self, type_name: str, params: dict) -> ReadLimit: + read_limit_type = self._registry[type_name] + if read_limit_type is None: + raise Exception("type_name '{}' is not registered.".format(type_name)) + params_without_type = params.copy() + del params_without_type["type"] + return read_limit_type.load(params_without_type) diff --git a/python/pyspark/sql/streaming/datasource.py b/python/pyspark/sql/streaming/datasource.py new file mode 100644 index 000000000000..82afbb83c583 --- /dev/null +++ b/python/pyspark/sql/streaming/datasource.py @@ -0,0 +1,207 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +from typing import List, Optional + +class ReadLimit(ABC): + @classmethod + @abstractmethod + def type_name(cls) -> str: + pass + + @classmethod + @abstractmethod + def load(cls, params: dict) -> "ReadLimit": + pass + + def dump(self) -> dict: + params = self._dump() + params.update({"type": self.type_name()}) + return params + + @abstractmethod + def _dump(self) -> dict: + pass + + +class ReadAllAvailable(ReadLimit): + @classmethod + def type_name(cls) -> str: + return "ReadAllAvailable" + + @classmethod + def load(cls, params: dict) -> "ReadAllAvailable": + return ReadAllAvailable() + + def _dump(self) -> dict: + return {} + + +class ReadMinRows(ReadLimit): + def __init__(self, min_rows: int) -> None: + self.min_rows = min_rows + + @classmethod + def type_name(cls) -> str: + return "ReadMinRows" + + @classmethod + def load(cls, params: dict) -> "ReadMinRows": + return ReadMinRows(params["min_rows"]) + + def _dump(self) -> dict: + return {"min_rows": self.min_rows} + + +class ReadMaxRows(ReadLimit): + def __init__(self, max_rows: int) -> None: + self.max_rows = max_rows + + @classmethod + def type_name(cls) -> str: + return "ReadMaxRows" + + @classmethod + def load(cls, params: dict) -> "ReadMaxRows": + return ReadMaxRows(params["max_rows"]) + + def _dump(self) -> dict: + return {"max_rows": self.max_rows} + + +class ReadMaxFiles(ReadLimit): + def __init__(self, max_files: int) -> None: + self.max_files = max_files + + @classmethod + def type_name(cls) -> str: + return "ReadMaxFiles" + + @classmethod + def load(cls, params: dict) -> "ReadMaxFiles": + return ReadMaxFiles(params["max_files"]) + + def _dump(self) -> dict: + return {"max_files": self.max_files} + + +class ReadMaxBytes(ReadLimit): + def __init__(self, max_bytes: int) -> None: + self.max_bytes = max_bytes + + @classmethod + def type_name(cls) -> str: + return "ReadMaxBytes" + + @classmethod + def load(cls, params: dict) -> "ReadMaxBytes": + return ReadMaxBytes(params["max_bytes"]) + + def _dump(self) -> dict: + return {"max_bytes": self.max_bytes} + + +class CompositeReadLimit(ReadLimit): + def __init__(self, readLimits: List[ReadLimit]) -> None: + self.readLimits = readLimits + + @classmethod + def type_name(cls) -> str: + return "CompositeReadLimit" + + @classmethod + def load(cls, params: dict) -> "CompositeReadLimit": + read_limits = [] + for rl_params in params["readLimits"]: + rl_type = rl_params["type"] + rl = READ_LIMIT_REGISTRY.get(rl_type, rl_params) + read_limits.append(rl) + return CompositeReadLimit(read_limits) + + def _dump(self) -> dict: + return {"readLimits": [rl.dump() for rl in self.readLimits]} + + +class SupportsAdmissionControl(ABC): + def getDefaultReadLimit(self) -> ReadLimit: + """ + FIXME: docstring needed + + /** + * Returns the read limits potentially passed to the data source through options when creating + * the data source. + */ + """ + return ReadAllAvailable() + + @abstractmethod + def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict: + """ + FIXME: docstring needed + + /** + * Returns the most recent offset available given a read limit. The start offset can be used + * to figure out how much new data should be read given the limit. Users should implement this + * method instead of latestOffset for a MicroBatchStream or getOffset for Source. + * <p> + * When this method is called on a `Source`, the source can return `null` if there is no + * data to process. In addition, for the very first micro-batch, the `startOffset` will be + * null as well. + * <p> + * When this method is called on a MicroBatchStream, the `startOffset` will be `initialOffset` + * for the very first micro-batch. The source can return `null` if there is no data to process. + */ + """ + pass + + def reportLatestOffset(self) -> Optional[dict]: + """ + FIXME: docstring needed + + /** + * Returns the most recent offset available. + * <p> + * The source can return `null`, if there is no data to process or the source does not support + * to this method. + */ + """ + return None + + +class SupportsTriggerAvailableNow(ABC): + @abstractmethod + def prepareForTriggerAvailableNow(self) -> None: + """ + FIXME: docstring needed + + /** + * This will be called at the beginning of streaming queries with Trigger.AvailableNow, to let the + * source record the offset for the current latest data at the time (a.k.a the target offset for + * the query). The source will behave as if there is no new data coming in after the target + * offset, i.e., the source will not return an offset higher than the target offset when + * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called. + * <p> + * Note that there is an exception on the first uncommitted batch after a restart, where the end + * offset is not derived from the current latest offset. Sources need to take special + * considerations if wanting to assert such relation. One possible way is to have an internal + * flag in the source to indicate whether it is Trigger.AvailableNow, set the flag in this method, + * and record the target offset in the first call of + * {@link #latestOffset(Offset, ReadLimit) latestOffset}. + */ + """ + pass diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index 54bf12843232..192f729a6fb9 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -31,11 +31,13 @@ from pyspark.serializers import ( from pyspark.sql.datasource import ( DataSource, DataSourceStreamReader, +) +from pyspark.sql.streaming.datasource import ( ReadAllAvailable, SupportsAdmissionControl, - SupportsTriggerAvailableNow + SupportsTriggerAvailableNow, ) -from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, _streamReader +from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, _streamReader, ReadLimitRegistry from pyspark.sql.pandas.serializers import ArrowStreamSerializer from pyspark.sql.types import ( _parse_datatype_json_string, @@ -60,6 +62,8 @@ COMMIT_FUNC_ID = 887 CHECK_SUPPORTED_FEATURES_ID = 888 PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889 LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890 +GET_DEFAULT_READ_LIMIT_FUNC_ID = 891 +REPORT_LATEST_OFFSET_FUNC_ID = 892 PREFETCHED_RECORDS_NOT_FOUND = 0 NON_EMPTY_PYARROW_RECORD_BATCHES = 1 @@ -68,9 +72,12 @@ EMPTY_PYARROW_RECORD_BATCHES = 2 SUPPORTS_ADMISSION_CONTROL = 1 SUPPORTS_TRIGGER_AVAILABLE_NOW = 1 << 1 +READ_LIMIT_REGISTRY = ReadLimitRegistry() + def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: offset = reader.initialOffset() + # raise Exception(f"Debug info for initial offset: offset: {offset}, json: {json.dumps(offset).encode('utf-8')}") write_with_length(json.dumps(offset).encode("utf-8"), outfile) @@ -165,19 +172,37 @@ def prepare_for_trigger_available_now_func(reader, outfile): def latest_offset_admission_control_func(reader, infile, outfile): start_offset_dict = json.loads(utf8_deserializer.loads(infile)) - limit_type = read_int(infile) - if limit_type == 0: - # ReadAllAvailable - limit = ReadAllAvailable() - else: - # FIXME: raise error - # FIXME: code for not supported? - raise Exception("Only ReadAllAvailable is supported for latestOffsetAdmissionControl.") + limit = json.loads(utf8_deserializer.loads(infile)) + limit_obj = READ_LIMIT_REGISTRY.get(limit["type"], limit) - offset = reader.latestOffset(start_offset_dict, limit) + offset = reader.latestOffset(start_offset_dict, limit_obj) write_with_length(json.dumps(offset).encode("utf-8"), outfile) +def get_default_read_limit_func(reader, outfile): + if isinstance(reader, SupportsAdmissionControl): + limit = reader.getDefaultReadLimit() + else: + limit = READ_ALL_AVAILABLE + + write_with_length(json.dumps(limit.dump()).encode("utf-8"), outfile) + + +def report_latest_offset_func(reader, outfile): + if isinstance(reader, _SimpleStreamReaderWrapper): + # We do not consider providing latest offset on simple stream reader. + write_int(0, outfile) + else: + if isinstance(reader, SupportsAdmissionControl): + offset = reader.reportLatestOffset() + if offset is None: + write_int(0, outfile) + else: + write_with_length(json.dumps(offset).encode("utf-8"), outfile) + else: + write_int(0, outfile) + + def main(infile: IO, outfile: IO) -> None: try: check_python_version(infile) @@ -244,6 +269,10 @@ def main(infile: IO, outfile: IO) -> None: prepare_for_trigger_available_now_func(reader, outfile) elif func_id == LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID: latest_offset_admission_control_func(reader, infile, outfile) + elif func_id == GET_DEFAULT_READ_LIMIT_FUNC_ID: + get_default_read_limit_func(reader, outfile) + elif func_id == REPORT_LATEST_OFFSET_FUNC_ID: + report_latest_offset_func(reader, outfile) else: raise IllegalArgumentException( errorClass="UNSUPPORTED_OPERATION", diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py index 1bb0c0895be9..b894a5fcfec6 100644 --- a/python/pyspark/sql/tests/test_python_streaming_datasource.py +++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py @@ -140,6 +140,53 @@ class BasePythonStreamingDataSourceTestsMixin: return TestDataSource + def _get_test_data_source_for_admission_control(self): + from pyspark.sql.streaming.datasource import ( + ReadAllAvailable, + ReadLimit, + ReadMaxRows, + SupportsAdmissionControl, + ) + + class TestDataStreamReader( + DataSourceStreamReader, + SupportsAdmissionControl + ): + def initialOffset(self): + return {"partition-1": 0} + + def getDefaultReadLimit(self): + return ReadMaxRows(2) + + def latestOffset(self, start: dict, readLimit: ReadLimit): + start_idx = start["partition-1"] + if isinstance(readLimit, ReadAllAvailable): + end_offset = start_idx + 10 + else: + assert isinstance(readLimit, ReadMaxRows), ("Expected ReadMaxRows read limit but got " + + str(type(readLimit))) + end_offset = start_idx + readLimit.max_rows + return {"partition-1": end_offset} + + def reportLatestOffset(self): + return {"partition-1": 1000000} + + def partitions(self, start: dict, end: dict): + start_index = start["partition-1"] + end_index = end["partition-1"] + return [InputPartition(i) for i in range(start_index, end_index)] + + def read(self, partition): + yield (partition.value,) + + class TestDataSource(DataSource): + def schema(self) -> str: + return "id INT" + def streamReader(self, schema): + return TestDataStreamReader() + + return TestDataSource + def test_stream_reader(self): self.spark.dataSource.register(self._get_test_data_source()) df = self.spark.readStream.format("TestDataSource").load() @@ -214,6 +261,35 @@ class BasePythonStreamingDataSourceTestsMixin: assertDataFrameEqual(df, expected_data) + def test_stream_reader_admission_control_trigger_once(self): + self.spark.dataSource.register(self._get_test_data_source_for_admission_control()) + df = self.spark.readStream.format("TestDataSource").load() + + def check_batch(df, batch_id): + assertDataFrameEqual(df, [Row(x) for x in range(10)]) + + q = df.writeStream.trigger(once=True).foreachBatch(check_batch).start() + q.awaitTermination() + self.assertIsNone(q.exception(), "No exception has to be propagated.") + self.assertTrue(q.recentProgress.length == 1) + self.assertTrue(q.lastProgress.numInputRows == 10) + self.assertTrue(q.lastProgress.sources[0].numInputRows == 10) + self.assertTrue(q.lastProgress.sources[0].latestOffset == """{"partition-1": 1000000}""") + + def test_stream_reader_admission_control_processing_time_trigger(self): + self.spark.dataSource.register(self._get_test_data_source_for_admission_control()) + df = self.spark.readStream.format("TestDataSource").load() + + def check_batch(df, batch_id): + assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) + + q = df.writeStream.foreachBatch(check_batch).start() + while len(q.recentProgress) < 10: + time.sleep(0.2) + q.stop() + q.awaitTermination() + self.assertIsNone(q.exception(), "No exception has to be propagated.") + def test_simple_stream_reader(self): class SimpleStreamReader(SimpleDataSourceStreamReader): def initialOffset(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala index e4fef9e6763c..319083a9d8bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala @@ -29,6 +29,8 @@ import org.apache.spark.storage.{PythonStreamBlockId, StorageLevel} case class PythonStreamingSourceOffset(json: String) extends Offset +case class PythonStreamingSourceReadLimit(json: String) extends ReadLimit + class PythonMicroBatchStream( ds: PythonDataSourceV2, shortName: String, @@ -122,6 +124,20 @@ class PythonMicroBatchStreamWithAdmissionControl( override def latestOffset(startOffset: Offset, limit: ReadLimit): Offset = { PythonStreamingSourceOffset(runner.latestOffset(startOffset, limit)) } + + override def getDefaultReadLimit: ReadLimit = { + val readLimitJson = runner.getDefaultReadLimit() + PythonStreamingSourceReadLimit(readLimitJson) + } + + override def reportLatestOffset(): Offset = { + val offsetJson = runner.reportLatestOffset() + if (offsetJson == null) { + null + } else { + PythonStreamingSourceOffset(offsetJson) + } + } } class PythonMicroBatchStreamWithTriggerAvailableNow( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala index 36e4d09041ef..836962735d08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala @@ -20,25 +20,23 @@ package org.apache.spark.sql.execution.python.streaming import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} import java.nio.channels.Channels - import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ - import org.apache.arrow.vector.ipc.ArrowStreamReader - import org.apache.spark.SparkEnv import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerFactory, PythonWorkerUtils, SpecialLengths} -import org.apache.spark.internal.{Logging, LogKeys} +import org.apache.spark.internal.{LogKeys, Logging} import org.apache.spark.internal.LogKeys.PYTHON_EXEC import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.streaming.{Offset, ReadAllAvailable, ReadLimit} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.datasources.v2.python.PythonStreamingSourceReadLimit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} object PythonStreamingSourceRunner { // When the python process for python_streaming_source_runner receives one of the @@ -50,11 +48,14 @@ object PythonStreamingSourceRunner { val CHECK_SUPPORTED_FEATURES_ID = 888 val PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889 val LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890 + val GET_DEFAULT_READ_LIMIT_FUNC_ID = 891 + val REPORT_LATEST_OFFSET_FUNC_ID = 892 // Status code for JVM to decide how to receive prefetched record batches // for simple stream reader. val PREFETCHED_RECORDS_NOT_FOUND = 0 val NON_EMPTY_PYARROW_RECORD_BATCHES = 1 val EMPTY_PYARROW_RECORD_BATCHES = 2 + val READ_ALL_AVAILABLE_JSON = """{"type": "ReadAllAvailable"}""" case class SupportedFeatures(admissionControl: Boolean, triggerAvailableNow: Boolean) } @@ -145,17 +146,48 @@ class PythonStreamingSourceRunner( throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( action = "checkSupportedFeatures", msg) } - val admissionControl = (featureBits | (1 << 0)) == 1 - val availableNow = (featureBits | (1 << 1)) == (1 << 1) + val admissionControl = (featureBits & (1 << 0)) == 1 + val availableNow = (featureBits & (1 << 1)) == (1 << 1) SupportedFeatures(admissionControl, availableNow) } + def getDefaultReadLimit(): String = { + dataOut.writeInt(GET_DEFAULT_READ_LIMIT_FUNC_ID) + dataOut.flush() + + val len = dataIn.readInt() + if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "getDefaultReadLimit", msg) + } + + PythonWorkerUtils.readUTF(len, dataIn) + } + + def reportLatestOffset(): String = { + dataOut.writeInt(REPORT_LATEST_OFFSET_FUNC_ID) + dataOut.flush() + + val len = dataIn.readInt() + if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "reportLatestOffset", msg) + } + + if (len == 0) { + null + } else { + PythonWorkerUtils.readUTF(len, dataIn) + } + } + def prepareForTriggerAvailableNow(): Unit = { dataOut.writeInt(PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID) dataOut.flush() val status = dataIn.readInt() - // FIXME: code for not supported? if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) { val msg = PythonWorkerUtils.readUTF(dataIn) throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( @@ -183,12 +215,15 @@ class PythonStreamingSourceRunner( PythonWorkerUtils.writeUTF(startOffset.json, dataOut) limit match { case _: ReadAllAvailable => - dataOut.writeInt(0) - dataOut.flush() + // NOTE: we need to use a constant here to match the Python side given the engine can + // decide by itself to use ReadAllAvailable and the Python side version of the instance + // isn't available here. + PythonWorkerUtils.writeUTF(READ_ALL_AVAILABLE_JSON, dataOut) + + case p: PythonStreamingSourceReadLimit => + PythonWorkerUtils.writeUTF(p.json, dataOut) case _ => - // FIXME: Add support for other ReadLimit types - // throw QueryExecutionErrors.unsupportedReadLimitTypeError(limit.getClass.getName) throw new UnsupportedOperationException("Unsupported ReadLimit type: " + s"${limit.getClass.getName}") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala index 330a0513d360..eb5f7475636f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala @@ -350,7 +350,8 @@ class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource - |from pyspark.sql.datasource import SimpleDataSourceStreamReader, SupportsTriggerAvailableNow + |from pyspark.sql.datasource import SimpleDataSourceStreamReader + |from pyspark.sql.streaming.datasource import SupportsTriggerAvailableNow | |class SimpleDataStreamReader(SimpleDataSourceStreamReader, SupportsTriggerAvailableNow): | def initialOffset(self): @@ -673,6 +674,205 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { q.awaitTermination() } + private val testAdmissionControlScript = + s""" + |from pyspark.sql.datasource import DataSource + |from pyspark.sql.datasource import ( + | DataSourceStreamReader, + | InputPartition, + |) + |from pyspark.sql.streaming.datasource import ( + | ReadAllAvailable, + | ReadLimit, + | ReadMaxRows, + | SupportsAdmissionControl, + |) + | + |class TestDataStreamReader( + | DataSourceStreamReader, + | SupportsAdmissionControl + |): + | def initialOffset(self): + | return {"partition-1": 0} + | def getDefaultReadLimit(self): + | return ReadMaxRows(2) + | def latestOffset(self, start: dict, readLimit: ReadLimit): + | start_idx = start["partition-1"] + | if isinstance(readLimit, ReadAllAvailable): + | end_offset = start_idx + 10 + | else: + | assert isinstance(readLimit, ReadMaxRows), ("Expected ReadMaxRows read limit but got " + | + str(type(readLimit))) + | end_offset = start_idx + readLimit.max_rows + | return {"partition-1": end_offset} + | def reportLatestOffset(self): + | return {"partition-1": 1000000} + | def partitions(self, start: dict, end: dict): + | start_index = start["partition-1"] + | end_index = end["partition-1"] + | return [InputPartition(i) for i in range(start_index, end_index)] + | def read(self, partition): + | yield (partition.value,) + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "id INT" + | def streamReader(self, schema): + | return TestDataStreamReader() + |""".stripMargin + + private val testAvailableNowScript = + s""" + |from pyspark.sql.datasource import DataSource + |from pyspark.sql.datasource import ( + | DataSourceStreamReader, + | InputPartition, + |) + |from pyspark.sql.streaming.datasource import ( + | ReadAllAvailable, + | ReadLimit, + | ReadMaxRows, + | SupportsAdmissionControl, + | SupportsTriggerAvailableNow + |) + | + |class TestDataStreamReader( + | DataSourceStreamReader, + | SupportsAdmissionControl, + | SupportsTriggerAvailableNow + |): + | def initialOffset(self): + | return {"partition-1": 0} + | def getDefaultReadLimit(self): + | return ReadMaxRows(2) + | def latestOffset(self, start: dict, readLimit: ReadLimit): + | start_idx = start["partition-1"] + | if isinstance(readLimit, ReadAllAvailable): + | end_offset = start_idx + 5 + | else: + | assert isinstance(readLimit, ReadMaxRows), ("Expected ReadMaxRows read limit but got " + | + str(type(readLimit))) + | end_offset = start_idx + readLimit.max_rows + | end_offset = min(end_offset, self.desired_end_offset) + | return {"partition-1": end_offset} + | def reportLatestOffset(self): + | return {"partition-1": 1000000} + | def prepareForTriggerAvailableNow(self): + | self.desired_end_offset = 10 + | def partitions(self, start: dict, end: dict): + | start_index = start["partition-1"] + | end_index = end["partition-1"] + | return [InputPartition(i) for i in range(start_index, end_index)] + | def read(self, partition): + | yield (partition.value,) + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "id INT" + | def streamReader(self, schema): + | return TestDataStreamReader() + |""".stripMargin + + test("DataSourceStreamReader with Admission Control, Trigger.Once") { + assume(shouldTestPandasUDFs) + val dataSource = createUserDefinedPythonDataSource(dataSourceName, testAdmissionControlScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + val outputDir = new File(path, "output") + val df = spark.readStream.format(dataSourceName).load() + val q = df.writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("json") + // Use Trigger.Once here by intention to test read with admission control. + .trigger(Trigger.Once()) + .start(outputDir.getAbsolutePath) + q.awaitTermination(waitTimeout.toMillis) + + assert(q.recentProgress.length === 1) + assert(q.lastProgress.numInputRows === 10) + assert(q.lastProgress.sources(0).numInputRows === 10) + assert(q.lastProgress.sources(0).latestOffset === """{"partition-1": 1000000}""") + + val rowCount = spark.read.format("json").load(outputDir.getAbsolutePath).count() + assert(rowCount === 10) + checkAnswer( + spark.read.format("json").load(outputDir.getAbsolutePath), + (0 until rowCount.toInt).map(Row(_)) + ) + } + } + + test("DataSourceStreamReader with Admission Control, processing time trigger") { + assume(shouldTestPandasUDFs) + val dataSource = createUserDefinedPythonDataSource(dataSourceName, testAdmissionControlScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + val df = spark.readStream.format(dataSourceName).load() + + val stopSignal = new CountDownLatch(1) + + val q = df.writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .foreachBatch((df: DataFrame, batchId: Long) => { + df.cache() + checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) + if (batchId == 10) stopSignal.countDown() + }) + .trigger(Trigger.ProcessingTime(0)) + .start() + stopSignal.await() + q.stop() + q.awaitTermination() + + assert(q.recentProgress.length >= 10) + q.recentProgress.foreach { progress => + assert(progress.numInputRows === 2) + assert(progress.sources(0).numInputRows === 2) + assert(progress.sources(0).latestOffset === """{"partition-1": 1000000}""") + } + } + } + + test("DataSourceStreamReader with Trigger.AvailableNow") { + assume(shouldTestPandasUDFs) + val dataSource = createUserDefinedPythonDataSource(dataSourceName, testAvailableNowScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + withTempDir { dir => + val path = dir.getAbsolutePath + val checkpointDir = new File(path, "checkpoint") + val outputDir = new File(path, "output") + val df = spark.readStream.format(dataSourceName).load() + val q = df.writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("json") + .trigger(Trigger.AvailableNow()) + .start(outputDir.getAbsolutePath) + q.awaitTermination(waitTimeout.toMillis) + + // 2 rows * 5 batches = 10 rows + assert(q.recentProgress.length === 5) + q.recentProgress.foreach { progress => + assert(progress.numInputRows === 2) + assert(progress.sources(0).numInputRows === 2) + assert(progress.sources(0).latestOffset === """{"partition-1": 1000000}""") + } + + val rowCount = spark.read.format("json").load(outputDir.getAbsolutePath).count() + assert(rowCount === 10) + checkAnswer( + spark.read.format("json").load(outputDir.getAbsolutePath), + (0 until rowCount.toInt).map(Row(_)) + ) + } + } + test("Error creating stream reader") { assume(shouldTestPandasUDFs) val dataSourceScript = @@ -770,7 +970,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { func: PythonMicroBatchStream => Unit): Unit = { val options = CaseInsensitiveStringMap.empty() val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( - pythonDs, dataSourceName, inputSchema, options) + pythonDs, errorDataSourceName, inputSchema, options) runner.init() val stream = new PythonMicroBatchStream( @@ -837,7 +1037,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { func: PythonMicroBatchStream => Unit): Unit = { val options = CaseInsensitiveStringMap.empty() val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner( - pythonDs, dataSourceName, inputSchema, options) + pythonDs, errorDataSourceName, inputSchema, options) runner.init() val stream = new PythonMicroBatchStream( --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
