This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 042804ad545c [SPARK-48567][SS] StreamingQuery.lastProgress should 
return the actual StreamingQueryProgress
042804ad545c is described below

commit 042804ad545c88afe69c149b25baea00fc213708
Author: Wei Liu <[email protected]>
AuthorDate: Tue Jun 18 08:44:44 2024 +0900

    [SPARK-48567][SS] StreamingQuery.lastProgress should return the actual 
StreamingQueryProgress
    
    ### What changes were proposed in this pull request?
    
    This PR is created after discussion in this closed one: 
https://github.com/apache/spark/pull/46886
    I was trying to fix a bug (in connect, query.lastProgress doesn't have 
`numInputRows`, `inputRowsPerSecond`, and `processedRowsPerSecond`), and we 
reached the conclusion that what purposed in this PR should be the ultimate fix.
    
    In python, for both classic spark and spark connect, the return type of 
`lastProgress` is `Dict` (and `recentProgress` is `List[Dict]`), but in scala 
it's the actual `StreamingQueryProgress` object:
    
https://github.com/apache/spark/blob/1a5d22aa2ffe769435be4aa6102ef961c55b9593/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala#L94-L101
    
    This API discrepancy brings some confusion, like in Scala, users can do 
`query.lastProgress.batchId`, while in Python they have to do 
`query.lastProgress["batchId"]`.
    
    This PR makes `StreamingQuery.lastProgress` to return the actual 
`StreamingQueryProgress` (and `StreamingQuery.recentProgress` to return 
`List[StreamingQueryProgress]`).
    
    To prevent breaking change, we extend `StreamingQueryProgress` to be a 
subclass of `dict`, so existing code accessing using dictionary method (e.g. 
`query.lastProgress["id"]`) is still functional.
    
    ### Why are the changes needed?
    
    API parity
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, now `StreamingQuery.lastProgress` returns the actual 
`StreamingQueryProgress` (and `StreamingQuery.recentProgress` returns 
`List[StreamingQueryProgress]`).
    
    ### How was this patch tested?
    
    Added unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46921 from WweiL/SPARK-48567-lastProgress.
    
    Authored-by: Wei Liu <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/streaming/query.py      |   9 +-
 python/pyspark/sql/streaming/listener.py           | 228 +++++++++++++--------
 python/pyspark/sql/streaming/query.py              |  13 +-
 .../pyspark/sql/tests/streaming/test_streaming.py  |  44 +++-
 .../sql/tests/streaming/test_streaming_listener.py |  32 ++-
 5 files changed, 227 insertions(+), 99 deletions(-)

