This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cbcc7f8a7f73 [SPARK-55304][SS][PYTHON] Introduce support of Admission
Control and Trigger.AvailableNow in Python data source - streaming reader
cbcc7f8a7f73 is described below
commit cbcc7f8a7f73576e1584d2ac2855640310e062eb
Author: Jungtaek Lim <[email protected]>
AuthorDate: Sat Feb 7 11:14:48 2026 +0900
[SPARK-55304][SS][PYTHON] Introduce support of Admission Control and
Trigger.AvailableNow in Python data source - streaming reader
### What changes were proposed in this pull request?
This PR proposes to introduce the support of Admission Control and
Trigger.AvailableNow in Python data source - streaming reader.
To support Admission control, we propose to change `DataSourceStreamReader`
interface as following:
(Created a table to perform side-by-side comparison)
| **Before** | **After** |
| :---: | :---: |
| `class DataSourceStreamReader(ABC):` | `class
DataSourceStreamReader(ABC):` |
| `def initialOffset(self) -> dict` | `def initialOffset(self) -> dict` |
| `def latestOffset() -> dict` | `def latestOffset(self, start: dict,
limit: ReadLimit) -> dict` |
| | `# NOTE: Optional to implement, default = ReadAllAvailable()` |
| | `def getDefaultReadLimit(self) -> ReadLimit` |
| | `# NOTE: Optional to implement, default = None` |
| | `def reportLatestOffset(self) -> Optional[dict]` |
| `def partitions(self, start: dict, end: dict) ->
Sequence[InputPartition]` | `def partitions(self, start: dict, end: dict) ->
Sequence[InputPartition]` |
| `abstractmethod def read(self, partition: InputPartition) ->
Union[Iterator[Tuple], Iterator["RecordBatch"]]` | `abstractmethod def
read(self, partition: InputPartition) -> Union[Iterator[Tuple],
Iterator["RecordBatch"]]` |
| `def commit(self, end: dict) -> None` | `def commit(self, end: dict) ->
None` |
| `def stop(self) -> None` | `def stop(self) -> None` |
The main change is following:
* The method signature for `latestOffset` is changed. The method is
mandatory.
* The method `getDefaultReadLimit` is added, as optional.
* The method `reportLatestOffset` is added, as optional.
This way, new implementations would support Admission Control by default.
We ensure the engine can handle the case of the old method signature, via
Python’s built-in inspect module (similar to Java’s reflection). If the method
“latestOffset” is implemented without parameters, we fall back to the source
which does not enable admission control. For all new sources, implementing
latestOffset with parameters is strongly recommended.
ReadLimit interface and built-in implementations will be available for
source implementations to leverage. Built-in implementations are as follows:
`ReadAllAvailable`, `ReadMinRows`, `ReadMaxRows`, `ReadMaxFiles`,
`ReadMaxBytes`. We won’t support custom implementation of `ReadLimit` interface
at this point since it requires major efforts and we don’t see a demand, but we
can plan for it if there is a strong demand.
We do not make any change to `SimpleDataSourceStreamReader` for Admission
Control, since it is designed for small data fetch and could be considered as
already limiting the data. We could still add the `ReadLimit` later if we see
strong demand of limiting the fetch size via the source option.
To support `Trigger.AvailableNow`, we propose to introduce a new interface
as following:
```
class SupportsTriggerAvailableNow(ABC):
abstractmethod
def prepareForTriggerAvailableNow(self) -> None
```
The above interface can be “mixed-up” with both `DataSourceStreamReader`
and `SimpleDataSourceStreamReader`. It won’t work with `DataSourceStreamReader`
implementations having the old method signature of `latestOffset()`, likewise
mentioned above.
### Why are the changes needed?
This is to catch up with supported features in Scala DSv2 API, which we got
reports from developers that missing features block them to implement some data
sources.
### Does this PR introduce _any_ user-facing change?
Yes, users implementing streaming reader via python data source API will be
able to add the support of Admission Control and Trigger.AvailableNow, which
had been major lacks of features.
### How was this patch tested?
New UTs.
### Was this patch authored or co-authored using generative AI tooling?
Co-authored using claude-4.5-sonnet
Closes #54085 from HeartSaVioR/SPARK-55304.
Lead-authored-by: Jungtaek Lim <[email protected]>
Co-authored-by: Jitesh Soni <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/pyspark/sql/datasource.py | 79 +++-
python/pyspark/sql/datasource_internal.py | 84 ++--
python/pyspark/sql/streaming/datasource.py | 119 ++++++
python/pyspark/sql/streaming/listener.py | 14 +-
.../streaming/python_streaming_source_runner.py | 113 ++++-
.../sql/tests/test_python_streaming_datasource.py | 285 ++++++++++++-
.../v2/python/PythonMicroBatchStream.scala | 108 ++++-
.../datasources/v2/python/PythonScan.scala | 19 +-
.../streaming/PythonStreamingSourceRunner.scala | 96 +++++
.../streaming/PythonStreamingDataSourceSuite.scala | 461 +++++++++++++++++++--
10 files changed, 1276 insertions(+), 102 deletions(-)
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index f1908180a3ba..bb73a7a9206b 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -32,6 +32,7 @@ from typing import (
)
from pyspark.sql import Row
+from pyspark.sql.streaming.datasource import ReadAllAvailable, ReadLimit
from pyspark.sql.types import StructType
from pyspark.errors import PySparkNotImplementedError
@@ -714,9 +715,35 @@ class DataSourceStreamReader(ABC):
messageParameters={"feature": "initialOffset"},
)
- def latestOffset(self) -> dict:
+ def latestOffset(self, start: dict, limit: ReadLimit) -> dict:
"""
- Returns the most recent offset available.
+ Returns the most recent offset available given a read limit. The start
offset can be used
+ to figure out how much new data should be read given the limit.
+
+ The `start` will be provided from the return value of
:meth:`initialOffset()` for
+ the very first micro-batch, and for subsequent micro-batches, the
start offset is the
+ ending offset from the previous micro-batch. The source can return the
`start` parameter
+ as it is, if there is no data to process.
+
+ :class:`ReadLimit` can be used by the source to limit the amount of
data returned in this
+ call. The implementation should implement
:meth:`getDefaultReadLimit()` to provide the
+ proper :class:`ReadLimit` if the source can limit the amount of data
returned based on the
+ source options.
+
+ The engine can still call :meth:`latestOffset()` with
:class:`ReadAllAvailable` even if the
+ source produces the different read limit from
:meth:`getDefaultReadLimit()`, to respect the
+ semantic of trigger. The source must always respect the given
readLimit provided by the
+ engine; e.g. if the readLimit is :class:`ReadAllAvailable`, the source
must ignore the read
+ limit configured through options.
+
+ .. versionadded:: 4.2.0
+
+ Parameters
+ ----------
+ start : dict
+ The start offset of the microbatch to continue reading from.
+ limit : :class:`ReadLimit`
+ The limit on the amount of data to be returned by this call.
Returns
-------
@@ -726,14 +753,58 @@ class DataSourceStreamReader(ABC):
Examples
--------
- >>> def latestOffset(self):
- ... return {"parititon-1": {"index": 3, "closed": True},
"partition-2": {"index": 5}}
+ >>> from pyspark.sql.streaming.datasource import ReadAllAvailable,
ReadMaxRows
+ >>> def latestOffset(self, start, limit):
+ ... # Assume the source has 10 new records between start and
latest offset
+ ... if isinstance(limit, ReadAllAvailable):
+ ... return {"index": start["index"] + 10}
+ ... else: # e.g., limit is ReadMaxRows(5)
+ ... return {"index": start["index"] + min(10, limit.maxRows)}
"""
+ # NOTE: Previous Spark versions didn't have start offset and read
limit parameters for this
+ # method. While Spark will ensure the backward compatibility for
existing data sources, the
+ # new data sources are strongly encouraged to implement this new
method signature.
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={"feature": "latestOffset"},
)
+ def getDefaultReadLimit(self) -> ReadLimit:
+ """
+ Returns the read limits potentially passed to the data source through
options when creating
+ the data source. See the built-in implementations of
:class:`ReadLimit` for available read
+ limits.
+
+ Implementing this method is optional. By default, it returns
:class:`ReadAllAvailable`,
+ which means there is no limit on the amount of data returned by
:meth:`latestOffset()`.
+
+ .. versionadded:: 4.2.0
+ """
+ return ReadAllAvailable()
+
+ def reportLatestOffset(self) -> Optional[dict]:
+ """
+ Returns the most recent offset available. The information is used to
report the latest
+ offset in the streaming query status.
+ The source can return `None`, if there is no data to process or the
source does not support
+ to this method.
+
+ .. versionadded:: 4.2.0
+
+ Returns
+ -------
+ dict or None
+ A dict or recursive dict whose key and value are primitive types,
which includes
+ Integer, String and Boolean.
+ Returns `None` if the source does not support reporting latest
offset.
+
+ Examples
+ --------
+ >>> def reportLatestOffset(self):
+ ... return {"partition-1": {"index": 100}, "partition-2":
{"index": 200}}
+ """
+ return None
+
def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]:
"""
Returns a list of InputPartition given the start and end offsets. Each
InputPartition
diff --git a/python/pyspark/sql/datasource_internal.py
b/python/pyspark/sql/datasource_internal.py
index 6df0be4192ec..92a968cf0572 100644
--- a/python/pyspark/sql/datasource_internal.py
+++ b/python/pyspark/sql/datasource_internal.py
@@ -19,7 +19,7 @@
import json
import copy
from itertools import chain
-from typing import Iterator, List, Optional, Sequence, Tuple
+from typing import Iterator, List, Sequence, Tuple, Type, Dict
from pyspark.sql.datasource import (
DataSource,
@@ -27,8 +27,17 @@ from pyspark.sql.datasource import (
InputPartition,
SimpleDataSourceStreamReader,
)
+from pyspark.sql.streaming.datasource import (
+ ReadAllAvailable,
+ ReadLimit,
+ ReadMaxBytes,
+ ReadMaxRows,
+ ReadMinRows,
+ ReadMaxFiles,
+)
from pyspark.sql.types import StructType
from pyspark.errors import PySparkNotImplementedError
+from pyspark.errors.exceptions.base import PySparkException
def _streamReader(datasource: DataSource, schema: StructType) ->
"DataSourceStreamReader":
@@ -62,13 +71,9 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader):
so that :class:`SimpleDataSourceStreamReader` can integrate with streaming
engine like an
ordinary :class:`DataSourceStreamReader`.
- current_offset tracks the latest progress of the record prefetching, it is
initialized to be
- initialOffset() when query start for the first time or initialized to be
the end offset of
- the last planned batch when query restarts.
-
When streaming engine calls latestOffset(), the wrapper calls read() that
starts from
- current_offset, prefetches and cache the data, then updates the
current_offset to be
- the end offset of the new data.
+ start provided via the parameter of latestOffset(), prefetches and cache
the data, then updates
+ the current_offset to be the end offset of the new data.
When streaming engine call planInputPartitions(start, end), the wrapper
get the prefetched data
from cache and send it to JVM along with the input partitions.
@@ -79,28 +84,26 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader):
def __init__(self, simple_reader: SimpleDataSourceStreamReader):
self.simple_reader = simple_reader
- self.initial_offset: Optional[dict] = None
- self.current_offset: Optional[dict] = None
self.cache: List[PrefetchedCacheEntry] = []
def initialOffset(self) -> dict:
- if self.initial_offset is None:
- self.initial_offset = self.simple_reader.initialOffset()
- return self.initial_offset
-
- def latestOffset(self) -> dict:
- # when query start for the first time, use initial offset as the start
offset.
- if self.current_offset is None:
- self.current_offset = self.initialOffset()
- (iter, end) = self.simple_reader.read(self.current_offset)
- self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
- self.current_offset = end
+ return self.simple_reader.initialOffset()
+
+ def getDefaultReadLimit(self) -> ReadLimit:
+ # We do not consider providing different read limit on simple stream
reader.
+ return ReadAllAvailable()
+
+ def latestOffset(self, start: dict, limit: ReadLimit) -> dict:
+ assert start is not None, "start offset should not be None"
+ assert isinstance(
+ limit, ReadAllAvailable
+ ), "simple stream reader does not support read limit"
+
+ (iter, end) = self.simple_reader.read(start)
+ self.cache.append(PrefetchedCacheEntry(start, end, iter))
return end
def commit(self, end: dict) -> None:
- if self.current_offset is None:
- self.current_offset = end
-
end_idx = -1
for idx, entry in enumerate(self.cache):
if json.dumps(entry.end) == json.dumps(end):
@@ -112,11 +115,6 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader):
self.simple_reader.commit(end)
def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
- # when query restart from checkpoint, use the last committed offset as
the start offset.
- # This depends on the streaming engine calling planInputPartitions()
of the last batch
- # in offset log when query restart.
- if self.current_offset is None:
- self.current_offset = end
if len(self.cache) > 0:
assert self.cache[-1].end == end
return [SimpleInputPartition(start, end)]
@@ -144,3 +142,33 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader):
self, input_partition: SimpleInputPartition # type: ignore[override]
) -> Iterator[Tuple]:
return self.simple_reader.readBetweenOffsets(input_partition.start,
input_partition.end)
+
+
+class ReadLimitRegistry:
+ def __init__(self) -> None:
+ self._registry: Dict[str, Type[ReadLimit]] = {}
+ # Register built-in ReadLimit types
+ self.__register(ReadAllAvailable)
+ self.__register(ReadMinRows)
+ self.__register(ReadMaxRows)
+ self.__register(ReadMaxFiles)
+ self.__register(ReadMaxBytes)
+
+ def __register(self, read_limit_type: Type["ReadLimit"]) -> None:
+ name = read_limit_type.__name__
+ if name in self._registry:
+ raise PySparkException(f"ReadLimit type '{name}' is already
registered.")
+ self._registry[name] = read_limit_type
+
+ def get(self, params_with_type: dict) -> ReadLimit:
+ type_name = params_with_type["_type"]
+ if type_name is None:
+ raise PySparkException("ReadLimit type name is missing.")
+
+ read_limit_type = self._registry.get(type_name)
+ if read_limit_type is None:
+ raise PySparkException("name '{}' is not
registered.".format(type_name))
+
+ params_without_type = params_with_type.copy()
+ del params_without_type["_type"]
+ return read_limit_type(**params_without_type)
diff --git a/python/pyspark/sql/streaming/datasource.py
b/python/pyspark/sql/streaming/datasource.py
new file mode 100644
index 000000000000..97f34133c593
--- /dev/null
+++ b/python/pyspark/sql/streaming/datasource.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+
+class ReadLimit:
+ """
+ Specifies limits on how much data to read from a streaming source when
+ determining the latest offset.
+
+ As of Spark 4.2.0, only built-in implementations of :class:`ReadLimit` are
supported. Please
+ refer to the following classes for the supported types:
+
+ - :class:`ReadAllAvailable`
+ - :class:`ReadMinRows`
+ - :class:`ReadMaxRows`
+ - :class:`ReadMaxFiles`
+ - :class:`ReadMaxBytes`
+ """
+
+
+@dataclass
+class ReadAllAvailable(ReadLimit):
+ """
+ A :class:`ReadLimit` that indicates to read all available data, regardless
of the given source
+ options.
+ """
+
+
+@dataclass
+class ReadMinRows(ReadLimit):
+ """
+ A :class:`ReadLimit` that indicates to read minimum N rows. If there is
less than N rows
+ available for read, the source should skip producing a new offset to read
and wait until more
+ data arrives.
+
+ Note that the semantic does not work properly with Trigger.AvailableNow
since the source
+ may end up waiting forever for more data to arrive. It is the source's
responsibility to
+ handle this case properly.
+ """
+
+ min_rows: int
+
+
+@dataclass
+class ReadMaxRows(ReadLimit):
+ """
+ A :class:`ReadLimit` that indicates to read maximum N rows. The source
should not read more
+ than N rows when determining the latest offset.
+ """
+
+ max_rows: int
+
+
+@dataclass
+class ReadMaxFiles(ReadLimit):
+ """
+ A :class:`ReadLimit` that indicates to read maximum N files. The source
should not read more
+ than N files when determining the latest offset.
+ """
+
+ max_files: int
+
+
+@dataclass
+class ReadMaxBytes(ReadLimit):
+ """
+ A :class:`ReadLimit` that indicates to read maximum N bytes. The source
should not read more
+ than N bytes when determining the latest offset.
+ """
+
+ max_bytes: int
+
+
+class SupportsTriggerAvailableNow(ABC):
+ """
+ A mixin interface for streaming sources that support Trigger.AvailableNow.
This interface can
+ be added to both :class:`DataSourceStreamReader` and
:class:`SimpleDataSourceStreamReader`.
+ """
+
+ @abstractmethod
+ def prepareForTriggerAvailableNow(self) -> None:
+ """
+ This will be called at the beginning of streaming queries with
Trigger.AvailableNow, to let
+ the source record the offset for the current latest data at the time
(a.k.a the target
+ offset for the query). The source must behave as if there is no new
data coming in after
+ the target offset, i.e., the source must not return an offset higher
than the target offset
+ when :meth:`DataSourceStreamReader.latestOffset()` is called.
+
+ The source can extend the semantic of "current latest data" based on
its own logic, but the
+ extended semantic must not violate the expectation that the source
will not read any data
+ which is added later than the time this method has called.
+
+ Note that it is the source's responsibility to ensure that calling
+ :meth:`DataSourceStreamReader.latestOffset()` or
:meth:`SimpleDataSourceStreamReader.read()`
+ after calling this method will eventually reach the target offset, and
finally returns the
+ same offset as given start parameter, to indicate that there is no
more data to read. This
+ includes the case where the query is restarted and the source is asked
to read from the
+ offset being journaled in previous run - source should take care of
exceptional cases like
+ new partition has added during the restart, etc, to ensure that the
query run will be
+ completed at some point.
+ """
+ pass
diff --git a/python/pyspark/sql/streaming/listener.py
b/python/pyspark/sql/streaming/listener.py
index ddda41601f2c..e0ef0c6c4b62 100644
--- a/python/pyspark/sql/streaming/listener.py
+++ b/python/pyspark/sql/streaming/listener.py
@@ -914,12 +914,20 @@ class SourceProgress(dict):
@classmethod
def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress":
+ def _to_json_string(value: Any) -> str:
+ """Convert offset value to JSON string. If already a string,
return as-is.
+ If a dict/list, JSON-encode it."""
+ if isinstance(value, str):
+ return value
+ else:
+ return json.dumps(value)
+
return cls(
jdict=j,
description=j["description"],
- startOffset=str(j["startOffset"]),
- endOffset=str(j["endOffset"]),
- latestOffset=str(j["latestOffset"]),
+ startOffset=_to_json_string(j["startOffset"]),
+ endOffset=_to_json_string(j["endOffset"]),
+ latestOffset=_to_json_string(j["latestOffset"]),
numInputRows=j["numInputRows"],
inputRowsPerSecond=j["inputRowsPerSecond"],
processedRowsPerSecond=j["processedRowsPerSecond"],
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index ab988eb714cc..31f70a59dbfb 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -19,17 +19,29 @@ import os
import sys
import json
from typing import IO, Iterator, Tuple
+import dataclasses
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import IllegalArgumentException, PySparkAssertionError
+from pyspark.errors.exceptions.base import PySparkException
from pyspark.serializers import (
read_int,
write_int,
write_with_length,
SpecialLengths,
)
-from pyspark.sql.datasource import DataSource, DataSourceStreamReader
-from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper,
_streamReader
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceStreamReader,
+)
+from pyspark.sql.streaming.datasource import (
+ SupportsTriggerAvailableNow,
+)
+from pyspark.sql.datasource_internal import (
+ _SimpleStreamReaderWrapper,
+ _streamReader,
+ ReadLimitRegistry,
+)
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
from pyspark.sql.types import (
_parse_datatype_json_string,
@@ -47,15 +59,26 @@ from pyspark.worker_util import (
utf8_deserializer,
)
+
INITIAL_OFFSET_FUNC_ID = 884
LATEST_OFFSET_FUNC_ID = 885
PARTITIONS_FUNC_ID = 886
COMMIT_FUNC_ID = 887
+CHECK_SUPPORTED_FEATURES_ID = 888
+PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889
+LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890
+GET_DEFAULT_READ_LIMIT_FUNC_ID = 891
+REPORT_LATEST_OFFSET_FUNC_ID = 892
PREFETCHED_RECORDS_NOT_FOUND = 0
NON_EMPTY_PYARROW_RECORD_BATCHES = 1
EMPTY_PYARROW_RECORD_BATCHES = 2
+SUPPORTS_ADMISSION_CONTROL = 1 << 0
+SUPPORTS_TRIGGER_AVAILABLE_NOW = 1 << 1
+
+READ_LIMIT_REGISTRY = ReadLimitRegistry()
+
def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
offset = reader.initialOffset()
@@ -63,7 +86,7 @@ def initial_offset_func(reader: DataSourceStreamReader,
outfile: IO) -> None:
def latest_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
- offset = reader.latestOffset()
+ offset = reader.latestOffset() # type: ignore[call-arg]
write_with_length(json.dumps(offset).encode("utf-8"), outfile)
@@ -116,6 +139,80 @@ def send_batch_func(
write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile)
+def check_support_func(reader: DataSourceStreamReader, outfile: IO) -> None:
+ support_flags = 0
+ if isinstance(reader, _SimpleStreamReaderWrapper):
+ # We consider the method of `read` in simple_reader to already have
admission control
+ # into it.
+ support_flags |= SUPPORTS_ADMISSION_CONTROL
+ if isinstance(reader.simple_reader, SupportsTriggerAvailableNow):
+ support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW
+ else:
+ import inspect
+
+ sig = inspect.signature(reader.latestOffset)
+ if len(sig.parameters) == 0:
+ # old signature of latestOffset()
+ pass
+ else:
+ # we don't check the number/type of parameters here strictly - we
leave the python to
+ # raise error when calling the method if the types do not match.
+ support_flags |= SUPPORTS_ADMISSION_CONTROL
+ if isinstance(reader, SupportsTriggerAvailableNow):
+ support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW
+ write_int(support_flags, outfile)
+
+
+def prepare_for_trigger_available_now_func(reader: DataSourceStreamReader,
outfile: IO) -> None:
+ if isinstance(reader, _SimpleStreamReaderWrapper):
+ if isinstance(reader.simple_reader, SupportsTriggerAvailableNow):
+ reader.simple_reader.prepareForTriggerAvailableNow()
+ else:
+ raise PySparkException(
+ "prepareForTriggerAvailableNow is not supported by the
underlying simple reader."
+ )
+ else:
+ if isinstance(reader, SupportsTriggerAvailableNow):
+ reader.prepareForTriggerAvailableNow()
+ else:
+ raise PySparkException(
+ "prepareForTriggerAvailableNow is not supported by the stream
reader."
+ )
+ write_int(0, outfile)
+
+
+def latest_offset_admission_control_func(
+ reader: DataSourceStreamReader, infile: IO, outfile: IO
+) -> None:
+ start_offset_dict = json.loads(utf8_deserializer.loads(infile))
+
+ limit = json.loads(utf8_deserializer.loads(infile))
+ limit_obj = READ_LIMIT_REGISTRY.get(limit)
+
+ offset = reader.latestOffset(start_offset_dict, limit_obj)
+ write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def get_default_read_limit_func(reader: DataSourceStreamReader, outfile: IO)
-> None:
+ limit = reader.getDefaultReadLimit()
+ limit_as_dict = dataclasses.asdict(limit) | { # type:
ignore[call-overload]
+ "_type": limit.__class__.__name__
+ }
+ write_with_length(json.dumps(limit_as_dict).encode("utf-8"), outfile)
+
+
+def report_latest_offset_func(reader: DataSourceStreamReader, outfile: IO) ->
None:
+ if isinstance(reader, _SimpleStreamReaderWrapper):
+ # We do not consider providing latest offset on simple stream reader.
+ write_int(0, outfile)
+ else:
+ offset = reader.reportLatestOffset()
+ if offset is None:
+ write_int(0, outfile)
+ else:
+ write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
def main(infile: IO, outfile: IO) -> None:
try:
check_python_version(infile)
@@ -176,6 +273,16 @@ def main(infile: IO, outfile: IO) -> None:
)
elif func_id == COMMIT_FUNC_ID:
commit_func(reader, infile, outfile)
+ elif func_id == CHECK_SUPPORTED_FEATURES_ID:
+ check_support_func(reader, outfile)
+ elif func_id == PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID:
+ prepare_for_trigger_available_now_func(reader, outfile)
+ elif func_id == LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID:
+ latest_offset_admission_control_func(reader, infile,
outfile)
+ elif func_id == GET_DEFAULT_READ_LIMIT_FUNC_ID:
+ get_default_read_limit_func(reader, outfile)
+ elif func_id == REPORT_LATEST_OFFSET_FUNC_ID:
+ report_latest_offset_func(reader, outfile)
else:
raise IllegalArgumentException(
errorClass="UNSUPPORTED_OPERATION",
diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py
b/python/pyspark/sql/tests/test_python_streaming_datasource.py
index deec97d0da72..bef85f7ba845 100644
--- a/python/pyspark/sql/tests/test_python_streaming_datasource.py
+++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py
@@ -18,6 +18,7 @@ import os
import tempfile
import time
import unittest
+import json
from pyspark.sql.datasource import (
DataSource,
@@ -28,6 +29,12 @@ from pyspark.sql.datasource import (
SimpleDataSourceStreamReader,
WriterCommitMessage,
)
+from pyspark.sql.streaming.datasource import (
+ ReadAllAvailable,
+ ReadLimit,
+ ReadMaxRows,
+ SupportsTriggerAvailableNow,
+)
from pyspark.sql.streaming import StreamingQueryException
from pyspark.sql.types import Row
from pyspark.testing.sqlutils import (
@@ -39,6 +46,42 @@ from pyspark.testing.utils import eventually
from pyspark.testing.sqlutils import ReusedSQLTestCase
+def wait_for_condition(query, condition_fn, timeout_sec=30):
+ """
+ Wait for a condition on a streaming query to be met, with timeout and
error context.
+
+ :param query: StreamingQuery object
+ :param condition_fn: Function that takes query and returns True when
condition is met
+ :param timeout_sec: Timeout in seconds (default 30)
+ :raises TimeoutError: If condition is not met within timeout, with query
context
+ """
+ start_time = time.time()
+ sleep_interval = 0.2
+
+ while not condition_fn(query):
+ elapsed = time.time() - start_time
+ if elapsed >= timeout_sec:
+ # Collect context for debugging
+ exception_info = query.exception()
+ recent_progresses = query.recentProgress
+
+ error_msg = (
+ f"Timeout after {timeout_sec} seconds waiting for condition. "
+ f"Query exception: {exception_info}. "
+ f"Recent progress count: {len(recent_progresses)}. "
+ )
+
+ if recent_progresses:
+ error_msg += f"Last progress: {recent_progresses[-1]}. "
+ error_msg += f"All recent progresses: {recent_progresses}"
+ else:
+ error_msg += "No progress recorded."
+
+ raise TimeoutError(error_msg)
+
+ time.sleep(sleep_interval)
+
+
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class BasePythonStreamingDataSourceTestsMixin:
def test_basic_streaming_data_source_class(self):
@@ -76,9 +119,8 @@ class BasePythonStreamingDataSourceTestsMixin:
def initialOffset(self):
return {"offset": 0}
- def latestOffset(self):
- self.current += 2
- return {"offset": self.current}
+ def latestOffset(self, start, limit):
+ return {"offset": start["offset"] + 2}
def partitions(self, start, end):
return [RangePartition(start["offset"], end["offset"])]
@@ -140,20 +182,144 @@ class BasePythonStreamingDataSourceTestsMixin:
return TestDataSource
- def test_stream_reader(self):
- self.spark.dataSource.register(self._get_test_data_source())
+ def _get_test_data_source_old_latest_offset_signature(self):
+ class RangePartition(InputPartition):
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ class TestStreamReader(DataSourceStreamReader):
+ current = 0
+
+ def initialOffset(self):
+ return {"offset": 0}
+
+ def latestOffset(self):
+ self.current += 2
+ return {"offset": self.current}
+
+ def partitions(self, start, end):
+ return [RangePartition(start["offset"], end["offset"])]
+
+ def commit(self, end):
+ pass
+
+ def read(self, partition):
+ start, end = partition.start, partition.end
+ for i in range(start, end):
+ yield (i,)
+
+ class TestDataSource(DataSource):
+ def schema(self):
+ return "id INT"
+
+ def streamReader(self, schema):
+ return TestStreamReader()
+
+ return TestDataSource
+
+ def _get_test_data_source_for_admission_control(self):
+ class TestDataStreamReader(DataSourceStreamReader):
+ def initialOffset(self):
+ return {"partition-1": 0}
+
+ def getDefaultReadLimit(self):
+ return ReadMaxRows(2)
+
+ def latestOffset(self, start: dict, limit: ReadLimit):
+ start_idx = start["partition-1"]
+ if isinstance(limit, ReadAllAvailable):
+ end_offset = start_idx + 10
+ else:
+ assert isinstance(
+ limit, ReadMaxRows
+ ), "Expected ReadMaxRows read limit but got " +
str(type(limit))
+ end_offset = start_idx + limit.max_rows
+ return {"partition-1": end_offset}
+
+ def reportLatestOffset(self):
+ return {"partition-1": 1000000}
+
+ def partitions(self, start: dict, end: dict):
+ start_index = start["partition-1"]
+ end_index = end["partition-1"]
+ return [InputPartition(i) for i in range(start_index,
end_index)]
+
+ def read(self, partition):
+ yield (partition.value,)
+
+ class TestDataSource(DataSource):
+ def schema(self) -> str:
+ return "id INT"
+
+ def streamReader(self, schema):
+ return TestDataStreamReader()
+
+ return TestDataSource
+
+ def _get_test_data_source_for_trigger_available_now(self):
+ class TestDataStreamReader(DataSourceStreamReader,
SupportsTriggerAvailableNow):
+ def initialOffset(self):
+ return {"partition-1": 0}
+
+ def getDefaultReadLimit(self):
+ return ReadMaxRows(2)
+
+ def latestOffset(self, start: dict, limit: ReadLimit):
+ start_idx = start["partition-1"]
+ if isinstance(limit, ReadAllAvailable):
+ end_offset = start_idx + 10
+ else:
+ assert isinstance(
+ limit, ReadMaxRows
+ ), "Expected ReadMaxRows read limit but got " +
str(type(limit))
+ end_offset = min(
+ start_idx + limit.max_rows,
self.desired_end_offset["partition-1"]
+ )
+ return {"partition-1": end_offset}
+
+ def reportLatestOffset(self):
+ return {"partition-1": 1000000}
+
+ def prepareForTriggerAvailableNow(self) -> None:
+ self.desired_end_offset = {"partition-1": 10}
+
+ def partitions(self, start: dict, end: dict):
+ start_index = start["partition-1"]
+ end_index = end["partition-1"]
+ return [InputPartition(i) for i in range(start_index,
end_index)]
+
+ def read(self, partition):
+ yield (partition.value,)
+
+ class TestDataSource(DataSource):
+ def schema(self) -> str:
+ return "id INT"
+
+ def streamReader(self, schema):
+ return TestDataStreamReader()
+
+ return TestDataSource
+
+ def _test_stream_reader(self, test_data_source):
+ self.spark.dataSource.register(test_data_source)
df = self.spark.readStream.format("TestDataSource").load()
def check_batch(df, batch_id):
assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 +
1)])
q = df.writeStream.foreachBatch(check_batch).start()
- while len(q.recentProgress) < 10:
- time.sleep(0.2)
+ wait_for_condition(q, lambda query: len(query.recentProgress) >= 10)
q.stop()
q.awaitTermination()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
+ def test_stream_reader(self):
+ self._test_stream_reader(self._get_test_data_source())
+
+ def test_stream_reader_old_latest_offset_signature(self):
+
self._test_stream_reader(self._get_test_data_source_old_latest_offset_signature())
+
def test_stream_reader_pyarrow(self):
import pyarrow as pa
@@ -198,8 +364,7 @@ class BasePythonStreamingDataSourceTestsMixin:
.option("checkpointLocation", checkpoint_dir.name)
.start(output_dir.name)
)
- while not q.recentProgress:
- time.sleep(0.2)
+ wait_for_condition(q, lambda query: len(query.recentProgress) > 0)
q.stop()
q.awaitTermination()
@@ -214,6 +379,55 @@ class BasePythonStreamingDataSourceTestsMixin:
assertDataFrameEqual(df, expected_data)
+ def test_stream_reader_admission_control_trigger_once(self):
+
self.spark.dataSource.register(self._get_test_data_source_for_admission_control())
+ df = self.spark.readStream.format("TestDataSource").load()
+
+ def check_batch(df, batch_id):
+ assertDataFrameEqual(df, [Row(x) for x in range(10)])
+
+ q = df.writeStream.trigger(once=True).foreachBatch(check_batch).start()
+ q.awaitTermination()
+ self.assertIsNone(q.exception(), "No exception has to be propagated.")
+ self.assertEqual(len(q.recentProgress), 1)
+ self.assertEqual(q.lastProgress.numInputRows, 10)
+ self.assertEqual(q.lastProgress.sources[0].numInputRows, 10)
+ self.assertEqual(
+ json.loads(q.lastProgress.sources[0].latestOffset),
{"partition-1": 1000000}
+ )
+
+ def test_stream_reader_admission_control_processing_time_trigger(self):
+
self.spark.dataSource.register(self._get_test_data_source_for_admission_control())
+ df = self.spark.readStream.format("TestDataSource").load()
+
+ def check_batch(df, batch_id):
+ assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 +
1)])
+
+ q = df.writeStream.foreachBatch(check_batch).start()
+ wait_for_condition(q, lambda query: len(query.recentProgress) >= 10)
+ q.stop()
+ q.awaitTermination()
+ self.assertIsNone(q.exception(), "No exception has to be propagated.")
+
+ def test_stream_reader_trigger_available_now(self):
+
self.spark.dataSource.register(self._get_test_data_source_for_trigger_available_now())
+ df = self.spark.readStream.format("TestDataSource").load()
+
+ def check_batch(df, batch_id):
+ assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 +
1)])
+
+ q =
df.writeStream.foreachBatch(check_batch).trigger(availableNow=True).start()
+ q.awaitTermination(timeout=30)
+ self.assertIsNone(q.exception(), "No exception has to be propagated.")
+ # 2 rows * 5 batches = 10 rows
+ self.assertEqual(len(q.recentProgress), 5)
+ for progress in q.recentProgress:
+ self.assertEqual(progress.numInputRows, 2)
+ self.assertEqual(q.lastProgress.sources[0].numInputRows, 2)
+ self.assertEqual(
+ json.loads(q.lastProgress.sources[0].latestOffset),
{"partition-1": 1000000}
+ )
+
def test_simple_stream_reader(self):
class SimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self):
@@ -246,12 +460,55 @@ class BasePythonStreamingDataSourceTestsMixin:
assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 +
1)])
q = df.writeStream.foreachBatch(check_batch).start()
- while len(q.recentProgress) < 10:
- time.sleep(0.2)
+ wait_for_condition(q, lambda query: len(query.recentProgress) >= 10)
q.stop()
q.awaitTermination()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
+ def test_simple_stream_reader_trigger_available_now(self):
+ class SimpleStreamReader(SimpleDataSourceStreamReader,
SupportsTriggerAvailableNow):
+ def initialOffset(self):
+ return {"offset": 0}
+
+ def read(self, start: dict):
+ start_idx = start["offset"]
+ end_offset = min(start_idx + 2,
self.desired_end_offset["offset"])
+ it = iter([(i,) for i in range(start_idx, end_offset)])
+ return (it, {"offset": end_offset})
+
+ def commit(self, end):
+ pass
+
+ def readBetweenOffsets(self, start: dict, end: dict):
+ start_idx = start["offset"]
+ end_idx = end["offset"]
+ return iter([(i,) for i in range(start_idx, end_idx)])
+
+ def prepareForTriggerAvailableNow(self) -> None:
+ self.desired_end_offset = {"offset": 10}
+
+ class SimpleDataSource(DataSource):
+ def schema(self):
+ return "id INT"
+
+ def simpleStreamReader(self, schema):
+ return SimpleStreamReader()
+
+ self.spark.dataSource.register(SimpleDataSource)
+ df = self.spark.readStream.format("SimpleDataSource").load()
+
+ def check_batch(df, batch_id):
+ # the last offset for the data is 9 since the desired end offset
is 10
+ # the batch isn't triggered with no data, so either we have one
data or two data in each batch
+ if batch_id * 2 + 1 > 9:
+ assertDataFrameEqual(df, [Row(batch_id * 2)])
+ else:
+ assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2
+ 1)])
+
+ q =
df.writeStream.foreachBatch(check_batch).trigger(availableNow=True).start()
+ q.awaitTermination(timeout=30)
+ self.assertIsNone(q.exception(), "No exception has to be propagated.")
+
def test_stream_writer(self):
input_dir =
tempfile.TemporaryDirectory(prefix="test_data_stream_write_input")
output_dir =
tempfile.TemporaryDirectory(prefix="test_data_stream_write_output")
@@ -268,8 +525,7 @@ class BasePythonStreamingDataSourceTestsMixin:
.option("checkpointLocation", checkpoint_dir.name)
.start(output_dir.name)
)
- while not q.recentProgress:
- time.sleep(0.2)
+ wait_for_condition(q, lambda query: len(query.recentProgress) > 0)
# Test stream writer write and commit.
# The first microbatch contain 30 rows and 2 partitions.
@@ -283,8 +539,7 @@ class BasePythonStreamingDataSourceTestsMixin:
# Test StreamWriter write and abort.
# When row id > 50, write tasks throw exception and fail.
# 1.txt is written by StreamWriter.abort() to record the failure.
- while q.exception() is None:
- time.sleep(0.2)
+ wait_for_condition(q, lambda query: query.exception() is not None)
assertDataFrameEqual(
self.spark.read.text(os.path.join(output_dir.name, "1.txt")),
[Row("failed in batch 1")],
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 50ea7616061c..e7f4f6a564dc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.datasources.v2.python
import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.PythonFunction
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.read.{InputPartition,
PartitionReaderFactory}
-import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset,
MicroBatchStream, Offset}
+import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset,
MicroBatchStream, Offset, ReadLimit, SupportsAdmissionControl,
SupportsTriggerAvailableNow}
import
org.apache.spark.sql.execution.datasources.v2.python.PythonMicroBatchStream.nextStreamId
import
org.apache.spark.sql.execution.python.streaming.PythonStreamingSourceRunner
import org.apache.spark.sql.types.StructType
@@ -28,18 +29,16 @@ import org.apache.spark.storage.{PythonStreamBlockId,
StorageLevel}
case class PythonStreamingSourceOffset(json: String) extends Offset
-class PythonMicroBatchStream(
+case class PythonStreamingSourceReadLimit(json: String) extends ReadLimit
+
+abstract class PythonMicroBatchStreamBase(
ds: PythonDataSourceV2,
shortName: String,
outputSchema: StructType,
- options: CaseInsensitiveStringMap
- )
+ options: CaseInsensitiveStringMap,
+ runner: PythonStreamingSourceRunner)
extends MicroBatchStream
- with Logging
- with AcceptsLatestSeenOffset {
- private def createDataSourceFunc =
- ds.source.createPythonFunction(
- ds.getOrCreateDataSourceInPython(shortName, options,
Some(outputSchema)).dataSource)
+ with Logging {
private val streamId = nextStreamId
private var nextBlockId = 0L
@@ -49,10 +48,6 @@ class PythonMicroBatchStream(
// from python to JVM.
private var cachedInputPartition: Option[(String, String,
PythonStreamingInputPartition)] = None
- private val runner: PythonStreamingSourceRunner =
- new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
- runner.init()
-
override def initialOffset(): Offset =
PythonStreamingSourceOffset(runner.initialOffset())
override def latestOffset(): Offset =
PythonStreamingSourceOffset(runner.latestOffset())
@@ -83,12 +78,6 @@ class PythonMicroBatchStream(
}
}
- override def setLatestSeenOffset(offset: Offset): Unit = {
- // Call planPartition on python with an empty offset range to initialize
the start offset
- // for the prefetching of simple reader.
- runner.partitions(offset.json(), offset.json())
- }
-
private lazy val readInfo: PythonDataSourceReadInfo = {
ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming =
true)
}
@@ -110,10 +99,91 @@ class PythonMicroBatchStream(
override def deserializeOffset(json: String): Offset =
PythonStreamingSourceOffset(json)
}
+class PythonMicroBatchStream(
+ ds: PythonDataSourceV2,
+ shortName: String,
+ outputSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ runner: PythonStreamingSourceRunner)
+ extends PythonMicroBatchStreamBase(ds, shortName, outputSchema, options,
runner)
+ with AcceptsLatestSeenOffset {
+
+ override def setLatestSeenOffset(offset: Offset): Unit = {
+ // Call planPartition on python with an empty offset range to initialize
the start offset
+ // for the prefetching of simple reader.
+ runner.partitions(offset.json(), offset.json())
+ }
+}
+
+class PythonMicroBatchStreamWithAdmissionControl(
+ ds: PythonDataSourceV2,
+ shortName: String,
+ outputSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ runner: PythonStreamingSourceRunner)
+ extends PythonMicroBatchStreamBase(ds, shortName, outputSchema, options,
runner)
+ with SupportsAdmissionControl {
+
+ override def latestOffset(): Offset = {
+ throw new IllegalStateException("latestOffset without parameters is not
expected to be " +
+ "called. Please use latestOffset(startOffset: Offset, limit: ReadLimit)
instead.")
+ }
+
+ override def latestOffset(startOffset: Offset, limit: ReadLimit): Offset = {
+ PythonStreamingSourceOffset(runner.latestOffset(startOffset, limit))
+ }
+
+ override def getDefaultReadLimit: ReadLimit = {
+ val readLimitJson = runner.getDefaultReadLimit()
+ PythonStreamingSourceReadLimit(readLimitJson)
+ }
+
+ override def reportLatestOffset(): Offset = {
+ val offsetJson = runner.reportLatestOffset()
+ if (offsetJson == null) {
+ null
+ } else {
+ PythonStreamingSourceOffset(offsetJson)
+ }
+ }
+}
+
+class PythonMicroBatchStreamWithTriggerAvailableNow(
+ ds: PythonDataSourceV2,
+ shortName: String,
+ outputSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ runner: PythonStreamingSourceRunner)
+ extends PythonMicroBatchStreamWithAdmissionControl(ds, shortName,
outputSchema, options, runner)
+ with SupportsTriggerAvailableNow {
+
+ override def prepareForTriggerAvailableNow(): Unit = {
+ runner.prepareForTriggerAvailableNow()
+ }
+}
+
object PythonMicroBatchStream {
private var currentId = 0
def nextStreamId: Int = synchronized {
currentId = currentId + 1
currentId
}
+
+ def createPythonStreamingSourceRunner(
+ ds: PythonDataSourceV2,
+ shortName: String,
+ outputSchema: StructType,
+ options: CaseInsensitiveStringMap): PythonStreamingSourceRunner = {
+
+ // Below methods were called during the construction of
PythonMicroBatchStream, so there is no
+ // timing/sequencing issue of calling them in here.
+ def createDataSourceFunc: PythonFunction =
+ ds.source.createPythonFunction(
+ ds.getOrCreateDataSourceInPython(
+ shortName,
+ options,
+ Some(outputSchema)).dataSource)
+
+ new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index a133c40cde60..9e3effe7d441 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -35,8 +35,23 @@ class PythonScan(
) extends Scan with SupportsMetadata {
override def toBatch: Batch = new PythonBatch(ds, shortName, outputSchema,
options)
- override def toMicroBatchStream(checkpointLocation: String):
MicroBatchStream =
- new PythonMicroBatchStream(ds, shortName, outputSchema, options)
+ override def toMicroBatchStream(checkpointLocation: String):
MicroBatchStream = {
+ val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+ ds, shortName, outputSchema, options)
+ runner.init()
+
+ val supportedFeatures = runner.checkSupportedFeatures()
+
+ if (supportedFeatures.triggerAvailableNow) {
+ new PythonMicroBatchStreamWithTriggerAvailableNow(
+ ds, shortName, outputSchema, options, runner)
+ } else if (supportedFeatures.admissionControl) {
+ new PythonMicroBatchStreamWithAdmissionControl(
+ ds, shortName, outputSchema, options, runner)
+ } else {
+ new PythonMicroBatchStream(ds, shortName, outputSchema, options, runner)
+ }
+ }
override def description: String = "(Python)"
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 638a2e9d2062..74cf32c46921 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -33,7 +33,9 @@ import org.apache.spark.internal.LogKeys.PYTHON_EXEC
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.streaming.{Offset,
ReadAllAvailable, ReadLimit}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
+import
org.apache.spark.sql.execution.datasources.v2.python.PythonStreamingSourceReadLimit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
@@ -46,11 +48,19 @@ object PythonStreamingSourceRunner {
val LATEST_OFFSET_FUNC_ID = 885
val PARTITIONS_FUNC_ID = 886
val COMMIT_FUNC_ID = 887
+ val CHECK_SUPPORTED_FEATURES_ID = 888
+ val PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889
+ val LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890
+ val GET_DEFAULT_READ_LIMIT_FUNC_ID = 891
+ val REPORT_LATEST_OFFSET_FUNC_ID = 892
// Status code for JVM to decide how to receive prefetched record batches
// for simple stream reader.
val PREFETCHED_RECORDS_NOT_FOUND = 0
val NON_EMPTY_PYARROW_RECORD_BATCHES = 1
val EMPTY_PYARROW_RECORD_BATCHES = 2
+ val READ_ALL_AVAILABLE_JSON = """{"_type": "ReadAllAvailable"}"""
+
+ case class SupportedFeatures(admissionControl: Boolean, triggerAvailableNow:
Boolean)
}
/**
@@ -130,6 +140,65 @@ class PythonStreamingSourceRunner(
}
}
+ def checkSupportedFeatures(): SupportedFeatures = {
+ dataOut.writeInt(CHECK_SUPPORTED_FEATURES_ID)
+ dataOut.flush()
+
+ val featureBits = dataIn.readInt()
+ if (featureBits == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+ action = "checkSupportedFeatures", msg)
+ }
+ val admissionControl = (featureBits & (1 << 0)) == 1
+ val availableNow = (featureBits & (1 << 1)) == (1 << 1)
+
+ SupportedFeatures(admissionControl, availableNow)
+ }
+
+ def getDefaultReadLimit(): String = {
+ dataOut.writeInt(GET_DEFAULT_READ_LIMIT_FUNC_ID)
+ dataOut.flush()
+
+ val len = dataIn.readInt()
+ if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+ action = "getDefaultReadLimit", msg)
+ }
+
+ PythonWorkerUtils.readUTF(len, dataIn)
+ }
+
+ def reportLatestOffset(): String = {
+ dataOut.writeInt(REPORT_LATEST_OFFSET_FUNC_ID)
+ dataOut.flush()
+
+ val len = dataIn.readInt()
+ if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+ action = "reportLatestOffset", msg)
+ }
+
+ if (len == 0) {
+ null
+ } else {
+ PythonWorkerUtils.readUTF(len, dataIn)
+ }
+ }
+
+ def prepareForTriggerAvailableNow(): Unit = {
+ dataOut.writeInt(PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID)
+ dataOut.flush()
+ val status = dataIn.readInt()
+ if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+ action = "prepareForTriggerAvailableNow", msg)
+ }
+ }
+
/**
* Invokes latestOffset() function of the stream reader and receive the
return value.
*/
@@ -145,6 +214,33 @@ class PythonStreamingSourceRunner(
PythonWorkerUtils.readUTF(len, dataIn)
}
+ def latestOffset(startOffset: Offset, limit: ReadLimit): String = {
+ dataOut.writeInt(LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID)
+ PythonWorkerUtils.writeUTF(startOffset.json, dataOut)
+ limit match {
+ case _: ReadAllAvailable =>
+ // NOTE: we need to use a constant here to match the Python side given
the engine can
+ // decide by itself to use ReadAllAvailable and the Python side
version of the instance
+ // isn't available here.
+ PythonWorkerUtils.writeUTF(READ_ALL_AVAILABLE_JSON, dataOut)
+
+ case p: PythonStreamingSourceReadLimit =>
+ PythonWorkerUtils.writeUTF(p.json, dataOut)
+
+ case _ =>
+ throw new UnsupportedOperationException("Unsupported ReadLimit type: "
+
+ s"${limit.getClass.getName}")
+ }
+ dataOut.flush()
+ val len = dataIn.readInt()
+ if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+ action = "latestOffset", msg)
+ }
+ PythonWorkerUtils.readUTF(len, dataIn)
+ }
+
/**
* Invokes initialOffset() function of the stream reader and receive the
return value.
*/
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
index 0e33b6e55a43..664bc42b64c4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
@@ -17,14 +17,15 @@
package org.apache.spark.sql.execution.python.streaming
import java.io.File
-import java.util.concurrent.CountDownLatch
+import java.util.concurrent.{CountDownLatch, TimeUnit}
import scala.concurrent.duration._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource,
shouldTestPandasUDFs}
-import
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2,
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.connector.read.streaming.ReadLimit
+import
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2,
PythonMicroBatchStream, PythonMicroBatchStreamWithAdmissionControl,
PythonStreamingSourceOffset, PythonStreamingSourceReadLimit}
import org.apache.spark.sql.execution.python.PythonDataSourceSuiteBase
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog,
OffsetSeqLog}
@@ -208,8 +209,11 @@ class PythonStreamingDataSourceSimpleSuite extends
PythonDataSourceSuiteBase {
.format("json")
.start(outputDir.getAbsolutePath)
- while (q.recentProgress.length < 5) {
- Thread.sleep(200)
+ eventually(timeout(30.seconds)) {
+ assert(q.recentProgress.length >= 5,
+ s"Expected at least 5 progress updates but got
${q.recentProgress.length}. " +
+ s"Query exception: ${q.exception}. " +
+ s"Recent progress: ${q.recentProgress.mkString(", ")}")
}
q.stop()
q.awaitTermination()
@@ -249,12 +253,18 @@ class PythonStreamingDataSourceSimpleSuite extends
PythonDataSourceSuiteBase {
pythonDs.setShortName("ErrorDataSource")
def testMicroBatchStreamError(action: String, msg: String)(
- func: PythonMicroBatchStream => Unit): Unit = {
- val stream = new PythonMicroBatchStream(
+ func: PythonMicroBatchStreamWithAdmissionControl => Unit): Unit = {
+ val options = CaseInsensitiveStringMap.empty()
+ val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+ pythonDs, errorDataSourceName, inputSchema, options)
+ runner.init()
+
+ val stream = new PythonMicroBatchStreamWithAdmissionControl(
pythonDs,
errorDataSourceName,
inputSchema,
- CaseInsensitiveStringMap.empty()
+ options,
+ runner
)
val err = intercept[SparkException] {
func(stream)
@@ -277,16 +287,6 @@ class PythonStreamingDataSourceSimpleSuite extends
PythonDataSourceSuiteBase {
stream =>
stream.initialOffset()
}
-
- // User don't need to implement latestOffset for
SimpleDataSourceStreamReader.
- // The latestOffset method of simple stream reader invokes initialOffset()
and read()
- // So the not implemented method is initialOffset.
- testMicroBatchStreamError(
- "latestOffset",
- "[NOT_IMPLEMENTED] initialOffset is not implemented") {
- stream =>
- stream.latestOffset()
- }
}
test("read() method throw error in SimpleDataSourceStreamReader") {
@@ -314,12 +314,18 @@ class PythonStreamingDataSourceSimpleSuite extends
PythonDataSourceSuiteBase {
pythonDs.setShortName("ErrorDataSource")
def testMicroBatchStreamError(action: String, msg: String)(
- func: PythonMicroBatchStream => Unit): Unit = {
- val stream = new PythonMicroBatchStream(
+ func: PythonMicroBatchStreamWithAdmissionControl => Unit): Unit = {
+ val options = CaseInsensitiveStringMap.empty()
+ val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+ pythonDs, errorDataSourceName, inputSchema, options)
+ runner.init()
+
+ val stream = new PythonMicroBatchStreamWithAdmissionControl(
pythonDs,
errorDataSourceName,
inputSchema,
- CaseInsensitiveStringMap.empty()
+ options,
+ runner
)
val err = intercept[SparkException] {
func(stream)
@@ -337,7 +343,60 @@ class PythonStreamingDataSourceSimpleSuite extends
PythonDataSourceSuiteBase {
}
testMicroBatchStreamError("latestOffset", "Exception: error reading
available data") { stream =>
- stream.latestOffset()
+ stream.latestOffset(PythonStreamingSourceOffset("""{"partition": 0}"""),
+ ReadLimit.allAvailable())
+ }
+ }
+
+ test("SimpleDataSourceStreamReader with Trigger.AvailableNow") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource
+ |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+ |from pyspark.sql.streaming.datasource import
SupportsTriggerAvailableNow
+ |
+ |class SimpleDataStreamReader(SimpleDataSourceStreamReader,
SupportsTriggerAvailableNow):
+ | def initialOffset(self):
+ | return {"partition-1": 0}
+ | def read(self, start: dict):
+ | start_idx = start["partition-1"]
+ | end_offset = min(start_idx + 2, self.desired_end_offset)
+ | it = iter([(i, ) for i in range(start_idx, end_offset)])
+ | return (it, {"partition-1": end_offset})
+ | def readBetweenOffsets(self, start: dict, end: dict):
+ | start_idx = start["partition-1"]
+ | end_idx = end["partition-1"]
+ | return iter([(i, ) for i in range(start_idx, end_idx)])
+ | def prepareForTriggerAvailableNow(self):
+ | self.desired_end_offset = 10
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
+ | def simpleStreamReader(self, schema):
+ | return SimpleDataStreamReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ withTempDir { dir =>
+ val path = dir.getAbsolutePath
+ val checkpointDir = new File(path, "checkpoint")
+ val outputDir = new File(path, "output")
+ val df = spark.readStream.format(dataSourceName).load()
+ val q = df.writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .format("json")
+ .trigger(Trigger.AvailableNow())
+ .start(outputDir.getAbsolutePath)
+ q.awaitTermination(waitTimeout.toMillis)
+ val rowCount =
spark.read.format("json").load(outputDir.getAbsolutePath).count()
+ assert(rowCount === 10)
+ checkAnswer(
+ spark.read.format("json").load(outputDir.getAbsolutePath),
+ (0 until rowCount.toInt).map(Row(_))
+ )
}
}
@@ -459,11 +518,18 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
spark.dataSource.registerPython(dataSourceName, dataSource)
val pythonDs = new PythonDataSourceV2
pythonDs.setShortName("SimpleDataSource")
+
+ val options = CaseInsensitiveStringMap.empty()
+ val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+ pythonDs, dataSourceName, inputSchema, options)
+ runner.init()
+
val stream = new PythonMicroBatchStream(
pythonDs,
dataSourceName,
inputSchema,
- CaseInsensitiveStringMap.empty()
+ options,
+ runner
)
var startOffset = stream.initialOffset()
@@ -611,6 +677,203 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
q.awaitTermination()
}
+ private val testAdmissionControlScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource
+ |from pyspark.sql.datasource import (
+ | DataSourceStreamReader,
+ | InputPartition,
+ |)
+ |from pyspark.sql.streaming.datasource import (
+ | ReadAllAvailable,
+ | ReadLimit,
+ | ReadMaxRows,
+ |)
+ |
+ |class TestDataStreamReader(
+ | DataSourceStreamReader,
+ |):
+ | def initialOffset(self):
+ | return {"partition-1": 0}
+ | def getDefaultReadLimit(self):
+ | return ReadMaxRows(2)
+ | def latestOffset(self, start: dict, limit: ReadLimit):
+ | start_idx = start["partition-1"]
+ | if isinstance(limit, ReadAllAvailable):
+ | end_offset = start_idx + 10
+ | else:
+ | assert isinstance(limit, ReadMaxRows), ("Expected
ReadMaxRows read "
+ | "limit but got "
+ | + str(type(limit)))
+ | end_offset = start_idx + limit.max_rows
+ | return {"partition-1": end_offset}
+ | def reportLatestOffset(self):
+ | return {"partition-1": 1000000}
+ | def partitions(self, start: dict, end: dict):
+ | start_index = start["partition-1"]
+ | end_index = end["partition-1"]
+ | return [InputPartition(i) for i in range(start_index,
end_index)]
+ | def read(self, partition):
+ | yield (partition.value,)
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
+ | def streamReader(self, schema):
+ | return TestDataStreamReader()
+ |""".stripMargin
+
+ private val testAvailableNowScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource
+ |from pyspark.sql.datasource import (
+ | DataSourceStreamReader,
+ | InputPartition,
+ |)
+ |from pyspark.sql.streaming.datasource import (
+ | ReadAllAvailable,
+ | ReadLimit,
+ | ReadMaxRows,
+ | SupportsTriggerAvailableNow
+ |)
+ |
+ |class TestDataStreamReader(
+ | DataSourceStreamReader,
+ | SupportsTriggerAvailableNow
+ |):
+ | def initialOffset(self):
+ | return {"partition-1": 0}
+ | def getDefaultReadLimit(self):
+ | return ReadMaxRows(2)
+ | def latestOffset(self, start: dict, limit: ReadLimit):
+ | start_idx = start["partition-1"]
+ | if isinstance(limit, ReadAllAvailable):
+ | end_offset = start_idx + 5
+ | else:
+ | assert isinstance(limit, ReadMaxRows), ("Expected
ReadMaxRows read "
+ | "limit but got "
+ | + str(type(limit)))
+ | end_offset = start_idx + limit.max_rows
+ | end_offset = min(end_offset, self.desired_end_offset)
+ | return {"partition-1": end_offset}
+ | def reportLatestOffset(self):
+ | return {"partition-1": 1000000}
+ | def prepareForTriggerAvailableNow(self):
+ | self.desired_end_offset = 10
+ | def partitions(self, start: dict, end: dict):
+ | start_index = start["partition-1"]
+ | end_index = end["partition-1"]
+ | return [InputPartition(i) for i in range(start_index,
end_index)]
+ | def read(self, partition):
+ | yield (partition.value,)
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
+ | def streamReader(self, schema):
+ | return TestDataStreamReader()
+ |""".stripMargin
+
+ test("DataSourceStreamReader with Admission Control, Trigger.Once") {
+ assume(shouldTestPandasUDFs)
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
testAdmissionControlScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ withTempDir { dir =>
+ val path = dir.getAbsolutePath
+ val checkpointDir = new File(path, "checkpoint")
+ val outputDir = new File(path, "output")
+ val df = spark.readStream.format(dataSourceName).load()
+ val q = df.writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .format("json")
+ // Use Trigger.Once here by intention to test read with admission
control.
+ .trigger(Trigger.Once())
+ .start(outputDir.getAbsolutePath)
+ q.awaitTermination(waitTimeout.toMillis)
+
+ assert(q.recentProgress.length === 1)
+ assert(q.lastProgress.numInputRows === 10)
+ assert(q.lastProgress.sources(0).numInputRows === 10)
+ assert(q.lastProgress.sources(0).latestOffset === """{"partition-1":
1000000}""")
+
+ val rowCount =
spark.read.format("json").load(outputDir.getAbsolutePath).count()
+ assert(rowCount === 10)
+ checkAnswer(
+ spark.read.format("json").load(outputDir.getAbsolutePath),
+ (0 until rowCount.toInt).map(Row(_))
+ )
+ }
+ }
+
+ test("DataSourceStreamReader with Admission Control, processing time
trigger") {
+ assume(shouldTestPandasUDFs)
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
testAdmissionControlScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ withTempDir { dir =>
+ val path = dir.getAbsolutePath
+ val checkpointDir = new File(path, "checkpoint")
+ val df = spark.readStream.format(dataSourceName).load()
+
+ val stopSignal = new CountDownLatch(1)
+
+ val q = df.writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .foreachBatch((df: DataFrame, batchId: Long) => {
+ df.cache()
+ checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+ if (batchId == 10) stopSignal.countDown()
+ })
+ .trigger(Trigger.ProcessingTime(0))
+ .start()
+ stopSignal.await()
+ q.stop()
+ q.awaitTermination()
+
+ assert(q.recentProgress.length >= 10)
+ q.recentProgress.foreach { progress =>
+ assert(progress.numInputRows === 2)
+ assert(progress.sources(0).numInputRows === 2)
+ assert(progress.sources(0).latestOffset === """{"partition-1":
1000000}""")
+ }
+ }
+ }
+
+ test("DataSourceStreamReader with Trigger.AvailableNow") {
+ assume(shouldTestPandasUDFs)
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
testAvailableNowScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ withTempDir { dir =>
+ val path = dir.getAbsolutePath
+ val checkpointDir = new File(path, "checkpoint")
+ val outputDir = new File(path, "output")
+ val df = spark.readStream.format(dataSourceName).load()
+ val q = df.writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .format("json")
+ .trigger(Trigger.AvailableNow())
+ .start(outputDir.getAbsolutePath)
+ q.awaitTermination(waitTimeout.toMillis)
+
+ // 2 rows * 5 batches = 10 rows
+ assert(q.recentProgress.length === 5)
+ q.recentProgress.foreach { progress =>
+ assert(progress.numInputRows === 2)
+ assert(progress.sources(0).numInputRows === 2)
+ assert(progress.sources(0).latestOffset === """{"partition-1":
1000000}""")
+ }
+
+ val rowCount =
spark.read.format("json").load(outputDir.getAbsolutePath).count()
+ assert(rowCount === 10)
+ checkAnswer(
+ spark.read.format("json").load(outputDir.getAbsolutePath),
+ (0 until rowCount.toInt).map(Row(_))
+ )
+ }
+ }
+
test("Error creating stream reader") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
@@ -705,12 +968,19 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
pythonDs.setShortName("ErrorDataSource")
def testMicroBatchStreamError(action: String, msg: String)(
- func: PythonMicroBatchStream => Unit): Unit = {
- val stream = new PythonMicroBatchStream(
+ func: PythonMicroBatchStreamWithAdmissionControl => Unit): Unit = {
+ val options = CaseInsensitiveStringMap.empty()
+ val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+ pythonDs, errorDataSourceName, inputSchema, options)
+ runner.init()
+
+ // New default for python stream reader is with Admission Control
+ val stream = new PythonMicroBatchStreamWithAdmissionControl(
pythonDs,
errorDataSourceName,
inputSchema,
- CaseInsensitiveStringMap.empty()
+ options,
+ runner
)
val err = intercept[SparkException] {
func(stream)
@@ -734,12 +1004,14 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
stream.initialOffset()
}
+ val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}")
testMicroBatchStreamError("latestOffset", "[NOT_IMPLEMENTED] latestOffset
is not implemented") {
stream =>
- stream.latestOffset()
+ val readLimit = PythonStreamingSourceReadLimit(
+ PythonStreamingSourceRunner.READ_ALL_AVAILABLE_JSON)
+ stream.latestOffset(offset, readLimit)
}
- val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}")
testMicroBatchStreamError("planPartitions", "[NOT_IMPLEMENTED] partitions
is not implemented") {
stream =>
stream.planInputPartitions(offset, offset)
@@ -767,11 +1039,17 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
def testMicroBatchStreamError(action: String, msg: String)(
func: PythonMicroBatchStream => Unit): Unit = {
+ val options = CaseInsensitiveStringMap.empty()
+ val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+ pythonDs, errorDataSourceName, inputSchema, options)
+ runner.init()
+
val stream = new PythonMicroBatchStream(
pythonDs,
errorDataSourceName,
inputSchema,
- CaseInsensitiveStringMap.empty()
+ options,
+ runner
)
val err = intercept[SparkException] {
func(stream)
@@ -804,6 +1082,133 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
stream.commit(offset)
}
}
+
+ test("empty batch for stream when latestOffset produces the same offset with
start") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource,
DataSourceStreamReader, InputPartition
+ |
+ |class ConditionalEmptyBatchReader(DataSourceStreamReader):
+ | call_count = 0
+ |
+ | def initialOffset(self):
+ | return {"offset": 0}
+ |
+ | def latestOffset(self, start, limit):
+ | self.call_count += 1
+ | # For odd batches (call count - 1 is odd), return the same
offset
+ | # (simulating no new data)
+ | # For even batches, advance the offset by 2
+ | if (self.call_count - 1) % 2 == 1:
+ | # Return current offset without advancing
+ | return start
+ | else:
+ | # Advance offset by 2
+ | return {"offset": start["offset"] + 2}
+ |
+ | def partitions(self, start: dict, end: dict):
+ | start_offset = start["offset"]
+ | end_offset = end["offset"]
+ | # Create partitions for the range [start, end)
+ | return [InputPartition(i) for i in range(start_offset,
end_offset)]
+ |
+ | def commit(self, end: dict):
+ | pass
+ |
+ | def read(self, partition):
+ | # Yield a value with a marker to identify this is from
Python source
+ | yield (partition.value, 1000)
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT, source INT"
+ |
+ | def streamReader(self, schema):
+ | return ConditionalEmptyBatchReader()
+ |""".stripMargin
+
+ val dataSource =
+ createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
+ val pythonDF = spark.readStream.format(dataSourceName).load()
+
+ // Create a rate source that produces data every microbatch
+ val rateDF = spark.readStream
+ .format("rate-micro-batch")
+ .option("rowsPerBatch", "2")
+ .load()
+ .selectExpr("CAST(value AS INT) as id", "2000 as source")
+
+ // Union the two sources
+ val unionDF = pythonDF.union(rateDF)
+
+ val stopSignal = new CountDownLatch(1)
+ var batchesWithoutPythonData = 0
+ var batchesWithPythonData = 0
+
+ val q = unionDF.writeStream
+ .foreachBatch((df: DataFrame, batchId: Long) => {
+ df.cache()
+ val pythonRows = df.filter("source = 1000").count()
+ val rateRows = df.filter("source = 2000").count()
+
+ // Rate source should always produce 2 rows per batch
+ assert(
+ rateRows == 2,
+ s"Batch $batchId: Expected 2 rows from rate source, got $rateRows"
+ )
+
+ // Python source should produce 0 rows for odd batches (empty batches)
+ // and 2 rows for even batches
+ if (batchId % 2 == 1) {
+ // Odd batch - Python source should return same offset, producing
empty batch
+ assert(
+ pythonRows == 0,
+ s"Batch $batchId: Expected 0 rows from Python source (empty
batch), got $pythonRows"
+ )
+ batchesWithoutPythonData += 1
+ } else {
+ // Even batch - Python source should advance offset and produce data
+ assert(
+ pythonRows == 2,
+ s"Batch $batchId: Expected 2 rows from Python source, got
$pythonRows"
+ )
+ batchesWithPythonData += 1
+ }
+
+ if (batchId >= 7) stopSignal.countDown()
+ })
+ .trigger(ProcessingTimeTrigger(0))
+ .start()
+
+ eventually(timeout(waitTimeout)) {
+ assert(
+ stopSignal.await(1, TimeUnit.SECONDS),
+ s"""
+ |Streaming query did not reach specific microbatch in time,
+ |# of batches with data from python stream source:
$batchesWithPythonData,
+ |# of batches without data from python stream source:
$batchesWithoutPythonData,
+ |recentProgress: ${q.recentProgress.mkString("[", ", ", "]")},
+ |exception (if any): ${q.exception}
+ |""".stripMargin
+ )
+ }
+
+ q.stop()
+ q.awaitTermination()
+
+ // Verify that we observed both types of batches
+ assert(
+ batchesWithoutPythonData >= 4,
+ s"Expected at least 4 batches without Python data, got
$batchesWithoutPythonData"
+ )
+ assert(
+ batchesWithPythonData >= 4,
+ s"Expected at least 4 batches with Python data, got
$batchesWithPythonData"
+ )
+ }
}
class PythonStreamingDataSourceWriteSuite extends PythonDataSourceSuiteBase {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]