This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 6c2da61b386 Revert "[SPARK-44435][SS][CONNECT] Tests for foreachBatch 
and Listener"
6c2da61b386 is described below

commit 6c2da61b386d905d05437e68a4b945b5ee9a3e90
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Thu Aug 24 10:35:20 2023 -0700

    Revert "[SPARK-44435][SS][CONNECT] Tests for foreachBatch and Listener"
    
    This reverts commit 311a497224db4a00b5cdf928cba8ef30545ee911.
---
 .../planner/StreamingForeachBatchHelper.scala      |   1 +
 .../streaming/worker/foreachBatch_worker.py        |   1 +
 .../connect/streaming/worker/listener_worker.py    |   1 +
 .../connect/streaming/test_parity_listener.py      |  57 +++++------
 .../tests/streaming/test_streaming_foreachBatch.py | 111 +--------------------
 5 files changed, 27 insertions(+), 144 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index ef7195439f9..21e4adb9896 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -113,6 +113,7 @@ object StreamingForeachBatchHelper extends Logging {
 
     val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
 
+      // TODO(SPARK-44460): Support Auth credentials
       // TODO(SPARK-44462): A new session id pointing to args.df.sparkSession 
needs to be created.
       //     This is because MicroBatch execution clones the session during 
start.
       //     The session attached to the foreachBatch dataframe is different 
from the one the one
diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
index 72037f1263d..cf61463cd68 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
@@ -51,6 +51,7 @@ def main(infile: IO, outfile: IO) -> None:
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
     spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
 
+    # TODO(SPARK-44460): Pass credentials.
     # TODO(SPARK-44461): Enable Process Isolation
 
     func = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index c026945767d..e1f4678e42f 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -59,6 +59,7 @@ def main(infile: IO, outfile: IO) -> None:
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
     spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
 
+    # TODO(SPARK-44460): Pass credentials.
     # TODO(SPARK-44461): Enable Process Isolation
 
     listener = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 5069a76cfdb..4bf58bf7807 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -18,31 +18,39 @@
 import unittest
 import time
 
-import pyspark.cloudpickle
 from pyspark.sql.tests.streaming.test_streaming_listener import 
StreamingListenerTestsMixin
-from pyspark.sql.streaming.listener import StreamingQueryListener
-from pyspark.sql.functions import count, lit
+from pyspark.sql.streaming.listener import StreamingQueryListener, 
QueryStartedEvent
+from pyspark.sql.types import StructType, StructField, StringType
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
+def get_start_event_schema():
+    return StructType(
+        [
+            StructField("id", StringType(), True),
+            StructField("runId", StringType(), True),
+            StructField("name", StringType(), True),
+            StructField("timestamp", StringType(), True),
+        ]
+    )
+
+
 class TestListener(StreamingQueryListener):
     def onQueryStarted(self, event):
-        e = pyspark.cloudpickle.dumps(event)
-        df = self.spark.createDataFrame(data=[(e,)])
-        df.write.mode("append").saveAsTable("listener_start_events")
+        df = self.spark.createDataFrame(
+            data=[(str(event.id), str(event.runId), event.name, 
event.timestamp)],
+            schema=get_start_event_schema(),
+        )
+        df.write.saveAsTable("listener_start_events")
 
     def onQueryProgress(self, event):
-        e = pyspark.cloudpickle.dumps(event)
-        df = self.spark.createDataFrame(data=[(e,)])
-        df.write.mode("append").saveAsTable("listener_progress_events")
+        pass
 
     def onQueryIdle(self, event):
         pass
 
     def onQueryTerminated(self, event):
-        e = pyspark.cloudpickle.dumps(event)
-        df = self.spark.createDataFrame(data=[(e,)])
-        df.write.mode("append").saveAsTable("listener_terminated_events")
+        pass
 
 
 class StreamingListenerParityTests(StreamingListenerTestsMixin, 
ReusedConnectTestCase):
@@ -57,36 +65,17 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
             time.sleep(30)
 
             df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
-            df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
-            df_stateful = df_observe.groupBy().count()  # make query stateful
-            q = (
-                df_stateful.writeStream.format("noop")
-                .queryName("test")
-                .outputMode("complete")
-                .start()
-            )
+            q = df.writeStream.format("noop").queryName("test").start()
 
             self.assertTrue(q.isActive)
             time.sleep(10)
-            self.assertTrue(q.lastProgress["batchId"] > 0)  # ensure at least 
one batch is ran
             q.stop()
-            self.assertFalse(q.isActive)
-
-            start_event = pyspark.cloudpickle.loads(
-                self.spark.read.table("listener_start_events").collect()[0][0]
-            )
-
-            progress_event = pyspark.cloudpickle.loads(
-                
self.spark.read.table("listener_progress_events").collect()[0][0]
-            )
 
-            terminated_event = pyspark.cloudpickle.loads(
-                
self.spark.read.table("listener_terminated_events").collect()[0][0]
+            start_event = QueryStartedEvent.fromJson(
+                
self.spark.read.table("listener_start_events").collect()[0].asDict()
             )
 
             self.check_start_event(start_event)
-            self.check_progress_event(progress_event)
-            self.check_terminated_event(terminated_event)
 
         finally:
             self.spark.streams.removeListener(test_listener)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py 
b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
index 65a0f6279fb..d4e185c3d85 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
@@ -16,12 +16,8 @@
 #
 
 import time
-from pyspark.sql.dataframe import DataFrame
-from pyspark.testing.sqlutils import ReusedSQLTestCase
-
 
-def my_test_function_1():
-    return 1
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class StreamingTestsForeachBatchMixin:
@@ -92,111 +88,6 @@ class StreamingTestsForeachBatchMixin:
         q.stop()
         self.assertIsNone(q.exception(), "No exception has to be propagated.")
 
-    def test_streaming_foreachBatch_spark_session(self):
-        table_name = "testTable_foreachBatch"
-
-        def func(df: DataFrame, batch_id: int):
-            if batch_id > 0:  # only process once
-                return
-            spark = df.sparkSession
-            df1 = spark.createDataFrame([("structured",), ("streaming",)])
-            df1.union(df).write.mode("append").saveAsTable(table_name)
-
-        df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
-        q = df.writeStream.foreachBatch(func).start()
-        q.processAllAvailable()
-        q.stop()
-
-        actual = self.spark.read.table(table_name)
-        df = (
-            self.spark.read.format("text")
-            .load(path="python/test_support/sql/streaming/")
-            .union(self.spark.createDataFrame([("structured",), 
("streaming",)]))
-        )
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-    def test_streaming_foreachBatch_path_access(self):
-        table_name = "testTable_foreachBatch_path"
-
-        def func(df: DataFrame, batch_id: int):
-            if batch_id > 0:  # only process once
-                return
-            spark = df.sparkSession
-            df1 = 
spark.read.format("text").load("python/test_support/sql/streaming")
-            df1.union(df).write.mode("append").saveAsTable(table_name)
-
-        df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
-        q = df.writeStream.foreachBatch(func).start()
-        q.processAllAvailable()
-        q.stop()
-
-        actual = self.spark.read.table(table_name)
-        df = 
self.spark.read.format("text").load(path="python/test_support/sql/streaming/")
-        df = df.union(df)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        # write to delta table?
-
-    @staticmethod
-    def my_test_function_2():
-        return 2
-
-    def test_streaming_foreachBatch_fuction_calling(self):
-        def my_test_function_3():
-            return 3
-
-        table_name = "testTable_foreachBatch_function"
-
-        def func(df: DataFrame, batch_id: int):
-            if batch_id > 0:  # only process once
-                return
-            spark = df.sparkSession
-            df1 = spark.createDataFrame(
-                [
-                    (my_test_function_1(),),
-                    (StreamingTestsForeachBatchMixin.my_test_function_2(),),
-                    (my_test_function_3(),),
-                ]
-            )
-            df1.write.mode("append").saveAsTable(table_name)
-
-        df = self.spark.readStream.format("rate").load()
-        q = df.writeStream.foreachBatch(func).start()
-        q.processAllAvailable()
-        q.stop()
-
-        actual = self.spark.read.table(table_name)
-        df = self.spark.createDataFrame(
-            [
-                (my_test_function_1(),),
-                (StreamingTestsForeachBatchMixin.my_test_function_2(),),
-                (my_test_function_3(),),
-            ]
-        )
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-    def test_streaming_foreachBatch_import(self):
-        import time  # not imported in foreachBatch_worker
-
-        table_name = "testTable_foreachBatch_import"
-
-        def func(df: DataFrame, batch_id: int):
-            if batch_id > 0:  # only process once
-                return
-            time.sleep(1)
-            spark = df.sparkSession
-            df1 = 
spark.read.format("text").load("python/test_support/sql/streaming")
-            df1.write.mode("append").saveAsTable(table_name)
-
-        df = self.spark.readStream.format("rate").load()
-        q = df.writeStream.foreachBatch(func).start()
-        q.processAllAvailable()
-        q.stop()
-
-        actual = self.spark.read.table(table_name)
-        df = 
self.spark.read.format("text").load("python/test_support/sql/streaming")
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
 
 class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, 
ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to