diff --git a/python/pyspark/sql/connect/streaming/query.py 
b/python/pyspark/sql/connect/streaming/query.py
index 98ecdc4966c7..cc1e2e220188 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -33,6 +33,7 @@ from pyspark.sql.streaming.listener import (
     QueryProgressEvent,
     QueryIdleEvent,
     QueryTerminatedEvent,
+    StreamingQueryProgress,
 )
 from pyspark.sql.streaming.query import (
     StreamingQuery as PySparkStreamingQuery,
@@ -110,21 +111,21 @@ class StreamingQuery:
     status.__doc__ = PySparkStreamingQuery.status.__doc__
 
     @property
-    def recentProgress(self) -> List[Dict[str, Any]]:
+    def recentProgress(self) -> List[StreamingQueryProgress]:
         cmd = pb2.StreamingQueryCommand()
         cmd.recent_progress = True
         progress = 
self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
-        return [json.loads(p) for p in progress]
+        return [StreamingQueryProgress.fromJson(json.loads(p)) for p in 
progress]
 
     recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
 
     @property
-    def lastProgress(self) -> Optional[Dict[str, Any]]:
+    def lastProgress(self) -> Optional[StreamingQueryProgress]:
         cmd = pb2.StreamingQueryCommand()
         cmd.last_progress = True
         progress = 
self._execute_streaming_query_cmd(cmd).recent_progress.recent_progress_json
         if len(progress) > 0:
-            return json.loads(progress[-1])
+            return StreamingQueryProgress.fromJson(json.loads(progress[-1]))
         else:
             return None
 
diff --git a/python/pyspark/sql/streaming/listener.py 
b/python/pyspark/sql/streaming/listener.py
index 2aa63cdb91ab..6cc2cc3fa2b8 100644
--- a/python/pyspark/sql/streaming/listener.py
+++ b/python/pyspark/sql/streaming/listener.py
@@ -397,10 +397,13 @@ class QueryTerminatedEvent:
         return self._errorClassOnException
 
 
-class StreamingQueryProgress:
+class StreamingQueryProgress(dict):
     """
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 4.0.0
+        Becomes a subclass of dict
+
     Notes
     -----
     This API is evolving.
@@ -426,23 +429,25 @@ class StreamingQueryProgress:
         jprogress: Optional["JavaObject"] = None,
         jdict: Optional[Dict[str, Any]] = None,
     ):
+        super().__init__(
+            id=id,
+            runId=runId,
+            name=name,
+            timestamp=timestamp,
+            batchId=batchId,
+            batchDuration=batchDuration,
+            durationMs=durationMs,
+            eventTime=eventTime,
+            stateOperators=stateOperators,
+            sources=sources,
+            sink=sink,
+            numInputRows=numInputRows,
+            inputRowsPerSecond=inputRowsPerSecond,
+            processedRowsPerSecond=processedRowsPerSecond,
+            observedMetrics=observedMetrics,
+        )
         self._jprogress: Optional["JavaObject"] = jprogress
         self._jdict: Optional[Dict[str, Any]] = jdict
-        self._id: uuid.UUID = id
-        self._runId: uuid.UUID = runId
-        self._name: Optional[str] = name
-        self._timestamp: str = timestamp
-        self._batchId: int = batchId
-        self._batchDuration: int = batchDuration
-        self._durationMs: Dict[str, int] = durationMs
-        self._eventTime: Dict[str, str] = eventTime
-        self._stateOperators: List[StateOperatorProgress] = stateOperators
-        self._sources: List[SourceProgress] = sources
-        self._sink: SinkProgress = sink
-        self._numInputRows: int = numInputRows
-        self._inputRowsPerSecond: float = inputRowsPerSecond
-        self._processedRowsPerSecond: float = processedRowsPerSecond
-        self._observedMetrics: Dict[str, Row] = observedMetrics
 
     @classmethod
     def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress":
@@ -489,9 +494,11 @@ class StreamingQueryProgress:
             stateOperators=[StateOperatorProgress.fromJson(s) for s in 
j["stateOperators"]],
             sources=[SourceProgress.fromJson(s) for s in j["sources"]],
             sink=SinkProgress.fromJson(j["sink"]),
-            numInputRows=j["numInputRows"],
-            inputRowsPerSecond=j["inputRowsPerSecond"],
-            processedRowsPerSecond=j["processedRowsPerSecond"],
+            numInputRows=j["numInputRows"] if "numInputRows" in j else None,
+            inputRowsPerSecond=j["inputRowsPerSecond"] if "inputRowsPerSecond" 
in j else None,
+            processedRowsPerSecond=j["processedRowsPerSecond"]
+            if "processedRowsPerSecond" in j
+            else None,
             observedMetrics={
                 k: Row(*row_dict.keys())(*row_dict.values())  # Assume no 
nested rows
                 for k, row_dict in j["observedMetrics"].items()
@@ -506,7 +513,10 @@ class StreamingQueryProgress:
         A unique query id that persists across restarts. See
         py:meth:`~pyspark.sql.streaming.StreamingQuery.id`.
         """
-        return self._id
+        # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which 
casts id and runId
+        # to string. But here they are UUID.
+        # To prevent breaking change, do not cast them to string when accessed 
with attribute.
+        return super().__getitem__("id")
 
     @property
     def runId(self) -> uuid.UUID:
@@ -514,21 +524,24 @@ class StreamingQueryProgress:
         A query id that is unique for every start/restart. See
         py:meth:`~pyspark.sql.streaming.StreamingQuery.runId`.
         """
-        return self._runId
+        # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which 
casts id and runId
+        # to string. But here they are UUID.
+        # To prevent breaking change, do not cast them to string when accessed 
with attribute.
+        return super().__getitem__("runId")
 
     @property
     def name(self) -> Optional[str]:
         """
         User-specified name of the query, `None` if not specified.
         """
-        return self._name
+        return self["name"]
 
     @property
     def timestamp(self) -> str:
         """
         The timestamp to start a query.
         """
-        return self._timestamp
+        return self["timestamp"]
 
     @property
     def batchId(self) -> int:
@@ -538,21 +551,21 @@ class StreamingQueryProgress:
         Similarly, when there is no data to be processed, the batchId will not 
be
         incremented.
         """
-        return self._batchId
+        return self["batchId"]
 
     @property
     def batchDuration(self) -> int:
         """
         The process duration of each batch.
         """
-        return self._batchDuration
+        return self["batchDuration"]
 
     @property
     def durationMs(self) -> Dict[str, int]:
         """
         The amount of time taken to perform various operations in milliseconds.
         """
-        return self._durationMs
+        return self["durationMs"]
 
     @property
     def eventTime(self) -> Dict[str, str]:
@@ -570,21 +583,21 @@ class StreamingQueryProgress:
 
         All timestamps are in ISO8601 format, i.e. UTC timestamps.
         """
-        return self._eventTime
+        return self["eventTime"]
 
     @property
     def stateOperators(self) -> List["StateOperatorProgress"]:
         """
         Information about operators in the query that store state.
         """
-        return self._stateOperators
+        return self["stateOperators"]
 
     @property
     def sources(self) -> List["SourceProgress"]:
         """
         detailed statistics on data being read from each of the streaming 
sources.
         """
-        return self._sources
+        return self["sources"]
 
     @property
     def sink(self) -> "SinkProgress":
@@ -592,32 +605,41 @@ class StreamingQueryProgress:
         A unique query id that persists across restarts. See
         py:meth:`~pyspark.sql.streaming.StreamingQuery.id`.
         """
-        return self._sink
+        return self["sink"]
 
     @property
     def observedMetrics(self) -> Dict[str, Row]:
-        return self._observedMetrics
+        return self["observedMetrics"]
 
     @property
     def numInputRows(self) -> int:
         """
         The aggregate (across all sources) number of records processed in a 
trigger.
         """
-        return self._numInputRows
+        if self["numInputRows"] is not None:
+            return self["numInputRows"]
+        else:
+            return sum(s.numInputRows for s in self.sources)
 
     @property
     def inputRowsPerSecond(self) -> float:
         """
         The aggregate (across all sources) rate of data arriving.
         """
-        return self._inputRowsPerSecond
+        if self["inputRowsPerSecond"] is not None:
+            return self["inputRowsPerSecond"]
+        else:
+            return sum(s.inputRowsPerSecond for s in self.sources)
 
     @property
     def processedRowsPerSecond(self) -> float:
         """
         The aggregate (across all sources) rate at which Spark is processing 
data.
         """
-        return self._processedRowsPerSecond
+        if self["processedRowsPerSecond"] is not None:
+            return self["processedRowsPerSecond"]
+        else:
+            return sum(s.processedRowsPerSecond for s in self.sources)
 
     @property
     def json(self) -> str:
@@ -641,14 +663,29 @@ class StreamingQueryProgress:
         else:
             return json.dumps(self._jdict, indent=4)
 
+    def __getitem__(self, key: str) -> Any:
+        # Before Spark 4.0, StreamingQuery.lastProgress returns a dict, which 
casts id and runId
+        # to string. But here they are UUID.
+        # To prevent breaking change, also cast them to string when accessed 
with __getitem__.
+        if key == "id" or key == "runId":
+            return str(super().__getitem__(key))
+        else:
+            return super().__getitem__(key)
+
     def __str__(self) -> str:
         return self.prettyJson
 
+    def __repr__(self) -> str:
+        return self.prettyJson
+
 
-class StateOperatorProgress:
+class StateOperatorProgress(dict):
     """
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 4.0.0
+        Becomes a subclass of dict
+
     Notes
     -----
     This API is evolving.
@@ -671,20 +708,22 @@ class StateOperatorProgress:
         jprogress: Optional["JavaObject"] = None,
         jdict: Optional[Dict[str, Any]] = None,
     ):
+        super().__init__(
+            operatorName=operatorName,
+            numRowsTotal=numRowsTotal,
+            numRowsUpdated=numRowsUpdated,
+            numRowsRemoved=numRowsRemoved,
+            allUpdatesTimeMs=allUpdatesTimeMs,
+            allRemovalsTimeMs=allRemovalsTimeMs,
+            commitTimeMs=commitTimeMs,
+            memoryUsedBytes=memoryUsedBytes,
+            numRowsDroppedByWatermark=numRowsDroppedByWatermark,
+            numShufflePartitions=numShufflePartitions,
+            numStateStoreInstances=numStateStoreInstances,
+            customMetrics=customMetrics,
+        )
         self._jprogress: Optional["JavaObject"] = jprogress
         self._jdict: Optional[Dict[str, Any]] = jdict
-        self._operatorName: str = operatorName
-        self._numRowsTotal: int = numRowsTotal
-        self._numRowsUpdated: int = numRowsUpdated
-        self._numRowsRemoved: int = numRowsRemoved
-        self._allUpdatesTimeMs: int = allUpdatesTimeMs
-        self._allRemovalsTimeMs: int = allRemovalsTimeMs
-        self._commitTimeMs: int = commitTimeMs
-        self._memoryUsedBytes: int = memoryUsedBytes
-        self._numRowsDroppedByWatermark: int = numRowsDroppedByWatermark
-        self._numShufflePartitions: int = numShufflePartitions
-        self._numStateStoreInstances: int = numStateStoreInstances
-        self._customMetrics: Dict[str, int] = customMetrics
 
     @classmethod
     def fromJObject(cls, jprogress: "JavaObject") -> "StateOperatorProgress":
@@ -724,51 +763,51 @@ class StateOperatorProgress:
 
     @property
     def operatorName(self) -> str:
-        return self._operatorName
+        return self["operatorName"]
 
     @property
     def numRowsTotal(self) -> int:
-        return self._numRowsTotal
+        return self["numRowsTotal"]
 
     @property
     def numRowsUpdated(self) -> int:
-        return self._numRowsUpdated
+        return self["numRowsUpdated"]
 
     @property
     def allUpdatesTimeMs(self) -> int:
-        return self._allUpdatesTimeMs
+        return self["allUpdatesTimeMs"]
 
     @property
     def numRowsRemoved(self) -> int:
-        return self._numRowsRemoved
+        return self["numRowsRemoved"]
 
     @property
     def allRemovalsTimeMs(self) -> int:
-        return self._allRemovalsTimeMs
+        return self["allRemovalsTimeMs"]
 
     @property
     def commitTimeMs(self) -> int:
-        return self._commitTimeMs
+        return self["commitTimeMs"]
 
     @property
     def memoryUsedBytes(self) -> int:
-        return self._memoryUsedBytes
+        return self["memoryUsedBytes"]
 
     @property
     def numRowsDroppedByWatermark(self) -> int:
-        return self._numRowsDroppedByWatermark
+        return self["numRowsDroppedByWatermark"]
 
     @property
     def numShufflePartitions(self) -> int:
-        return self._numShufflePartitions
+        return self["numShufflePartitions"]
 
     @property
     def numStateStoreInstances(self) -> int:
-        return self._numStateStoreInstances
+        return self["numStateStoreInstances"]
 
     @property
-    def customMetrics(self) -> Dict[str, int]:
-        return self._customMetrics
+    def customMetrics(self) -> dict:
+        return self["customMetrics"]
 
     @property
     def json(self) -> str:
@@ -795,11 +834,17 @@ class StateOperatorProgress:
     def __str__(self) -> str:
         return self.prettyJson
 
+    def __repr__(self) -> str:
+        return self.prettyJson
+
 
-class SourceProgress:
+class SourceProgress(dict):
     """
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 4.0.0
+        Becomes a subclass of dict
+
     Notes
     -----
     This API is evolving.
@@ -818,16 +863,18 @@ class SourceProgress:
         jprogress: Optional["JavaObject"] = None,
         jdict: Optional[Dict[str, Any]] = None,
     ) -> None:
