This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 7be69bf7da03 [SPARK-44971][PYTHON] StreamingQueryProgress event
fromJson bug fix
7be69bf7da03 is described below
commit 7be69bf7da036282c2c7c0b62c32e7666fa1b579
Author: Wei Liu <[email protected]>
AuthorDate: Thu Aug 31 09:28:24 2023 +0900
[SPARK-44971][PYTHON] StreamingQueryProgress event fromJson bug fix
### What changes were proposed in this pull request?
The `fromJson` method for `StreamingQueryProgress` excepts the field
`batchDuration` is in the dict.
That method is used internally for converting a json representation of
`StreamingQueryProgress` into python object, commonly created in the Scala side
`json` method of the same object.
But the `batchDuration` field is not there before
https://github.com/apache/spark/pull/42077, which is only merged to 4.0.
Therefore we add a catch there to prevent this method from failing.
### Why are the changes needed?
Necessary bug fix
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #42686 from WweiL/SPARK-44971-fromJson-bugfix.
Lead-authored-by: Wei Liu <[email protected]>
Co-authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/streaming/listener.py | 2 +-
.../connect/streaming/test_parity_listener.py | 57 +++++++++++++---------
.../sql/tests/streaming/test_streaming_listener.py | 2 +-
3 files changed, 36 insertions(+), 25 deletions(-)
diff --git a/python/pyspark/sql/streaming/listener.py
b/python/pyspark/sql/streaming/listener.py
index 225ad6d45afb..16f40396490c 100644
--- a/python/pyspark/sql/streaming/listener.py
+++ b/python/pyspark/sql/streaming/listener.py
@@ -477,7 +477,7 @@ class StreamingQueryProgress:
name=j["name"],
timestamp=j["timestamp"],
batchId=j["batchId"],
- batchDuration=j["batchDuration"],
+ batchDuration=j.get("batchDuration", None),
durationMs=dict(j["durationMs"]) if "durationMs" in j else {},
eventTime=dict(j["eventTime"]) if "eventTime" in j else {},
stateOperators=[StateOperatorProgress.fromJson(s) for s in
j["stateOperators"]],
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 4bf58bf7807b..5069a76cfdb7 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -18,39 +18,31 @@
import unittest
import time
+import pyspark.cloudpickle
from pyspark.sql.tests.streaming.test_streaming_listener import
StreamingListenerTestsMixin
-from pyspark.sql.streaming.listener import StreamingQueryListener,
QueryStartedEvent
-from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.sql.streaming.listener import StreamingQueryListener
+from pyspark.sql.functions import count, lit
from pyspark.testing.connectutils import ReusedConnectTestCase
-def get_start_event_schema():
- return StructType(
- [
- StructField("id", StringType(), True),
- StructField("runId", StringType(), True),
- StructField("name", StringType(), True),
- StructField("timestamp", StringType(), True),
- ]
- )
-
-
class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
- df = self.spark.createDataFrame(
- data=[(str(event.id), str(event.runId), event.name,
event.timestamp)],
- schema=get_start_event_schema(),
- )
- df.write.saveAsTable("listener_start_events")
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_start_events")
def onQueryProgress(self, event):
- pass
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_progress_events")
def onQueryIdle(self, event):
pass
def onQueryTerminated(self, event):
- pass
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_terminated_events")
class StreamingListenerParityTests(StreamingListenerTestsMixin,
ReusedConnectTestCase):
@@ -65,17 +57,36 @@ class
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
time.sleep(30)
df = self.spark.readStream.format("rate").option("rowsPerSecond",
10).load()
- q = df.writeStream.format("noop").queryName("test").start()
+ df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+ df_stateful = df_observe.groupBy().count() # make query stateful
+ q = (
+ df_stateful.writeStream.format("noop")
+ .queryName("test")
+ .outputMode("complete")
+ .start()
+ )
self.assertTrue(q.isActive)
time.sleep(10)
+ self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least
one batch is ran
q.stop()
+ self.assertFalse(q.isActive)
+
+ start_event = pyspark.cloudpickle.loads(
+ self.spark.read.table("listener_start_events").collect()[0][0]
+ )
+
+ progress_event = pyspark.cloudpickle.loads(
+
self.spark.read.table("listener_progress_events").collect()[0][0]
+ )
- start_event = QueryStartedEvent.fromJson(
-
self.spark.read.table("listener_start_events").collect()[0].asDict()
+ terminated_event = pyspark.cloudpickle.loads(
+
self.spark.read.table("listener_terminated_events").collect()[0][0]
)
self.check_start_event(start_event)
+ self.check_progress_event(progress_event)
+ self.check_terminated_event(terminated_event)
finally:
self.spark.streams.removeListener(test_listener)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index cbbdc2955e59..87d0dae00d8b 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -88,7 +88,7 @@ class StreamingListenerTestsMixin:
except Exception:
self.fail("'%s' is not in ISO 8601 format.")
self.assertTrue(isinstance(progress.batchId, int))
- self.assertTrue(isinstance(progress.batchDuration, int))
+ self.assertTrue(progress.batchDuration is None or
isinstance(progress.batchDuration, int))
self.assertTrue(isinstance(progress.durationMs, dict))
self.assertTrue(
set(progress.durationMs.keys()).issubset(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]