This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 042804ad545c [SPARK-48567][SS] StreamingQuery.lastProgress should
return the actual StreamingQueryProgress
042804ad545c is described below
commit 042804ad545c88afe69c149b25baea00fc213708
Author: Wei Liu <[email protected]>
AuthorDate: Tue Jun 18 08:44:44 2024 +0900
[SPARK-48567][SS] StreamingQuery.lastProgress should return the actual
StreamingQueryProgress
### What changes were proposed in this pull request?
This PR is created after discussion in this closed one:
https://github.com/apache/spark/pull/46886
I was trying to fix a bug (in connect, query.lastProgress doesn't have
`numInputRows`, `inputRowsPerSecond`, and `processedRowsPerSecond`), and we
reached the conclusion that what purposed in this PR should be the ultimate fix.
In python, for both classic spark and spark connect, the return type of
`lastProgress` is `Dict` (and `recentProgress` is `List[Dict]`), but in scala
it's the actual `StreamingQueryProgress` object:
https://github.com/apache/spark/blob/1a5d22aa2ffe769435be4aa6102ef961c55b9593/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala#L94-L101
This API discrepancy brings some confusion, like in Scala, users can do
`query.lastProgress.batchId`, while in Python they have to do
`query.lastProgress["batchId"]`.
This PR makes `StreamingQuery.lastProgress` to return the actual
`StreamingQueryProgress` (and `StreamingQuery.recentProgress` to return
`List[StreamingQueryProgress]`).
To prevent breaking change, we extend `StreamingQueryProgress` to be a
subclass of `dict`, so existing code accessing using dictionary method (e.g.
`query.lastProgress["id"]`) is still functional.
### Why are the changes needed?
API parity
### Does this PR introduce _any_ user-facing change?
Yes, now `StreamingQuery.lastProgress` returns the actual
`StreamingQueryProgress` (and `StreamingQuery.recentProgress` returns
`List[StreamingQueryProgress]`).
### How was this patch tested?
Added unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46921 from WweiL/SPARK-48567-lastProgress.
Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/streaming/query.py | 9 +-
python/pyspark/sql/streaming/listener.py | 228 +++++++++++++--------
python/pyspark/sql/streaming/query.py | 13 +-
.../pyspark/sql/tests/streaming/test_streaming.py | 44 +++-
.../sql/tests/streaming/test_streaming_listener.py | 32 ++-
5 files changed, 227 insertions(+), 99 deletions(-)
diff --git a/python/pyspark/sql/connect/streaming/query.py
b/python/pyspark/sql/connect/streaming/query.py
index 98ecdc4966c7..cc1e2e220188 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -33,6 +33,7 @@ from pyspark.sql.streaming.listener import (
QueryProgressEvent,
QueryIdleEvent,
QueryTerminatedEvent,
+ StreamingQueryProgress,
)
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
@@ -110,21 +111,21 @@ class StreamingQuery:
status.__doc__ = PySparkStreamingQuery.status.__doc__
@property
- def recentProgress(self) -> List[Dict[str, Any]]:
+ def recentProgress(self) -> List[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.recent_progress = True
progress =
self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
- return [json.loads(p) for p in progress]
+ return [StreamingQueryProgress.fromJson(json.loads(p)) for p in
progress]
recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
@property
- def lastProgress(self) -> Optional[Dict[str, Any]]:
+ def lastProgress(self) -> Optional[StreamingQueryProgress]:
cmd = pb2.StreamingQueryCommand()
cmd.last_progress = True
progress =
self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
if len(progress) > 0:
- return json.loads(progress[-1])
+ return StreamingQueryProgress.fromJson(json.loads(progress[-1]))
else:
return None
diff --git a/python/pyspark/sql/streaming/listener.py
b/python/pyspark/sql/streaming/listener.py
index 2aa63cdb91ab..6cc2cc3fa2b8 100644
--- a/python/pyspark/sql/streaming/listener.py
+++ b/python/pyspark/sql/streaming/listener.py
@@ -397,10 +397,13 @@ class QueryTerminatedEvent:
return self._errorClassOnException
-class StreamingQueryProgress:
+class StreamingQueryProgress(dict):
"""
.. versionadded:: 3.4.0
+ .. versionchanged:: 4.0.0
+ Becomes a subclass of dict
+
Notes
-----
This API is evolving.
@@ -426,23 +429,25 @@ class StreamingQueryProgress:
jprogress: Optional["JavaObject"] = None,
jdict: Optional[Dict[str, Any]] = None,
):
+ super().__init__(
+ id=id,
+ runId=runId,
+ name=name,
+ timestamp=timestamp,
+ batchId=batchId,
+ batchDuration=batchDuration,
+ durationMs=durationMs,
+ eventTime=eventTime,
+ stateOperators=stateOperators,
+ sources=sources,
+ sink=sink,
+ numInputRows=numInputRows,
+ inputRowsPerSecond=inputRowsPerSecond,
+ processedRowsPerSecond=processedRowsPerSecond,
+ observedMetrics=observedMetrics,
+ )
self._jprogress: Optional["JavaObject"] = jprogress
self._jdict: Optional[Dict[str, Any]] = jdict
- self._id: uuid.UUID = id
- self._runId: uuid.UUID = runId
- self._name: Optional[str] = name
- self._timestamp: str = timestamp
- self._batchId: int = batchId
- self._batchDuration: int = batchDuration
- self._durationMs: Dict[str, int] = durationMs
- self._eventTime: Dict[str, str] = eventTime
- self._stateOperators: List[StateOperatorProgress] = stateOperators
- self._sources: List[SourceProgress] = sources
- self._sink: SinkProgress = sink
- self._numInputRows: int = numInputRows
- self._inputRowsPerSecond: float = inputRowsPerSecond
- self._processedRowsPerSecond: float = processedRowsPerSecond
- self._observedMetrics: Dict[str, Row] = observedMetrics
@classmethod
def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress":
@@ -489,9 +494,11 @@ class StreamingQueryProgress:
stateOperators=[StateOperatorProgress.fromJson(s) for s in
j["stateOperators"]],
sources=[SourceProgress.fromJson(s) for s in j["sources"]],
sink=SinkProgress.fromJson(j["sink"]),
- numInputRows=j["numInputRows"],
- inputRowsPerSecond=j["inputRowsPerSecond"],
- processedRowsPerSecond=j["processedRowsPerSecond"],
+ numInputRows=j["numInputRows"] if "numInputRows" in j else None,
+ inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond"
in j else None,
+ processedRowsPerSecond=j["processedRowsPerSecond"]
+ if "processedRowsPerSecond" in j
+ else None,
observedMetrics={
k: Row(*row_dict.keys())(*row_dict.values()) # Assume no
nested rows
for k, row_dict in j["observedMetrics"].items()
@@ -506,7 +513,10 @@ class StreamingQueryProgress:
A unique query id that persists across restarts. See
py:meth:`~pyspark.sql.streaming.StreamingQuery.id`.
"""
- return self._id
+ # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which
casts id and runId
+ # to string. But here they are UUID.
+ # To prevent breaking change, do not cast them to string when accessed
with attribute.
+ return super().__getitem__("id")
@property
def runId(self) -> uuid.UUID:
@@ -514,21 +524,24 @@ class StreamingQueryProgress:
A query id that is unique for every start/restart. See
py:meth:`~pyspark.sql.streaming.StreamingQuery.runId`.
"""
- return self._runId
+ # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which
casts id and runId
+ # to string. But here they are UUID.
+ # To prevent breaking change, do not cast them to string when accessed
with attribute.
+ return super().__getitem__("runId")
@property
def name(self) -> Optional[str]:
"""
User-specified name of the query, `None` if not specified.
"""
- return self._name
+ return self["name"]
@property
def timestamp(self) -> str:
"""
The timestamp to start a query.
"""
- return self._timestamp
+ return self["timestamp"]
@property
def batchId(self) -> int:
@@ -538,21 +551,21 @@ class StreamingQueryProgress:
Similarly, when there is no data to be processed, the batchId will not
be
incremented.
"""
- return self._batchId
+ return self["batchId"]
@property
def batchDuration(self) -> int:
"""
The process duration of each batch.
"""
- return self._batchDuration
+ return self["batchDuration"]
@property
def durationMs(self) -> Dict[str, int]:
"""
The amount of time taken to perform various operations in milliseconds.
"""
- return self._durationMs
+ return self["durationMs"]
@property
def eventTime(self) -> Dict[str, str]:
@@ -570,21 +583,21 @@ class StreamingQueryProgress:
All timestamps are in ISO8601 format, i.e. UTC timestamps.
"""
- return self._eventTime
+ return self["eventTime"]
@property
def stateOperators(self) -> List["StateOperatorProgress"]:
"""
Information about operators in the query that store state.
"""
- return self._stateOperators
+ return self["stateOperators"]
@property
def sources(self) -> List["SourceProgress"]:
"""
detailed statistics on data being read from each of the streaming
sources.
"""
- return self._sources
+ return self["sources"]
@property
def sink(self) -> "SinkProgress":
@@ -592,32 +605,41 @@ class StreamingQueryProgress:
A unique query id that persists across restarts. See
py:meth:`~pyspark.sql.streaming.StreamingQuery.id`.
"""
- return self._sink
+ return self["sink"]
@property
def observedMetrics(self) -> Dict[str, Row]:
- return self._observedMetrics
+ return self["observedMetrics"]
@property
def numInputRows(self) -> int:
"""
The aggregate (across all sources) number of records processed in a
trigger.
"""
- return self._numInputRows
+ if self["numInputRows"] is not None:
+ return self["numInputRows"]
+ else:
+ return sum(s.numInputRows for s in self.sources)
@property
def inputRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate of data arriving.
"""
- return self._inputRowsPerSecond
+ if self["inputRowsPerSecond"] is not None:
+ return self["inputRowsPerSecond"]
+ else:
+ return sum(s.inputRowsPerSecond for s in self.sources)
@property
def processedRowsPerSecond(self) -> float:
"""
The aggregate (across all sources) rate at which Spark is processing
data.
"""
- return self._processedRowsPerSecond
+ if self["processedRowsPerSecond"] is not None:
+ return self["processedRowsPerSecond"]
+ else:
+ return sum(s.processedRowsPerSecond for s in self.sources)
@property
def json(self) -> str:
@@ -641,14 +663,29 @@ class StreamingQueryProgress:
else:
return json.dumps(self._jdict, indent=4)
+ def __getitem__(self, key: str) -> Any:
+ # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which
casts id and runId
+ # to string. But here they are UUID.
+ # To prevent breaking change, also cast them to string when accessed
with __getitem__.
+ if key == "id" or key == "runId":
+ return str(super().__getitem__(key))
+ else:
+ return super().__getitem__(key)
+
def __str__(self) -> str:
return self.prettyJson
+ def __repr__(self) -> str:
+ return self.prettyJson
+
-class StateOperatorProgress:
+class StateOperatorProgress(dict):
"""
.. versionadded:: 3.4.0
+ .. versionchanged:: 4.0.0
+ Becomes a subclass of dict
+
Notes
-----
This API is evolving.
@@ -671,20 +708,22 @@ class StateOperatorProgress:
jprogress: Optional["JavaObject"] = None,
jdict: Optional[Dict[str, Any]] = None,
):
+ super().__init__(
+ operatorName=operatorName,
+ numRowsTotal=numRowsTotal,
+ numRowsUpdated=numRowsUpdated,
+ numRowsRemoved=numRowsRemoved,
+ allUpdatesTimeMs=allUpdatesTimeMs,
+ allRemovalsTimeMs=allRemovalsTimeMs,
+ commitTimeMs=commitTimeMs,
+ memoryUsedBytes=memoryUsedBytes,
+ numRowsDroppedByWatermark=numRowsDroppedByWatermark,
+ numShufflePartitions=numShufflePartitions,
+ numStateStoreInstances=numStateStoreInstances,
+ customMetrics=customMetrics,
+ )
self._jprogress: Optional["JavaObject"] = jprogress
self._jdict: Optional[Dict[str, Any]] = jdict
- self._operatorName: str = operatorName
- self._numRowsTotal: int = numRowsTotal
- self._numRowsUpdated: int = numRowsUpdated
- self._numRowsRemoved: int = numRowsRemoved
- self._allUpdatesTimeMs: int = allUpdatesTimeMs
- self._allRemovalsTimeMs: int = allRemovalsTimeMs
- self._commitTimeMs: int = commitTimeMs
- self._memoryUsedBytes: int = memoryUsedBytes
- self._numRowsDroppedByWatermark: int = numRowsDroppedByWatermark
- self._numShufflePartitions: int = numShufflePartitions
- self._numStateStoreInstances: int = numStateStoreInstances
- self._customMetrics: Dict[str, int] = customMetrics
@classmethod
def fromJObject(cls, jprogress: "JavaObject") -> "StateOperatorProgress":
@@ -724,51 +763,51 @@ class StateOperatorProgress:
@property
def operatorName(self) -> str:
- return self._operatorName
+ return self["operatorName"]
@property
def numRowsTotal(self) -> int:
- return self._numRowsTotal
+ return self["numRowsTotal"]
@property
def numRowsUpdated(self) -> int:
- return self._numRowsUpdated
+ return self["numRowsUpdated"]
@property
def allUpdatesTimeMs(self) -> int:
- return self._allUpdatesTimeMs
+ return self["allUpdatesTimeMs"]
@property
def numRowsRemoved(self) -> int:
- return self._numRowsRemoved
+ return self["numRowsRemoved"]
@property
def allRemovalsTimeMs(self) -> int:
- return self._allRemovalsTimeMs
+ return self["allRemovalsTimeMs"]
@property
def commitTimeMs(self) -> int:
- return self._commitTimeMs
+ return self["commitTimeMs"]
@property
def memoryUsedBytes(self) -> int:
- return self._memoryUsedBytes
+ return self["memoryUsedBytes"]
@property
def numRowsDroppedByWatermark(self) -> int:
- return self._numRowsDroppedByWatermark
+ return self["numRowsDroppedByWatermark"]
@property
def numShufflePartitions(self) -> int:
- return self._numShufflePartitions
+ return self["numShufflePartitions"]
@property
def numStateStoreInstances(self) -> int:
- return self._numStateStoreInstances
+ return self["numStateStoreInstances"]
@property
- def customMetrics(self) -> Dict[str, int]:
- return self._customMetrics
+ def customMetrics(self) -> dict:
+ return self["customMetrics"]
@property
def json(self) -> str:
@@ -795,11 +834,17 @@ class StateOperatorProgress:
def __str__(self) -> str:
return self.prettyJson
+ def __repr__(self) -> str:
+ return self.prettyJson
+
-class SourceProgress:
+class SourceProgress(dict):
"""
.. versionadded:: 3.4.0
+ .. versionchanged:: 4.0.0
+ Becomes a subclass of dict
+
Notes
-----
This API is evolving.
@@ -818,16 +863,18 @@ class SourceProgress:
jprogress: Optional["JavaObject"] = None,
jdict: Optional[Dict[str, Any]] = None,
) -> None:
+ super().__init__(
+ description=description,
+ startOffset=startOffset,
+ endOffset=endOffset,
+ latestOffset=latestOffset,
+ numInputRows=numInputRows,
+ inputRowsPerSecond=inputRowsPerSecond,
+ processedRowsPerSecond=processedRowsPerSecond,
+ metrics=metrics,
+ )
self._jprogress: Optional["JavaObject"] = jprogress
self._jdict: Optional[Dict[str, Any]] = jdict
- self._description: str = description
- self._startOffset: str = startOffset
- self._endOffset: str = endOffset
- self._latestOffset: str = latestOffset
- self._numInputRows: int = numInputRows
- self._inputRowsPerSecond: float = inputRowsPerSecond
- self._processedRowsPerSecond: float = processedRowsPerSecond
- self._metrics: Dict[str, str] = metrics
@classmethod
def fromJObject(cls, jprogress: "JavaObject") -> "SourceProgress":
@@ -862,53 +909,53 @@ class SourceProgress:
"""
Description of the source.
"""
- return self._description
+ return self["description"]
@property
def startOffset(self) -> str:
"""
The starting offset for data being read.
"""
- return self._startOffset
+ return self["startOffset"]
@property
def endOffset(self) -> str:
"""
The ending offset for data being read.
"""
- return self._endOffset
+ return self["endOffset"]
@property
def latestOffset(self) -> str:
"""
The latest offset from this source.
"""
- return self._latestOffset
+ return self["latestOffset"]
@property
def numInputRows(self) -> int:
"""
The number of records read from this source.
"""
- return self._numInputRows
+ return self["numInputRows"]
@property
def inputRowsPerSecond(self) -> float:
"""
The rate at which data is arriving from this source.
"""
- return self._inputRowsPerSecond
+ return self["inputRowsPerSecond"]
@property
def processedRowsPerSecond(self) -> float:
"""
The rate at which data from this source is being processed by Spark.
"""
- return self._processedRowsPerSecond
+ return self["processedRowsPerSecond"]
@property
- def metrics(self) -> Dict[str, str]:
- return self._metrics
+ def metrics(self) -> dict:
+ return self["metrics"]
@property
def json(self) -> str:
@@ -935,11 +982,17 @@ class SourceProgress:
def __str__(self) -> str:
return self.prettyJson
+ def __repr__(self) -> str:
+ return self.prettyJson
+
-class SinkProgress:
+class SinkProgress(dict):
"""
.. versionadded:: 3.4.0
+ .. versionchanged:: 4.0.0
+ Becomes a subclass of dict
+
Notes
-----
This API is evolving.
@@ -953,11 +1006,13 @@ class SinkProgress:
jprogress: Optional["JavaObject"] = None,
jdict: Optional[Dict[str, Any]] = None,
) -> None:
+ super().__init__(
+ description=description,
+ numOutputRows=numOutputRows,
+ metrics=metrics,
+ )
self._jprogress: Optional["JavaObject"] = jprogress
self._jdict: Optional[Dict[str, Any]] = jdict
- self._description: str = description
- self._numOutputRows: int = numOutputRows
- self._metrics: Dict[str, str] = metrics
@classmethod
def fromJObject(cls, jprogress: "JavaObject") -> "SinkProgress":
@@ -982,7 +1037,7 @@ class SinkProgress:
"""
Description of the source.
"""
- return self._description
+ return self["description"]
@property
def numOutputRows(self) -> int:
@@ -990,11 +1045,11 @@ class SinkProgress:
Number of rows written to the sink or -1 for Continuous Mode
(temporarily)
or Sink V1 (until decommissioned).
"""
- return self._numOutputRows
+ return self["numOutputRows"]
@property
def metrics(self) -> Dict[str, str]:
- return self._metrics
+ return self["metrics"]
@property
def json(self) -> str:
@@ -1021,6 +1076,9 @@ class SinkProgress:
def __str__(self) -> str:
return self.prettyJson
+ def __repr__(self) -> str:
+ return self.prettyJson
+
def _test() -> None:
import sys
diff --git a/python/pyspark/sql/streaming/query.py
b/python/pyspark/sql/streaming/query.py
index d3d58da3562b..916f96a5b2c2 100644
--- a/python/pyspark/sql/streaming/query.py
+++ b/python/pyspark/sql/streaming/query.py
@@ -22,7 +22,10 @@ from pyspark.errors import StreamingQueryException,
PySparkValueError
from pyspark.errors.exceptions.captured import (
StreamingQueryException as CapturedStreamingQueryException,
)
-from pyspark.sql.streaming.listener import StreamingQueryListener
+from pyspark.sql.streaming.listener import (
+ StreamingQueryListener,
+ StreamingQueryProgress,
+)
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
@@ -251,7 +254,7 @@ class StreamingQuery:
return json.loads(self._jsq.status().json())
@property
- def recentProgress(self) -> List[Dict[str, Any]]:
+ def recentProgress(self) -> List[StreamingQueryProgress]:
"""
Returns an array of the most recent [[StreamingQueryProgress]] updates
for this query.
The number of progress updates retained for each stream is configured
by Spark session
@@ -280,10 +283,10 @@ class StreamingQuery:
>>> sq.stop()
"""
- return [json.loads(p.json()) for p in self._jsq.recentProgress()]
+ return [StreamingQueryProgress.fromJObject(p) for p in
self._jsq.recentProgress()]
@property
- def lastProgress(self) -> Optional[Dict[str, Any]]:
+ def lastProgress(self) -> Optional[StreamingQueryProgress]:
"""
Returns the most recent :class:`StreamingQueryProgress` update of this
streaming query or
None if there were no progress updates
@@ -311,7 +314,7 @@ class StreamingQuery:
"""
lastProgress = self._jsq.lastProgress()
if lastProgress:
- return json.loads(lastProgress.json())
+ return StreamingQueryProgress.fromJObject(lastProgress)
else:
return None
diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py
b/python/pyspark/sql/tests/streaming/test_streaming.py
index e284d052d9ae..00d1fbf53885 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming.py
@@ -29,7 +29,7 @@ from pyspark.errors import PySparkValueError
class StreamingTestsMixin:
def test_streaming_query_functions_basic(self):
- df = self.spark.readStream.format("rate").option("rowsPerSecond",
10).load()
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = (
df.writeStream.format("memory")
.queryName("test_streaming_query_functions_basic")
@@ -43,8 +43,8 @@ class StreamingTestsMixin:
self.assertEqual(query.exception(), None)
self.assertFalse(query.awaitTermination(1))
query.processAllAvailable()
- recentProgress = query.recentProgress
lastProgress = query.lastProgress
+ recentProgress = query.recentProgress
self.assertEqual(lastProgress["name"], query.name)
self.assertEqual(lastProgress["id"], query.id)
self.assertTrue(any(p == lastProgress for p in recentProgress))
@@ -59,6 +59,46 @@ class StreamingTestsMixin:
finally:
query.stop()
+ def test_streaming_progress(self):
+ """
+ Should be able to access fields using attributes in lastProgress /
recentProgress
+ e.g. q.lastProgress.id
+ """
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+ query = df.writeStream.format("noop").start()
+ try:
+ query.processAllAvailable()
+ lastProgress = query.lastProgress
+ recentProgress = query.recentProgress
+ self.assertEqual(lastProgress["name"], query.name)
+ # Return str when accessed using dict get.
+ self.assertEqual(lastProgress["id"], query.id)
+ # SPARK-48567 Use attribute to access fields in q.lastProgress
+ self.assertEqual(lastProgress.name, query.name)
+ # Return uuid when accessed using attribute.
+ self.assertEqual(str(lastProgress.id), query.id)
+ self.assertTrue(any(p == lastProgress for p in recentProgress))
+ self.assertTrue(lastProgress.numInputRows > 0)
+ # Also access source / sink progress with attributes
+ self.assertTrue(len(lastProgress.sources) > 0)
+ self.assertTrue(lastProgress.sources[0].numInputRows > 0)
+ self.assertTrue(lastProgress["sources"][0]["numInputRows"] > 0)
+ self.assertTrue(lastProgress.sink.numOutputRows > 0)
+ self.assertTrue(lastProgress["sink"]["numOutputRows"] > 0)
+ # In Python, for historical reasons, changing field value
+ # in StreamingQueryProgress is allowed.
+ new_name = "myNewQuery"
+ lastProgress["name"] = new_name
+ self.assertEqual(lastProgress.name, new_name)
+
+ except Exception as e:
+ self.fail(
+ "Streaming query functions sanity check shouldn't throw any
error. "
+ "Error message: " + str(e)
+ )
+ finally:
+ query.stop()
+
def test_streaming_query_name_edge_case(self):
# Query name should be None when not specified
q1 =
self.spark.readStream.format("rate").load().writeStream.format("noop").start()
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index 762fc335b56a..0f13450849c5 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -227,9 +227,9 @@ class StreamingListenerTestsMixin:
"my_event", count(lit(1)).alias("rc"),
count(col("error")).alias("erc")
)
- q = observed_ds.writeStream.format("console").start()
+ q = observed_ds.writeStream.format("noop").start()
- while q.lastProgress is None or q.lastProgress["batchId"] == 0:
+ while q.lastProgress is None or q.lastProgress.batchId == 0:
q.awaitTermination(0.5)
time.sleep(5)
@@ -241,6 +241,32 @@ class StreamingListenerTestsMixin:
q.stop()
self.spark.streams.removeListener(error_listener)
+ def test_streaming_progress(self):
+ try:
+ # Test a fancier query with stateful operation and observed metrics
+ df = self.spark.readStream.format("rate").option("rowsPerSecond",
10).load()
+ df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+ df_stateful = df_observe.groupBy().count() # make query stateful
+ q = (
+ df_stateful.writeStream.format("noop")
+ .queryName("test")
+ .outputMode("update")
+ .trigger(processingTime="5 seconds")
+ .start()
+ )
+
+ while q.lastProgress is None or q.lastProgress.batchId == 0:
+ q.awaitTermination(0.5)
+
+ q.stop()
+
+ self.check_streaming_query_progress(q.lastProgress, True)
+ for p in q.recentProgress:
+ self.check_streaming_query_progress(p, True)
+
+ finally:
+ q.stop()
+
class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase):
def test_number_of_public_methods(self):
@@ -355,7 +381,7 @@ class StreamingListenerTests(StreamingListenerTestsMixin,
ReusedSQLTestCase):
.start()
)
self.assertTrue(q.isActive)
- time.sleep(10)
+ q.awaitTermination(10)
q.stop()
# Make sure all events are empty
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]