+        super().__init__(
+            description=description,
+            startOffset=startOffset,
+            endOffset=endOffset,
+            latestOffset=latestOffset,
+            numInputRows=numInputRows,
+            inputRowsPerSecond=inputRowsPerSecond,
+            processedRowsPerSecond=processedRowsPerSecond,
+            metrics=metrics,
+        )
         self._jprogress: Optional["JavaObject"] = jprogress
         self._jdict: Optional[Dict[str, Any]] = jdict
-        self._description: str = description
-        self._startOffset: str = startOffset
-        self._endOffset: str = endOffset
-        self._latestOffset: str = latestOffset
-        self._numInputRows: int = numInputRows
-        self._inputRowsPerSecond: float = inputRowsPerSecond
-        self._processedRowsPerSecond: float = processedRowsPerSecond
-        self._metrics: Dict[str, str] = metrics
 
     @classmethod
     def fromJObject(cls, jprogress: "JavaObject") -> "SourceProgress":
@@ -862,53 +909,53 @@ class SourceProgress:
         """
         Description of the source.
         """
-        return self._description
+        return self["description"]
 
     @property
     def startOffset(self) -> str:
         """
         The starting offset for data being read.
         """
-        return self._startOffset
+        return self["startOffset"]
 
     @property
     def endOffset(self) -> str:
         """
         The ending offset for data being read.
         """
