This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 7be69bf7da03 [SPARK-44971][PYTHON] StreamingQueryProgress event fromJson bug fix 7be69bf7da03 is described below commit 7be69bf7da036282c2c7c0b62c32e7666fa1b579 Author: Wei Liu <wei....@databricks.com> AuthorDate: Thu Aug 31 09:28:24 2023 +0900 [SPARK-44971][PYTHON] StreamingQueryProgress event fromJson bug fix ### What changes were proposed in this pull request? The `fromJson` method for `StreamingQueryProgress` excepts the field `batchDuration` is in the dict. That method is used internally for converting a json representation of `StreamingQueryProgress` into python object, commonly created in the Scala side `json` method of the same object. But the `batchDuration` field is not there before https://github.com/apache/spark/pull/42077, which is only merged to 4.0. Therefore we add a catch there to prevent this method from failing. ### Why are the changes needed? Necessary bug fix ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #42686 from WweiL/SPARK-44971-fromJson-bugfix. Lead-authored-by: Wei Liu <wei....@databricks.com> Co-authored-by: Wei Liu <z920631...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/streaming/listener.py | 2 +- .../connect/streaming/test_parity_listener.py | 57 +++++++++++++--------- .../sql/tests/streaming/test_streaming_listener.py | 2 +- 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 225ad6d45afb..16f40396490c 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -477,7 +477,7 @@ class StreamingQueryProgress: name=j["name"], timestamp=j["timestamp"], batchId=j["batchId"], - batchDuration=j["batchDuration"], + batchDuration=j.get("batchDuration", None), durationMs=dict(j["durationMs"]) if "durationMs" in j else {}, eventTime=dict(j["eventTime"]) if "eventTime" in j else {}, stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 4bf58bf7807b..5069a76cfdb7 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -18,39 +18,31 @@ import unittest import time +import pyspark.cloudpickle from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin -from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.streaming.listener import StreamingQueryListener +from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase -def get_start_event_schema(): - return StructType( - [ - StructField("id", StringType(), True), - StructField("runId", StringType(), True), - StructField("name", StringType(), True), - StructField("timestamp", StringType(), True), - ] - ) - - class TestListener(StreamingQueryListener): def onQueryStarted(self, event): - df = self.spark.createDataFrame( - data=[(str(event.id), str(event.runId), event.name, event.timestamp)], - schema=get_start_event_schema(), - ) - df.write.saveAsTable("listener_start_events") + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_start_events") def onQueryProgress(self, event): - pass + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_progress_events") def onQueryIdle(self, event): pass def onQueryTerminated(self, event): - pass + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_terminated_events") class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): @@ -65,17 +57,36 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes time.sleep(30) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - q = df.writeStream.format("noop").queryName("test").start() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) self.assertTrue(q.isActive) time.sleep(10) + self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran q.stop() + self.assertFalse(q.isActive) + + start_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_start_events").collect()[0][0] + ) + + progress_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_progress_events").collect()[0][0] + ) - start_event = QueryStartedEvent.fromJson( - self.spark.read.table("listener_start_events").collect()[0].asDict() + terminated_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_terminated_events").collect()[0][0] ) self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) finally: self.spark.streams.removeListener(test_listener) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index cbbdc2955e59..87d0dae00d8b 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -88,7 +88,7 @@ class StreamingListenerTestsMixin: except Exception: self.fail("'%s' is not in ISO 8601 format.") self.assertTrue(isinstance(progress.batchId, int)) - self.assertTrue(isinstance(progress.batchDuration, int)) + self.assertTrue(progress.batchDuration is None or isinstance(progress.batchDuration, int)) self.assertTrue(isinstance(progress.durationMs, dict)) self.assertTrue( set(progress.durationMs.keys()).issubset( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org