This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a58362ecbbf7 [SPARK-46722][CONNECT] Add a test regarding to backward compatibility check for StreamingQueryListener in Spark Connect (Scala/PySpark) a58362ecbbf7 is described below commit a58362ecbbf7c4e5d5f848411834cf2a9ef298b3 Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Tue Jan 16 12:23:02 2024 +0900 [SPARK-46722][CONNECT] Add a test regarding to backward compatibility check for StreamingQueryListener in Spark Connect (Scala/PySpark) ### What changes were proposed in this pull request? This PR proposes to add a functionality to perform backward compatibility check for StreamingQueryListener in Spark Connect (both Scala and PySpark), specifically implementing onQueryIdle or not. ### Why are the changes needed? We missed to add backward compatibility test when introducing onQueryIdle, and it led to an issue in PySpark (https://issues.apache.org/jira/browse/SPARK-45631). We added the compatibility test in PySpark but didn't add it in Spark Connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Modified UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44736 from HeartSaVioR/SPARK-46722. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../sql/streaming/ClientStreamingQuerySuite.scala | 88 ++++++++++---- .../connect/streaming/test_parity_listener.py | 133 +++++++++++++-------- 2 files changed, 142 insertions(+), 79 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 91c562c0f98b..fd989b5da35c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.streaming import java.io.{File, FileWriter} import java.util.concurrent.TimeUnit -import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.scalatest.concurrent.Eventually.eventually @@ -32,7 +31,7 @@ import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession} import org.apache.spark.sql.functions.{col, udf, window} -import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryStartedEvent, QueryTerminatedEvent} +import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent} import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.util.SparkFileUtils @@ -354,9 +353,15 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { } test("streaming query listener") { + testStreamingQueryListener(new EventCollectorV1, "_v1") + testStreamingQueryListener(new EventCollectorV2, "_v2") + } + + private def testStreamingQueryListener( + listener: StreamingQueryListener, + tablePostfix: String): Unit = { assert(spark.streams.listListeners().length == 0) - val listener = new EventCollector spark.streams.addListener(listener) val q = spark.readStream @@ -370,11 +375,21 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { q.processAllAvailable() eventually(timeout(30.seconds)) { assert(q.isActive) - checkAnswer(spark.table("my_listener_table").toDF(), Seq(Row(1, 2), Row(4, 5))) + + assert(!spark.table(s"listener_start_events$tablePostfix").toDF().isEmpty) + assert(!spark.table(s"listener_progress_events$tablePostfix").toDF().isEmpty) } } finally { q.stop() - spark.sql("DROP TABLE IF EXISTS my_listener_table") + + eventually(timeout(30.seconds)) { + assert(!q.isActive) + assert(!spark.table(s"listener_terminated_events$tablePostfix").toDF().isEmpty) + } + + spark.sql(s"DROP TABLE IF EXISTS listener_start_events$tablePostfix") + spark.sql(s"DROP TABLE IF EXISTS listener_progress_events$tablePostfix") + spark.sql(s"DROP TABLE IF EXISTS listener_terminated_events$tablePostfix") } // List listeners after adding a new listener, length should be 1. @@ -382,7 +397,7 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { assert(listeners.length == 1) // Add listener1 as another instance of EventCollector and validate - val listener1 = new EventCollector + val listener1 = new EventCollectorV2 spark.streams.addListener(listener1) assert(spark.streams.listListeners().length == 2) spark.streams.removeListener(listener1) @@ -462,35 +477,56 @@ case class TestClass(value: Int) { override def toString: String = value.toString } -class EventCollector extends StreamingQueryListener { - @volatile var startEvent: QueryStartedEvent = null - @volatile var terminationEvent: QueryTerminatedEvent = null - @volatile var idleEvent: QueryIdleEvent = null +abstract class EventCollector extends StreamingQueryListener { + private lazy val spark = SparkSession.builder().getOrCreate() - private val _progressEvents = new mutable.Queue[StreamingQueryProgress] + protected def tablePostfix: String - def progressEvents: Seq[StreamingQueryProgress] = _progressEvents.synchronized { - _progressEvents.clone().toSeq + protected def handleOnQueryStarted(event: QueryStartedEvent): Unit = { + val df = spark.createDataFrame(Seq((event.json, 0))) + df.write.mode("append").saveAsTable(s"listener_start_events$tablePostfix") } - override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { - startEvent = event - val spark = SparkSession.builder().getOrCreate() - val df = spark.createDataFrame(Seq((1, 2), (4, 5))) - df.write.saveAsTable("my_listener_table") + protected def handleOnQueryProgress(event: QueryProgressEvent): Unit = { + val df = spark.createDataFrame(Seq((event.json, 0))) + df.write.mode("append").saveAsTable(s"listener_progress_events$tablePostfix") } - override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { - _progressEvents += event.progress + protected def handleOnQueryTerminated(event: QueryTerminatedEvent): Unit = { + val df = spark.createDataFrame(Seq((event.json, 0))) + df.write.mode("append").saveAsTable(s"listener_terminated_events$tablePostfix") } +} - override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = { - idleEvent = event - } +/** + * V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`, + * `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5. + */ +class EventCollectorV1 extends EventCollector { + override protected def tablePostfix: String = "_v1" - override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { - terminationEvent = event - } + override def onQueryStarted(event: QueryStartedEvent): Unit = handleOnQueryStarted(event) + + override def onQueryProgress(event: QueryProgressEvent): Unit = handleOnQueryProgress(event) + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = + handleOnQueryTerminated(event) +} + +/** + * V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+. + */ +class EventCollectorV2 extends EventCollector { + override protected def tablePostfix: String = "_v2" + + override def onQueryStarted(event: QueryStartedEvent): Unit = handleOnQueryStarted(event) + + override def onQueryProgress(event: QueryProgressEvent): Unit = handleOnQueryProgress(event) + + override def onQueryIdle(event: QueryIdleEvent): Unit = {} + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = + handleOnQueryTerminated(event) } class ForeachBatchFn(val viewName: String) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 4fc040642bed..412f49a3960b 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -26,16 +26,36 @@ from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase -class TestListener(StreamingQueryListener): +# V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`, +# `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5. +class TestListenerV1(StreamingQueryListener): def onQueryStarted(self, event): e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame(data=[(e,)]) - df.write.mode("append").saveAsTable("listener_start_events") + df.write.mode("append").saveAsTable("listener_start_events_v1") def onQueryProgress(self, event): e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame(data=[(e,)]) - df.write.mode("append").saveAsTable("listener_progress_events") + df.write.mode("append").saveAsTable("listener_progress_events_v1") + + def onQueryTerminated(self, event): + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_terminated_events_v1") + + +# V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+. +class TestListenerV2(StreamingQueryListener): + def onQueryStarted(self, event): + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_start_events_v2") + + def onQueryProgress(self, event): + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_progress_events_v2") def onQueryIdle(self, event): pass @@ -43,60 +63,67 @@ class TestListener(StreamingQueryListener): def onQueryTerminated(self, event): e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame(data=[(e,)]) - df.write.mode("append").saveAsTable("listener_terminated_events") + df.write.mode("append").saveAsTable("listener_terminated_events_v2") class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): def test_listener_events(self): - test_listener = TestListener() - - try: - self.spark.streams.addListener(test_listener) - - # This ensures the read socket on the server won't crash (i.e. because of timeout) - # when there hasn't been a new event for a long time - time.sleep(30) - - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - df_observe = df.observe("my_event", count(lit(1)).alias("rc")) - df_stateful = df_observe.groupBy().count() # make query stateful - q = ( - df_stateful.writeStream.format("noop") - .queryName("test") - .outputMode("complete") - .start() - ) - - self.assertTrue(q.isActive) - # ensure at least one batch is ran - while q.lastProgress is None or q.lastProgress["batchId"] == 0: - time.sleep(5) - q.stop() - self.assertFalse(q.isActive) - - time.sleep(60) # Sleep to make sure listener_terminated_events is written successfully - - start_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_start_events").collect()[0][0] - ) - - progress_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_progress_events").collect()[0][0] - ) - - terminated_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_terminated_events").collect()[0][0] - ) - - self.check_start_event(start_event) - self.check_progress_event(progress_event) - self.check_terminated_event(terminated_event) - - finally: - self.spark.streams.removeListener(test_listener) - - # Remove again to verify this won't throw any error - self.spark.streams.removeListener(test_listener) + def verify(test_listener, table_postfix): + try: + self.spark.streams.addListener(test_listener) + + # This ensures the read socket on the server won't crash (i.e. because of timeout) + # when there hasn't been a new event for a long time + time.sleep(30) + + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) + + self.assertTrue(q.isActive) + # ensure at least one batch is ran + while q.lastProgress is None or q.lastProgress["batchId"] == 0: + time.sleep(5) + q.stop() + self.assertFalse(q.isActive) + + # Sleep to make sure listener_terminated_events is written successfully + time.sleep(60) + + start_table_name = "listener_start_events" + table_postfix + progress_tbl_name = "listener_progress_events" + table_postfix + terminated_tbl_name = "listener_terminated_events" + table_postfix + + start_event = pyspark.cloudpickle.loads( + self.spark.read.table(start_table_name).collect()[0][0] + ) + + progress_event = pyspark.cloudpickle.loads( + self.spark.read.table(progress_tbl_name).collect()[0][0] + ) + + terminated_event = pyspark.cloudpickle.loads( + self.spark.read.table(terminated_tbl_name).collect()[0][0] + ) + + self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) + + finally: + self.spark.streams.removeListener(test_listener) + + # Remove again to verify this won't throw any error + self.spark.streams.removeListener(test_listener) + + verify(TestListenerV1(), "_v1") + verify(TestListenerV2(), "_v2") def test_accessing_spark_session(self): spark = self.spark --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org