-        return self._endOffset
+        return self["endOffset"]
 
     @property
     def latestOffset(self) -> str:
         """
         The latest offset from this source.
         """
-        return self._latestOffset
+        return self["latestOffset"]
 
     @property
     def numInputRows(self) -> int:
         """
         The number of records read from this source.
         """
-        return self._numInputRows
+        return self["numInputRows"]
 
     @property
     def inputRowsPerSecond(self) -> float:
         """
         The rate at which data is arriving from this source.
         """
-        return self._inputRowsPerSecond
+        return self["inputRowsPerSecond"]
 
     @property
     def processedRowsPerSecond(self) -> float:
         """
         The rate at which data from this source is being processed by Spark.
         """
-        return self._processedRowsPerSecond
+        return self["processedRowsPerSecond"]
 
     @property
-    def metrics(self) -> Dict[str, str]:
-        return self._metrics
+    def metrics(self) -> dict:
+        return self["metrics"]
 
     @property
     def json(self) -> str:
@@ -935,11 +982,17 @@ class SourceProgress:
     def __str__(self) -> str:
         return self.prettyJson
 
+    def __repr__(self) -> str:
+        return self.prettyJson
+
 
-class SinkProgress:
+class SinkProgress(dict):
     """
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 4.0.0
+        Becomes a subclass of dict
+
     Notes
     -----
     This API is evolving.
@@ -953,11 +1006,13 @@ class SinkProgress:
         jprogress: Optional["JavaObject"] = None,
         jdict: Optional[Dict[str, Any]] = None,
     ) -> None:
+        super().__init__(
+            description=description,
+            numOutputRows=numOutputRows,
+            metrics=metrics,
+        )
         self._jprogress: Optional["JavaObject"] = jprogress
         self._jdict: Optional[Dict[str, Any]] = jdict
-        self._description: str = description
-        self._numOutputRows: int = numOutputRows
-        self._metrics: Dict[str, str] = metrics
 
     @classmethod
     def fromJObject(cls, jprogress: "JavaObject") -> "SinkProgress":
@@ -982,7 +1037,7 @@ class SinkProgress:
         """
         Description of the source.
         """
-        return self._description
+        return self["description"]
 
     @property
     def numOutputRows(self) -> int:
@@ -990,11 +1045,11 @@ class SinkProgress:
         Number of rows written to the sink or -1 for Continuous Mode 
(temporarily)
         or Sink V1 (until decommissioned).
         """
