This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 311a497224d [SPARK-44435][SS][CONNECT] Tests for foreachBatch and Listener 311a497224d is described below commit 311a497224db4a00b5cdf928cba8ef30545ee911 Author: Wei Liu <wei....@databricks.com> AuthorDate: Thu Aug 24 19:18:43 2023 +0900 [SPARK-44435][SS][CONNECT] Tests for foreachBatch and Listener ### What changes were proposed in this pull request? Add several new test cases for streaming foreachBatch and streaming query listener events to test various scenarios. ### Why are the changes needed? More tests is better ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test only change Closes #42521 from WweiL/SPARK-44435-tests-foreachBatch-listener. Authored-by: Wei Liu <wei....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 2d44848f12cf818a0fe54fb03075cd9cca485ecb) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../planner/StreamingForeachBatchHelper.scala | 1 - .../streaming/worker/foreachBatch_worker.py | 1 - .../connect/streaming/worker/listener_worker.py | 1 - .../connect/streaming/test_parity_listener.py | 57 ++++++----- .../tests/streaming/test_streaming_foreachBatch.py | 111 ++++++++++++++++++++- 5 files changed, 144 insertions(+), 27 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 21e4adb9896..ef7195439f9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -113,7 +113,6 @@ object StreamingForeachBatchHelper extends Logging { val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { - // TODO(SPARK-44460): Support Auth credentials // TODO(SPARK-44462): A new session id pointing to args.df.sparkSession needs to be created. // This is because MicroBatch execution clones the session during start. // The session attached to the foreachBatch dataframe is different from the one the one diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index cf61463cd68..72037f1263d 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -51,7 +51,6 @@ def main(infile: IO, outfile: IO) -> None: spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] - # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation func = worker.read_command(pickle_ser, infile) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index e1f4678e42f..c026945767d 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -59,7 +59,6 @@ def main(infile: IO, outfile: IO) -> None: spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] - # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation listener = worker.read_command(pickle_ser, infile) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 4bf58bf7807..5069a76cfdb 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -18,39 +18,31 @@ import unittest import time +import pyspark.cloudpickle from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin -from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.streaming.listener import StreamingQueryListener +from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase -def get_start_event_schema(): - return StructType( - [ - StructField("id", StringType(), True), - StructField("runId", StringType(), True), - StructField("name", StringType(), True), - StructField("timestamp", StringType(), True), - ] - ) - - class TestListener(StreamingQueryListener): def onQueryStarted(self, event): - df = self.spark.createDataFrame( - data=[(str(event.id), str(event.runId), event.name, event.timestamp)], - schema=get_start_event_schema(), - ) - df.write.saveAsTable("listener_start_events") + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_start_events") def onQueryProgress(self, event): - pass + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_progress_events") def onQueryIdle(self, event): pass def onQueryTerminated(self, event): - pass + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_terminated_events") class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): @@ -65,17 +57,36 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes time.sleep(30) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - q = df.writeStream.format("noop").queryName("test").start() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) self.assertTrue(q.isActive) time.sleep(10) + self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran q.stop() + self.assertFalse(q.isActive) + + start_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_start_events").collect()[0][0] + ) + + progress_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_progress_events").collect()[0][0] + ) - start_event = QueryStartedEvent.fromJson( - self.spark.read.table("listener_start_events").collect()[0].asDict() + terminated_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_terminated_events").collect()[0][0] ) self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) finally: self.spark.streams.removeListener(test_listener) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py index d4e185c3d85..65a0f6279fb 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -16,10 +16,14 @@ # import time - +from pyspark.sql.dataframe import DataFrame from pyspark.testing.sqlutils import ReusedSQLTestCase +def my_test_function_1(): + return 1 + + class StreamingTestsForeachBatchMixin: def test_streaming_foreachBatch(self): q = None @@ -88,6 +92,111 @@ class StreamingTestsForeachBatchMixin: q.stop() self.assertIsNone(q.exception(), "No exception has to be propagated.") + def test_streaming_foreachBatch_spark_session(self): + table_name = "testTable_foreachBatch" + + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + spark = df.sparkSession + df1 = spark.createDataFrame([("structured",), ("streaming",)]) + df1.union(df).write.mode("append").saveAsTable(table_name) + + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = ( + self.spark.read.format("text") + .load(path="python/test_support/sql/streaming/") + .union(self.spark.createDataFrame([("structured",), ("streaming",)])) + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + def test_streaming_foreachBatch_path_access(self): + table_name = "testTable_foreachBatch_path" + + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + spark = df.sparkSession + df1 = spark.read.format("text").load("python/test_support/sql/streaming") + df1.union(df).write.mode("append").saveAsTable(table_name) + + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/") + df = df.union(df) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + # write to delta table? + + @staticmethod + def my_test_function_2(): + return 2 + + def test_streaming_foreachBatch_fuction_calling(self): + def my_test_function_3(): + return 3 + + table_name = "testTable_foreachBatch_function" + + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + spark = df.sparkSession + df1 = spark.createDataFrame( + [ + (my_test_function_1(),), + (StreamingTestsForeachBatchMixin.my_test_function_2(),), + (my_test_function_3(),), + ] + ) + df1.write.mode("append").saveAsTable(table_name) + + df = self.spark.readStream.format("rate").load() + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = self.spark.createDataFrame( + [ + (my_test_function_1(),), + (StreamingTestsForeachBatchMixin.my_test_function_2(),), + (my_test_function_3(),), + ] + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + def test_streaming_foreachBatch_import(self): + import time # not imported in foreachBatch_worker + + table_name = "testTable_foreachBatch_import" + + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return + time.sleep(1) + spark = df.sparkSession + df1 = spark.read.format("text").load("python/test_support/sql/streaming") + df1.write.mode("append").saveAsTable(table_name) + + df = self.spark.readStream.format("rate").load() + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = self.spark.read.format("text").load("python/test_support/sql/streaming") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase): pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org