This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 311a497224d [SPARK-44435][SS][CONNECT] Tests for foreachBatch and
Listener
311a497224d is described below
commit 311a497224db4a00b5cdf928cba8ef30545ee911
Author: Wei Liu <[email protected]>
AuthorDate: Thu Aug 24 19:18:43 2023 +0900
[SPARK-44435][SS][CONNECT] Tests for foreachBatch and Listener
### What changes were proposed in this pull request?
Add several new test cases for streaming foreachBatch and streaming query
listener events to test various scenarios.
### Why are the changes needed?
More tests is better
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Test only change
Closes #42521 from WweiL/SPARK-44435-tests-foreachBatch-listener.
Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 2d44848f12cf818a0fe54fb03075cd9cca485ecb)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../planner/StreamingForeachBatchHelper.scala | 1 -
.../streaming/worker/foreachBatch_worker.py | 1 -
.../connect/streaming/worker/listener_worker.py | 1 -
.../connect/streaming/test_parity_listener.py | 57 ++++++-----
.../tests/streaming/test_streaming_foreachBatch.py | 111 ++++++++++++++++++++-
5 files changed, 144 insertions(+), 27 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 21e4adb9896..ef7195439f9 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -113,7 +113,6 @@ object StreamingForeachBatchHelper extends Logging {
val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
- // TODO(SPARK-44460): Support Auth credentials
// TODO(SPARK-44462): A new session id pointing to args.df.sparkSession
needs to be created.
// This is because MicroBatch execution clones the session during
start.
// The session attached to the foreachBatch dataframe is different
from the one the one
diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
index cf61463cd68..72037f1263d 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
@@ -51,7 +51,6 @@ def main(infile: IO, outfile: IO) -> None:
spark_connect_session =
SparkSession.builder.remote(connect_url).getOrCreate()
spark_connect_session._client._session_id = session_id # type:
ignore[attr-defined]
- # TODO(SPARK-44460): Pass credentials.
# TODO(SPARK-44461): Enable Process Isolation
func = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index e1f4678e42f..c026945767d 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -59,7 +59,6 @@ def main(infile: IO, outfile: IO) -> None:
spark_connect_session =
SparkSession.builder.remote(connect_url).getOrCreate()
spark_connect_session._client._session_id = session_id # type:
ignore[attr-defined]
- # TODO(SPARK-44460): Pass credentials.
# TODO(SPARK-44461): Enable Process Isolation
listener = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 4bf58bf7807..5069a76cfdb 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -18,39 +18,31 @@
import unittest
import time
+import pyspark.cloudpickle
from pyspark.sql.tests.streaming.test_streaming_listener import
StreamingListenerTestsMixin
-from pyspark.sql.streaming.listener import StreamingQueryListener,
QueryStartedEvent
-from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.sql.streaming.listener import StreamingQueryListener
+from pyspark.sql.functions import count, lit
from pyspark.testing.connectutils import ReusedConnectTestCase
-def get_start_event_schema():
- return StructType(
- [
- StructField("id", StringType(), True),
- StructField("runId", StringType(), True),
- StructField("name", StringType(), True),
- StructField("timestamp", StringType(), True),
- ]
- )
-
-
class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
- df = self.spark.createDataFrame(
- data=[(str(event.id), str(event.runId), event.name,
event.timestamp)],
- schema=get_start_event_schema(),
- )
- df.write.saveAsTable("listener_start_events")
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_start_events")
def onQueryProgress(self, event):
- pass
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_progress_events")
def onQueryIdle(self, event):
pass
def onQueryTerminated(self, event):
- pass
+ e = pyspark.cloudpickle.dumps(event)
+ df = self.spark.createDataFrame(data=[(e,)])
+ df.write.mode("append").saveAsTable("listener_terminated_events")
class StreamingListenerParityTests(StreamingListenerTestsMixin,
ReusedConnectTestCase):
@@ -65,17 +57,36 @@ class
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
time.sleep(30)
df = self.spark.readStream.format("rate").option("rowsPerSecond",
10).load()
- q = df.writeStream.format("noop").queryName("test").start()
+ df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
+ df_stateful = df_observe.groupBy().count() # make query stateful
+ q = (
+ df_stateful.writeStream.format("noop")
+ .queryName("test")
+ .outputMode("complete")
+ .start()
+ )
self.assertTrue(q.isActive)
time.sleep(10)
+ self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least
one batch is ran
q.stop()
+ self.assertFalse(q.isActive)
+
+ start_event = pyspark.cloudpickle.loads(
+ self.spark.read.table("listener_start_events").collect()[0][0]
+ )
+
+ progress_event = pyspark.cloudpickle.loads(
+
self.spark.read.table("listener_progress_events").collect()[0][0]
+ )
- start_event = QueryStartedEvent.fromJson(
-
self.spark.read.table("listener_start_events").collect()[0].asDict()
+ terminated_event = pyspark.cloudpickle.loads(
+
self.spark.read.table("listener_terminated_events").collect()[0][0]
)
self.check_start_event(start_event)
+ self.check_progress_event(progress_event)
+ self.check_terminated_event(terminated_event)
finally:
self.spark.streams.removeListener(test_listener)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
index d4e185c3d85..65a0f6279fb 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
@@ -16,10 +16,14 @@
#
import time
-
+from pyspark.sql.dataframe import DataFrame
from pyspark.testing.sqlutils import ReusedSQLTestCase
+def my_test_function_1():
+ return 1
+
+
class StreamingTestsForeachBatchMixin:
def test_streaming_foreachBatch(self):
q = None
@@ -88,6 +92,111 @@ class StreamingTestsForeachBatchMixin:
q.stop()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
+ def test_streaming_foreachBatch_spark_session(self):
+ table_name = "testTable_foreachBatch"
+
+ def func(df: DataFrame, batch_id: int):
+ if batch_id > 0: # only process once
+ return
+ spark = df.sparkSession
+ df1 = spark.createDataFrame([("structured",), ("streaming",)])
+ df1.union(df).write.mode("append").saveAsTable(table_name)
+
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+ q = df.writeStream.foreachBatch(func).start()
+ q.processAllAvailable()
+ q.stop()
+
+ actual = self.spark.read.table(table_name)
+ df = (
+ self.spark.read.format("text")
+ .load(path="python/test_support/sql/streaming/")
+ .union(self.spark.createDataFrame([("structured",),
("streaming",)]))
+ )
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ def test_streaming_foreachBatch_path_access(self):
+ table_name = "testTable_foreachBatch_path"
+
+ def func(df: DataFrame, batch_id: int):
+ if batch_id > 0: # only process once
+ return
+ spark = df.sparkSession
+ df1 =
spark.read.format("text").load("python/test_support/sql/streaming")
+ df1.union(df).write.mode("append").saveAsTable(table_name)
+
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+ q = df.writeStream.foreachBatch(func).start()
+ q.processAllAvailable()
+ q.stop()
+
+ actual = self.spark.read.table(table_name)
+ df =
self.spark.read.format("text").load(path="python/test_support/sql/streaming/")
+ df = df.union(df)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ # write to delta table?
+
+ @staticmethod
+ def my_test_function_2():
+ return 2
+
+ def test_streaming_foreachBatch_fuction_calling(self):
+ def my_test_function_3():
+ return 3
+
+ table_name = "testTable_foreachBatch_function"
+
+ def func(df: DataFrame, batch_id: int):
+ if batch_id > 0: # only process once
+ return
+ spark = df.sparkSession
+ df1 = spark.createDataFrame(
+ [
+ (my_test_function_1(),),
+ (StreamingTestsForeachBatchMixin.my_test_function_2(),),
+ (my_test_function_3(),),
+ ]
+ )
+ df1.write.mode("append").saveAsTable(table_name)
+
+ df = self.spark.readStream.format("rate").load()
+ q = df.writeStream.foreachBatch(func).start()
+ q.processAllAvailable()
+ q.stop()
+
+ actual = self.spark.read.table(table_name)
+ df = self.spark.createDataFrame(
+ [
+ (my_test_function_1(),),
+ (StreamingTestsForeachBatchMixin.my_test_function_2(),),
+ (my_test_function_3(),),
+ ]
+ )
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ def test_streaming_foreachBatch_import(self):
+ import time # not imported in foreachBatch_worker
+
+ table_name = "testTable_foreachBatch_import"
+
+ def func(df: DataFrame, batch_id: int):
+ if batch_id > 0: # only process once
+ return
+ time.sleep(1)
+ spark = df.sparkSession
+ df1 =
spark.read.format("text").load("python/test_support/sql/streaming")
+ df1.write.mode("append").saveAsTable(table_name)
+
+ df = self.spark.readStream.format("rate").load()
+ q = df.writeStream.foreachBatch(func).start()
+ q.processAllAvailable()
+ q.stop()
+
+ actual = self.spark.read.table(table_name)
+ df =
self.spark.read.format("text").load("python/test_support/sql/streaming")
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin,
ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]