-        return self._numOutputRows
+        return self["numOutputRows"]
 
     @property
     def metrics(self) -> Dict[str, str]:
-        return self._metrics
+        return self["metrics"]
 
     @property
     def json(self) -> str:
@@ -1021,6 +1076,9 @@ class SinkProgress:
     def __str__(self) -> str:
         return self.prettyJson
 
+    def __repr__(self) -> str:
+        return self.prettyJson
+
 
 def _test() -> None:
     import sys
diff --git a/python/pyspark/sql/streaming/query.py 
b/python/pyspark/sql/streaming/query.py
index d3d58da3562b..916f96a5b2c2 100644
--- a/python/pyspark/sql/streaming/query.py
+++ b/python/pyspark/sql/streaming/query.py
@@ -22,7 +22,10 @@ from pyspark.errors import StreamingQueryException, 
PySparkValueError
 from pyspark.errors.exceptions.captured import (
     StreamingQueryException as CapturedStreamingQueryException,
 )
-from pyspark.sql.streaming.listener import StreamingQueryListener
+from pyspark.sql.streaming.listener import (
+    StreamingQueryListener,
+    StreamingQueryProgress,
+)
 
 if TYPE_CHECKING:
     from py4j.java_gateway import JavaObject
@@ -251,7 +254,7 @@ class StreamingQuery:
         return json.loads(self._jsq.status().json())
 
     @property
