This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 311a497224d [SPARK-44435][SS][CONNECT] Tests for foreachBatch and 
Listener
311a497224d is described below

commit 311a497224db4a00b5cdf928cba8ef30545ee911
Author: Wei Liu <wei....@databricks.com>
AuthorDate: Thu Aug 24 19:18:43 2023 +0900

    [SPARK-44435][SS][CONNECT] Tests for foreachBatch and Listener
    
    ### What changes were proposed in this pull request?
    
    Add several new test cases for streaming foreachBatch and streaming query 
listener events to test various scenarios.
    
    ### Why are the changes needed?
    
    More tests is better
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Test only change
    
    Closes #42521 from WweiL/SPARK-44435-tests-foreachBatch-listener.
    
    Authored-by: Wei Liu <wei....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 2d44848f12cf818a0fe54fb03075cd9cca485ecb)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../planner/StreamingForeachBatchHelper.scala      |   1 -
 .../streaming/worker/foreachBatch_worker.py        |   1 -
 .../connect/streaming/worker/listener_worker.py    |   1 -
 .../connect/streaming/test_parity_listener.py      |  57 ++++++-----
 .../tests/streaming/test_streaming_foreachBatch.py | 111 ++++++++++++++++++++-
 5 files changed, 144 insertions(+), 27 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 21e4adb9896..ef7195439f9 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -113,7 +113,6 @@ object StreamingForeachBatchHelper extends Logging {
 
     val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
 
-      // TODO(SPARK-44460): Support Auth credentials
       // TODO(SPARK-44462): A new session id pointing to args.df.sparkSession 
needs to be created.
       //     This is because MicroBatch execution clones the session during 
start.
       //     The session attached to the foreachBatch dataframe is different 
from the one the one
diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
index cf61463cd68..72037f1263d 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
@@ -51,7 +51,6 @@ def main(infile: IO, outfile: IO) -> None:
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
     spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
 
-    # TODO(SPARK-44460): Pass credentials.
     # TODO(SPARK-44461): Enable Process Isolation
 
     func = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index e1f4678e42f..c026945767d 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -59,7 +59,6 @@ def main(infile: IO, outfile: IO) -> None:
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
     spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
 
-    # TODO(SPARK-44460): Pass credentials.
     # TODO(SPARK-44461): Enable Process Isolation
 
     listener = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 4bf58bf7807..5069a76cfdb 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -18,39 +18,31 @@
 import unittest
 import time
 
+import pyspark.cloudpickle
 from pyspark.sql.tests.streaming.test_streaming_listener import 
StreamingListenerTestsMixin
-from pyspark.sql.streaming.listener import StreamingQueryListener, 
QueryStartedEvent
-from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.sql.streaming.listener import StreamingQueryListener
+from pyspark.sql.functions import count, lit
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-def get_start_event_schema():
-    return StructType(
-        [
-            StructField("id", StringType(), True),
-            StructField("runId", StringType(), True),
-            StructField("name", StringType(), True),
-            StructField("timestamp", StringType(), True),
-        ]
-    )
-
-
 class TestListener(StreamingQueryListener):
     def onQueryStarted(self, event):
-        df = self.spark.createDataFrame(
-            data=[(str(event.id), str(event.runId), event.name, 
event.timestamp)],
-            schema=get_start_event_schema(),
-        )
-        df.write.saveAsTable("listener_start_events")
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_start_events")
 
     def onQueryProgress(self, event):
-        pass
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_progress_events")
 
     def onQueryIdle(self, event):
         pass
 
     def onQueryTerminated(self, event):
-        pass
+        e = pyspark.cloudpickle.dumps(event)
+        df = self.spark.createDataFrame(data=[(e,)])
+        df.write.mode("append").saveAsTable("listener_terminated_events")
 
 
 class StreamingListenerParityTests(StreamingListenerTestsMixin, 
ReusedConnectTestCase):
@@ -65,17 +57,36 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
             time.sleep(30)
 
             df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
-            q = df.writeStream.format("noop").queryName("test").start()
+            df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+            df_stateful = df_observe.groupBy().count()  # make query stateful
+            q = (
+                df_stateful.writeStream.format("noop")
+                .queryName("test")
+                .outputMode("complete")
+                .start()
+            )
 
             self.assertTrue(q.isActive)
             time.sleep(10)
+            self.assertTrue(q.lastProgress["batchId"] > 0)  # ensure at least 
one batch is ran
             q.stop()
+            self.assertFalse(q.isActive)
+
+            start_event = pyspark.cloudpickle.loads(
+                self.spark.read.table("listener_start_events").collect()[0][0]
+            )
+
+            progress_event = pyspark.cloudpickle.loads(
+                
self.spark.read.table("listener_progress_events").collect()[0][0]
+            )
 
-            start_event = QueryStartedEvent.fromJson(
-                
self.spark.read.table("listener_start_events").collect()[0].asDict()
+            terminated_event = pyspark.cloudpickle.loads(
+                
self.spark.read.table("listener_terminated_events").collect()[0][0]
             )
 
             self.check_start_event(start_event)
+            self.check_progress_event(progress_event)
+            self.check_terminated_event(terminated_event)
 
         finally:
             self.spark.streams.removeListener(test_listener)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py 
b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
index d4e185c3d85..65a0f6279fb 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
@@ -16,10 +16,14 @@
 #
 
 import time
-
+from pyspark.sql.dataframe import DataFrame
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
+def my_test_function_1():
+    return 1
+
+
 class StreamingTestsForeachBatchMixin:
     def test_streaming_foreachBatch(self):
         q = None
@@ -88,6 +92,111 @@ class StreamingTestsForeachBatchMixin:
         q.stop()
         self.assertIsNone(q.exception(), "No exception has to be propagated.")
 
+    def test_streaming_foreachBatch_spark_session(self):
+        table_name = "testTable_foreachBatch"
+
+        def func(df: DataFrame, batch_id: int):
+            if batch_id > 0:  # only process once
+                return
+            spark = df.sparkSession
+            df1 = spark.createDataFrame([("structured",), ("streaming",)])
+            df1.union(df).write.mode("append").saveAsTable(table_name)
+
+        df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+        q = df.writeStream.foreachBatch(func).start()
+        q.processAllAvailable()
+        q.stop()
+
+        actual = self.spark.read.table(table_name)
+        df = (
+            self.spark.read.format("text")
+            .load(path="python/test_support/sql/streaming/")
+            .union(self.spark.createDataFrame([("structured",), 
("streaming",)]))
+        )
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+    def test_streaming_foreachBatch_path_access(self):
+        table_name = "testTable_foreachBatch_path"
+
+        def func(df: DataFrame, batch_id: int):
+            if batch_id > 0:  # only process once
+                return
+            spark = df.sparkSession
+            df1 = 
spark.read.format("text").load("python/test_support/sql/streaming")
+            df1.union(df).write.mode("append").saveAsTable(table_name)
+
+        df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+        q = df.writeStream.foreachBatch(func).start()
+        q.processAllAvailable()
+        q.stop()
+
+        actual = self.spark.read.table(table_name)
+        df = 
self.spark.read.format("text").load(path="python/test_support/sql/streaming/")
+        df = df.union(df)
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+        # write to delta table?
+
+    @staticmethod
+    def my_test_function_2():
+        return 2
+
+    def test_streaming_foreachBatch_fuction_calling(self):
+        def my_test_function_3():
+            return 3
+
+        table_name = "testTable_foreachBatch_function"
+
+        def func(df: DataFrame, batch_id: int):
+            if batch_id > 0:  # only process once
+                return
+            spark = df.sparkSession
+            df1 = spark.createDataFrame(
+                [
+                    (my_test_function_1(),),
+                    (StreamingTestsForeachBatchMixin.my_test_function_2(),),
+                    (my_test_function_3(),),
+                ]
+            )
+            df1.write.mode("append").saveAsTable(table_name)
+
+        df = self.spark.readStream.format("rate").load()
+        q = df.writeStream.foreachBatch(func).start()
+        q.processAllAvailable()
+        q.stop()
+
+        actual = self.spark.read.table(table_name)
+        df = self.spark.createDataFrame(
+            [
+                (my_test_function_1(),),
+                (StreamingTestsForeachBatchMixin.my_test_function_2(),),
+                (my_test_function_3(),),
+            ]
+        )
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+    def test_streaming_foreachBatch_import(self):
+        import time  # not imported in foreachBatch_worker
+
+        table_name = "testTable_foreachBatch_import"
+
+        def func(df: DataFrame, batch_id: int):
+            if batch_id > 0:  # only process once
+                return
+            time.sleep(1)
+            spark = df.sparkSession
+            df1 = 
spark.read.format("text").load("python/test_support/sql/streaming")
+            df1.write.mode("append").saveAsTable(table_name)
+
+        df = self.spark.readStream.format("rate").load()
+        q = df.writeStream.foreachBatch(func).start()
+        q.processAllAvailable()
+        q.stop()
+
+        actual = self.spark.read.table(table_name)
+        df = 
self.spark.read.format("text").load("python/test_support/sql/streaming")
+        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
 
 class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, 
ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to