This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 7be69bf7da03 [SPARK-44971][PYTHON] StreamingQueryProgress event 
fromJson bug fix
7be69bf7da03 is described below

commit 7be69bf7da036282c2c7c0b62c32e7666fa1b579
Author: Wei Liu <wei....@databricks.com>
AuthorDate: Thu Aug 31 09:28:24 2023 +0900

    [SPARK-44971][PYTHON] StreamingQueryProgress event fromJson bug fix
    
    ### What changes were proposed in this pull request?
    
    The `fromJson` method for `StreamingQueryProgress` excepts the field 
`batchDuration` is in the dict.
    That method is used internally for converting a json representation of 
`StreamingQueryProgress` into python object, commonly created in the Scala side 
`json` method of the same object.
    
    But the `batchDuration` field is not there before 
https://github.com/apache/spark/pull/42077, which is only merged to 4.0. 
Therefore we add a catch there to prevent this method from failing.
    
    ### Why are the changes needed?
    
    Necessary bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #42686 from WweiL/SPARK-44971-fromJson-bugfix.
    
    Lead-authored-by: Wei Liu <wei....@databricks.com>
    Co-authored-by: Wei Liu <z920631...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/streaming/listener.py           |  2 +-
 .../connect/streaming/test_parity_listener.py      | 57 +++++++++++++---------
 .../sql/tests/streaming/test_streaming_listener.py |  2 +-
 3 files changed, 36 insertions(+), 25 deletions(-)

diff --git a/python/pyspark/sql/streaming/listener.py 
b/python/pyspark/sql/streaming/listener.py
index 225ad6d45afb..16f40396490c 100644
--- a/python/pyspark/sql/streaming/listener.py
+++ b/python/pyspark/sql/streaming/listener.py
@@ -477,7 +477,7 @@ class StreamingQueryProgress:
             name=j["name"],
             timestamp=j["timestamp"],
             batchId=j["batchId"],
-            batchDuration=j["batchDuration"],
+            batchDuration=j.get("batchDuration", None),
             durationMs=dict(j["durationMs"]) if "durationMs" in j else {},
             eventTime=dict(j["eventTime"]) if "eventTime" in j else {},
             stateOperators=[StateOperatorProgress.fromJson(s) for s in 
j["stateOperators"]],
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 4bf58bf7807b..5069a76cfdb7 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -18,39 +18,31 @@
 import unittest
 import time
 
+import pyspark.cloudpickle
 from pyspark.sql.tests.streaming.test_streaming_listener import 
StreamingListenerTestsMixin
-from pyspark.sql.streaming.listener import StreamingQueryListener, 
QueryStartedEvent
-from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.sql.streaming.listener import StreamingQueryListener
+from pyspark.sql.functions import count, lit
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-def get_start_event_schema():
-    return StructType(
-        [
-            StructField("id", StringType(), True),
-            StructField("runId", StringType(), True),
-            StructField("name", StringType(), True),
-            StructField("timestamp", StringType(), True),
-        ]
-    )
-
-
 class TestListener(StreamingQueryListener):
     def onQueryStarted(self, event):
-        df = self.spark.createDataFrame(
-            data=[(str(event.id), str(event.runId), event.name, 
event.timestamp)],
-            schema=get_start_event_schema(),
-        )
-        df.write.saveAsTable("listener_start_events")
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_start_events")
 
     def onQueryProgress(self, event):
-        pass
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_progress_events")
 
     def onQueryIdle(self, event):
         pass
 
     def onQueryTerminated(self, event):
-        pass
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_terminated_events")
 
 
 class StreamingListenerParityTests(StreamingListenerTestsMixin, 
ReusedConnectTestCase):
@@ -65,17 +57,36 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
             time.sleep(30)
 
             df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
-            q = df.writeStream.format("noop").queryName("test").start()
+            df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+            df_stateful = df_observe.groupBy().count()  # make query stateful
+            q = (
+                df_stateful.writeStream.format("noop")
+                .queryName("test")
+                .outputMode("complete")
+                .start()
+            )
 
             self.assertTrue(q.isActive)
             time.sleep(10)
+            self.assertTrue(q.lastProgress["batchId"] > 0)  # ensure at least 
one batch is ran
             q.stop()
+            self.assertFalse(q.isActive)
+
+            start_event = pyspark.cloudpickle.loads(
+                self.spark.read.table("listener_start_events").collect()[0][0]
+            )
+
+            progress_event = pyspark.cloudpickle.loads(
+                
self.spark.read.table("listener_progress_events").collect()[0][0]
+            )
 
-            start_event = QueryStartedEvent.fromJson(
-                
self.spark.read.table("listener_start_events").collect()[0].asDict()
+            terminated_event = pyspark.cloudpickle.loads(
+                
self.spark.read.table("listener_terminated_events").collect()[0][0]
             )
 
             self.check_start_event(start_event)
+            self.check_progress_event(progress_event)
+            self.check_terminated_event(terminated_event)
 
         finally:
             self.spark.streams.removeListener(test_listener)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py 
b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index cbbdc2955e59..87d0dae00d8b 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -88,7 +88,7 @@ class StreamingListenerTestsMixin:
         except Exception:
             self.fail("'%s' is not in ISO 8601 format.")
         self.assertTrue(isinstance(progress.batchId, int))
-        self.assertTrue(isinstance(progress.batchDuration, int))
+        self.assertTrue(progress.batchDuration is None or 
isinstance(progress.batchDuration, int))
         self.assertTrue(isinstance(progress.durationMs, dict))
         self.assertTrue(
             set(progress.durationMs.keys()).issubset(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to