-    def recentProgress(self) -> List[Dict[str, Any]]:
+    def recentProgress(self) -> List[StreamingQueryProgress]:
         """
         Returns an array of the most recent [[StreamingQueryProgress]] updates 
for this query.
         The number of progress updates retained for each stream is configured 
by Spark session
@@ -280,10 +283,10 @@ class StreamingQuery:
 
         >>> sq.stop()
         """
-        return [json.loads(p.json()) for p in self._jsq.recentProgress()]
+        return [StreamingQueryProgress.fromJObject(p) for p in 
self._jsq.recentProgress()]
 
     @property
-    def lastProgress(self) -> Optional[Dict[str, Any]]:
+    def lastProgress(self) -> Optional[StreamingQueryProgress]:
         """
         Returns the most recent :class:`StreamingQueryProgress` update of this 
streaming query or
         None if there were no progress updates
@@ -311,7 +314,7 @@ class StreamingQuery:
         """
         lastProgress = self._jsq.lastProgress()
         if lastProgress:
-            return json.loads(lastProgress.json())
+            return StreamingQueryProgress.fromJObject(lastProgress)
         else:
             return None
 
diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py 
b/python/pyspark/sql/tests/streaming/test_streaming.py
index e284d052d9ae..00d1fbf53885 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming.py
@@ -29,7 +29,7 @@ from pyspark.errors import PySparkValueError
 
 class StreamingTestsMixin:
     def test_streaming_query_functions_basic(self):
