This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a58362ecbbf7 [SPARK-46722][CONNECT] Add a test regarding to backward
compatibility check for StreamingQueryListener in Spark Connect (Scala/PySpark)
a58362ecbbf7 is described below
commit a58362ecbbf7c4e5d5f848411834cf2a9ef298b3
Author: Jungtaek Lim <[email protected]>
AuthorDate: Tue Jan 16 12:23:02 2024 +0900
[SPARK-46722][CONNECT] Add a test regarding to backward compatibility check
for StreamingQueryListener in Spark Connect (Scala/PySpark)
### What changes were proposed in this pull request?
This PR proposes to add a functionality to perform backward compatibility
check for StreamingQueryListener in Spark Connect (both Scala and PySpark),
specifically implementing onQueryIdle or not.
### Why are the changes needed?
We missed to add backward compatibility test when introducing onQueryIdle,
and it led to an issue in PySpark
(https://issues.apache.org/jira/browse/SPARK-45631). We added the compatibility
test in PySpark but didn't add it in Spark Connect.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Modified UTs.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44736 from HeartSaVioR/SPARK-46722.
Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../sql/streaming/ClientStreamingQuerySuite.scala | 88 ++++++++++----
.../connect/streaming/test_parity_listener.py | 133 +++++++++++++--------
2 files changed, 142 insertions(+), 79 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index 91c562c0f98b..fd989b5da35c 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.streaming
import java.io.{File, FileWriter}
import java.util.concurrent.TimeUnit
-import scala.collection.mutable
import scala.jdk.CollectionConverters._
import org.scalatest.concurrent.Eventually.eventually
@@ -32,7 +31,7 @@ import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf, window}
-import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent,
QueryStartedEvent, QueryTerminatedEvent}
+import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent,
QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.util.SparkFileUtils
@@ -354,9 +353,15 @@ class ClientStreamingQuerySuite extends QueryTest with
SQLHelper with Logging {
}
test("streaming query listener") {
+ testStreamingQueryListener(new EventCollectorV1, "_v1")
+ testStreamingQueryListener(new EventCollectorV2, "_v2")
+ }
+
+ private def testStreamingQueryListener(
+ listener: StreamingQueryListener,
+ tablePostfix: String): Unit = {
assert(spark.streams.listListeners().length == 0)
- val listener = new EventCollector
spark.streams.addListener(listener)
val q = spark.readStream
@@ -370,11 +375,21 @@ class ClientStreamingQuerySuite extends QueryTest with
SQLHelper with Logging {
q.processAllAvailable()
eventually(timeout(30.seconds)) {
assert(q.isActive)
- checkAnswer(spark.table("my_listener_table").toDF(), Seq(Row(1, 2),
Row(4, 5)))
+
+
assert(!spark.table(s"listener_start_events$tablePostfix").toDF().isEmpty)
+
assert(!spark.table(s"listener_progress_events$tablePostfix").toDF().isEmpty)
}
} finally {
q.stop()
- spark.sql("DROP TABLE IF EXISTS my_listener_table")
+
+ eventually(timeout(30.seconds)) {
+ assert(!q.isActive)
+
assert(!spark.table(s"listener_terminated_events$tablePostfix").toDF().isEmpty)
+ }
+
+ spark.sql(s"DROP TABLE IF EXISTS listener_start_events$tablePostfix")
+ spark.sql(s"DROP TABLE IF EXISTS listener_progress_events$tablePostfix")
+ spark.sql(s"DROP TABLE IF EXISTS
listener_terminated_events$tablePostfix")
}
// List listeners after adding a new listener, length should be 1.
@@ -382,7 +397,7 @@ class ClientStreamingQuerySuite extends QueryTest with
SQLHelper with Logging {
assert(listeners.length == 1)
// Add listener1 as another instance of EventCollector and validate
- val listener1 = new EventCollector
+ val listener1 = new EventCollectorV2
spark.streams.addListener(listener1)
assert(spark.streams.listListeners().length == 2)
spark.streams.removeListener(listener1)
@@ -462,35 +477,56 @@ case class TestClass(value: Int) {
override def toString: String = value.toString
}
-class EventCollector extends StreamingQueryListener {
- @volatile var startEvent: QueryStartedEvent = null
- @volatile var terminationEvent: QueryTerminatedEvent = null
- @volatile var idleEvent: QueryIdleEvent = null
+abstract class EventCollector extends StreamingQueryListener {
+ private lazy val spark = SparkSession.builder().getOrCreate()
- private val _progressEvents = new mutable.Queue[StreamingQueryProgress]
+ protected def tablePostfix: String
- def progressEvents: Seq[StreamingQueryProgress] =
_progressEvents.synchronized {
- _progressEvents.clone().toSeq
+ protected def handleOnQueryStarted(event: QueryStartedEvent): Unit = {
+ val df = spark.createDataFrame(Seq((event.json, 0)))
+ df.write.mode("append").saveAsTable(s"listener_start_events$tablePostfix")
}
- override def onQueryStarted(event:
StreamingQueryListener.QueryStartedEvent): Unit = {
- startEvent = event
- val spark = SparkSession.builder().getOrCreate()
- val df = spark.createDataFrame(Seq((1, 2), (4, 5)))
- df.write.saveAsTable("my_listener_table")
+ protected def handleOnQueryProgress(event: QueryProgressEvent): Unit = {
+ val df = spark.createDataFrame(Seq((event.json, 0)))
+
df.write.mode("append").saveAsTable(s"listener_progress_events$tablePostfix")
}
- override def onQueryProgress(event:
StreamingQueryListener.QueryProgressEvent): Unit = {
- _progressEvents += event.progress
+ protected def handleOnQueryTerminated(event: QueryTerminatedEvent): Unit = {
+ val df = spark.createDataFrame(Seq((event.json, 0)))
+
df.write.mode("append").saveAsTable(s"listener_terminated_events$tablePostfix")
}
+}
- override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit
= {
- idleEvent = event
- }
+/**
+ * V1: Initial interface of StreamingQueryListener containing methods
`onQueryStarted`,
+ * `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
+ */
+class EventCollectorV1 extends EventCollector {
+ override protected def tablePostfix: String = "_v1"
- override def onQueryTerminated(event:
StreamingQueryListener.QueryTerminatedEvent): Unit = {
- terminationEvent = event
- }
+ override def onQueryStarted(event: QueryStartedEvent): Unit =
handleOnQueryStarted(event)
+
+ override def onQueryProgress(event: QueryProgressEvent): Unit =
handleOnQueryProgress(event)
+
+ override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
+ handleOnQueryTerminated(event)
+}
+
+/**
+ * V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
+ */
+class EventCollectorV2 extends EventCollector {
+ override protected def tablePostfix: String = "_v2"
+
+ override def onQueryStarted(event: QueryStartedEvent): Unit =
handleOnQueryStarted(event)
+
+ override def onQueryProgress(event: QueryProgressEvent): Unit =
handleOnQueryProgress(event)
+
+ override def onQueryIdle(event: QueryIdleEvent): Unit = {}
+
+ override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
+ handleOnQueryTerminated(event)
}
class ForeachBatchFn(val viewName: String)
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 4fc040642bed..412f49a3960b 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -26,16 +26,36 @@ from pyspark.sql.functions import count, lit
from pyspark.testing.connectutils import ReusedConnectTestCase
-class TestListener(StreamingQueryListener):
+# V1: Initial interface of StreamingQueryListener containing methods
`onQueryStarted`,
+# `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
+class TestListenerV1(StreamingQueryListener):
def onQueryStarted(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
- df.write.mode("append").saveAsTable("listener_start_events")
+ df.write.mode("append").saveAsTable("listener_start_events_v1")
def onQueryProgress(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
- df.write.mode("append").saveAsTable("listener_progress_events")
+ df.write.mode("append").saveAsTable("listener_progress_events_v1")
+
+ def onQueryTerminated(self, event):
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_terminated_events_v1")
+
+
+# V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
+class TestListenerV2(StreamingQueryListener):
+ def onQueryStarted(self, event):
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_start_events_v2")
+
+ def onQueryProgress(self, event):
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_progress_events_v2")
def onQueryIdle(self, event):
pass
@@ -43,60 +63,67 @@ class TestListener(StreamingQueryListener):
def onQueryTerminated(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
- df.write.mode("append").saveAsTable("listener_terminated_events")
+ df.write.mode("append").saveAsTable("listener_terminated_events_v2")
class StreamingListenerParityTests(StreamingListenerTestsMixin,
ReusedConnectTestCase):
def test_listener_events(self):
- test_listener = TestListener()
-
- try:
- self.spark.streams.addListener(test_listener)
-
- # This ensures the read socket on the server won't crash (i.e.
because of timeout)
- # when there hasn't been a new event for a long time
- time.sleep(30)
-
- df = self.spark.readStream.format("rate").option("rowsPerSecond",
10).load()
- df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
- df_stateful = df_observe.groupBy().count() # make query stateful
- q = (
- df_stateful.writeStream.format("noop")
- .queryName("test")
- .outputMode("complete")
- .start()
- )
-
- self.assertTrue(q.isActive)
- # ensure at least one batch is ran
- while q.lastProgress is None or q.lastProgress["batchId"] == 0:
- time.sleep(5)
- q.stop()
- self.assertFalse(q.isActive)
-
- time.sleep(60) # Sleep to make sure listener_terminated_events is
written successfully
-
- start_event = pyspark.cloudpickle.loads(
- self.spark.read.table("listener_start_events").collect()[0][0]
- )
-
- progress_event = pyspark.cloudpickle.loads(
-
self.spark.read.table("listener_progress_events").collect()[0][0]
- )
-
- terminated_event = pyspark.cloudpickle.loads(
-
self.spark.read.table("listener_terminated_events").collect()[0][0]
- )
-
- self.check_start_event(start_event)
- self.check_progress_event(progress_event)
- self.check_terminated_event(terminated_event)
-
- finally:
- self.spark.streams.removeListener(test_listener)
-
- # Remove again to verify this won't throw any error
- self.spark.streams.removeListener(test_listener)
+ def verify(test_listener, table_postfix):
+ try:
+ self.spark.streams.addListener(test_listener)
+
+ # This ensures the read socket on the server won't crash (i.e.
because of timeout)
+ # when there hasn't been a new event for a long time
+ time.sleep(30)
+
+ df =
self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
+ df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+ df_stateful = df_observe.groupBy().count() # make query
stateful
+ q = (
+ df_stateful.writeStream.format("noop")
+ .queryName("test")
+ .outputMode("complete")
+ .start()
+ )
+
+ self.assertTrue(q.isActive)
+ # ensure at least one batch is ran
+ while q.lastProgress is None or q.lastProgress["batchId"] == 0:
+ time.sleep(5)
+ q.stop()
+ self.assertFalse(q.isActive)
+
+ # Sleep to make sure listener_terminated_events is written
successfully
+ time.sleep(60)
+
+ start_table_name = "listener_start_events" + table_postfix
+ progress_tbl_name = "listener_progress_events" + table_postfix
+ terminated_tbl_name = "listener_terminated_events" +
table_postfix
+
+ start_event = pyspark.cloudpickle.loads(
+ self.spark.read.table(start_table_name).collect()[0][0]
+ )
+
+ progress_event = pyspark.cloudpickle.loads(
+ self.spark.read.table(progress_tbl_name).collect()[0][0]
+ )
+
+ terminated_event = pyspark.cloudpickle.loads(
+ self.spark.read.table(terminated_tbl_name).collect()[0][0]
+ )
+
+ self.check_start_event(start_event)
+ self.check_progress_event(progress_event)
+ self.check_terminated_event(terminated_event)
+
+ finally:
+ self.spark.streams.removeListener(test_listener)
+
+ # Remove again to verify this won't throw any error
+ self.spark.streams.removeListener(test_listener)
+
+ verify(TestListenerV1(), "_v1")
+ verify(TestListenerV2(), "_v2")
def test_accessing_spark_session(self):
spark = self.spark
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]