-        df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
+        df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
         query = (
             df.writeStream.format("memory")
             .queryName("test_streaming_query_functions_basic")
@@ -43,8 +43,8 @@ class StreamingTestsMixin:
             self.assertEqual(query.exception(), None)
             self.assertFalse(query.awaitTermination(1))
             query.processAllAvailable()
-            recentProgress = query.recentProgress
             lastProgress = query.lastProgress
+            recentProgress = query.recentProgress
             self.assertEqual(lastProgress["name"], query.name)
             self.assertEqual(lastProgress["id"], query.id)
             self.assertTrue(any(p == lastProgress for p in recentProgress))
@@ -59,6 +59,46 @@ class StreamingTestsMixin:
         finally:
             query.stop()
 
+    def test_streaming_progress(self):
+        """
+        Should be able to access fields using attributes in lastProgress / 
recentProgress
+        e.g. q.lastProgress.id
+        """
+        df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+        query = df.writeStream.format("noop").start()
+        try:
+            query.processAllAvailable()
+            lastProgress = query.lastProgress
+            recentProgress = query.recentProgress
+            self.assertEqual(lastProgress["name"], query.name)
+            # Return str when accessed using dict get.
+            self.assertEqual(lastProgress["id"], query.id)
+            # SPARK-48567 Use attribute to access fields in q.lastProgress
+            self.assertEqual(lastProgress.name, query.name)
+            # Return uuid when accessed using attribute.
+            self.assertEqual(str(lastProgress.id), query.id)
+            self.assertTrue(any(p == lastProgress for p in recentProgress))
+            self.assertTrue(lastProgress.numInputRows > 0)
+            # Also access source / sink progress with attributes
+            self.assertTrue(len(lastProgress.sources) > 0)
+            self.assertTrue(lastProgress.sources[0].numInputRows > 0)
+            self.assertTrue(lastProgress["sources"][0]["numInputRows"] > 0)
+            self.assertTrue(lastProgress.sink.numOutputRows > 0)
+            self.assertTrue(lastProgress["sink"]["numOutputRows"] > 0)
+            # In Python, for historical reasons, changing field value
+            # in StreamingQueryProgress is allowed.
+            new_name = "myNewQuery"
+            lastProgress["name"] = new_name
+            self.assertEqual(lastProgress.name, new_name)
+
+        except Exception as e:
+            self.fail(
+                "Streaming query functions sanity check shouldn't throw any 
error. "
+                "Error message: " + str(e)
+            )
+        finally:
+            query.stop()
+
     def test_streaming_query_name_edge_case(self):
         # Query name should be None when not specified
         q1 = 
self.spark.readStream.format("rate").load().writeStream.format("noop").start()
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py 
b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index 762fc335b56a..0f13450849c5 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -227,9 +227,9 @@ class StreamingListenerTestsMixin:
                 "my_event", count(lit(1)).alias("rc"), 
count(col("error")).alias("erc")
             )
 
-            q = observed_ds.writeStream.format("console").start()
+            q = observed_ds.writeStream.format("noop").start()
 
-            while q.lastProgress is None or q.lastProgress["batchId"] == 0:
+            while q.lastProgress is None or q.lastProgress.batchId == 0:
                 q.awaitTermination(0.5)
 
             time.sleep(5)
@@ -241,6 +241,32 @@ class StreamingListenerTestsMixin:
             q.stop()
             self.spark.streams.removeListener(error_listener)
 
+    def test_streaming_progress(self):
+        try:
+            # Test a fancier query with stateful operation and observed metrics
+            df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
+            df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+            df_stateful = df_observe.groupBy().count()  # make query stateful
+            q = (
+                df_stateful.writeStream.format("noop")
+                .queryName("test")
+                .outputMode("update")
+                .trigger(processingTime="5 seconds")
+                .start()
+            )
+
+            while q.lastProgress is None or q.lastProgress.batchId == 0:
+                q.awaitTermination(0.5)
+
+            q.stop()
+
+            self.check_streaming_query_progress(q.lastProgress, True)
+            for p in q.recentProgress:
+                self.check_streaming_query_progress(p, True)
+
+        finally:
+            q.stop()
+
 
 class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase):
     def test_number_of_public_methods(self):
@@ -355,7 +381,7 @@ class StreamingListenerTests(StreamingListenerTestsMixin, 
ReusedSQLTestCase):
                     .start()
                 )
                 self.assertTrue(q.isActive)
-                time.sleep(10)
+                q.awaitTermination(10)
                 q.stop()
 
                 # Make sure all events are empty


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to