This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 6c2da61b386 Revert "[SPARK-44435][SS][CONNECT] Tests for foreachBatch
and Listener"
6c2da61b386 is described below
commit 6c2da61b386d905d05437e68a4b945b5ee9a3e90
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Thu Aug 24 10:35:20 2023 -0700
Revert "[SPARK-44435][SS][CONNECT] Tests for foreachBatch and Listener"
This reverts commit 311a497224db4a00b5cdf928cba8ef30545ee911.
---
.../planner/StreamingForeachBatchHelper.scala | 1 +
.../streaming/worker/foreachBatch_worker.py | 1 +
.../connect/streaming/worker/listener_worker.py | 1 +
.../connect/streaming/test_parity_listener.py | 57 +++++------
.../tests/streaming/test_streaming_foreachBatch.py | 111 +--------------------
5 files changed, 27 insertions(+), 144 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index ef7195439f9..21e4adb9896 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -113,6 +113,7 @@ object StreamingForeachBatchHelper extends Logging {
val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
+ // TODO(SPARK-44460): Support Auth credentials
// TODO(SPARK-44462): A new session id pointing to args.df.sparkSession
needs to be created.
// This is because MicroBatch execution clones the session during
start.
// The session attached to the foreachBatch dataframe is different
from the one the one
diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
index 72037f1263d..cf61463cd68 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
@@ -51,6 +51,7 @@ def main(infile: IO, outfile: IO) -> None:
spark_connect_session =
SparkSession.builder.remote(connect_url).getOrCreate()
spark_connect_session._client._session_id = session_id # type:
ignore[attr-defined]
+ # TODO(SPARK-44460): Pass credentials.
# TODO(SPARK-44461): Enable Process Isolation
func = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index c026945767d..e1f4678e42f 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -59,6 +59,7 @@ def main(infile: IO, outfile: IO) -> None:
spark_connect_session =
SparkSession.builder.remote(connect_url).getOrCreate()
spark_connect_session._client._session_id = session_id # type:
ignore[attr-defined]
+ # TODO(SPARK-44460): Pass credentials.
# TODO(SPARK-44461): Enable Process Isolation
listener = worker.read_command(pickle_ser, infile)
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index 5069a76cfdb..4bf58bf7807 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -18,31 +18,39 @@
import unittest
import time
-import pyspark.cloudpickle
from pyspark.sql.tests.streaming.test_streaming_listener import
StreamingListenerTestsMixin
-from pyspark.sql.streaming.listener import StreamingQueryListener
-from pyspark.sql.functions import count, lit
+from pyspark.sql.streaming.listener import StreamingQueryListener,
QueryStartedEvent
+from pyspark.sql.types import StructType, StructField, StringType
from pyspark.testing.connectutils import ReusedConnectTestCase
+def get_start_event_schema():
+ return StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("runId", StringType(), True),
+ StructField("name", StringType(), True),
+ StructField("timestamp", StringType(), True),
+ ]
+ )
+
+
class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
- e = pyspark.cloudpickle.dumps(event)
- df = self.spark.createDataFrame(data=[(e,)])
- df.write.mode("append").saveAsTable("listener_start_events")
+ df = self.spark.createDataFrame(
+ data=[(str(event.id), str(event.runId), event.name,
event.timestamp)],
+ schema=get_start_event_schema(),
+ )
+ df.write.saveAsTable("listener_start_events")
def onQueryProgress(self, event):
- e = pyspark.cloudpickle.dumps(event)
- df = self.spark.createDataFrame(data=[(e,)])
- df.write.mode("append").saveAsTable("listener_progress_events")
+ pass
def onQueryIdle(self, event):
pass
def onQueryTerminated(self, event):
- e = pyspark.cloudpickle.dumps(event)
- df = self.spark.createDataFrame(data=[(e,)])
- df.write.mode("append").saveAsTable("listener_terminated_events")
+ pass
class StreamingListenerParityTests(StreamingListenerTestsMixin,
ReusedConnectTestCase):
@@ -57,36 +65,17 @@ class
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
time.sleep(30)
df = self.spark.readStream.format("rate").option("rowsPerSecond",
10).load()
- df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
- df_stateful = df_observe.groupBy().count() # make query stateful
- q = (
- df_stateful.writeStream.format("noop")
- .queryName("test")
- .outputMode("complete")
- .start()
- )
+ q = df.writeStream.format("noop").queryName("test").start()
self.assertTrue(q.isActive)
time.sleep(10)
- self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least
one batch is ran
q.stop()
- self.assertFalse(q.isActive)
-
- start_event = pyspark.cloudpickle.loads(
- self.spark.read.table("listener_start_events").collect()[0][0]
- )
-
- progress_event = pyspark.cloudpickle.loads(
-
self.spark.read.table("listener_progress_events").collect()[0][0]
- )
- terminated_event = pyspark.cloudpickle.loads(
-
self.spark.read.table("listener_terminated_events").collect()[0][0]
+ start_event = QueryStartedEvent.fromJson(
+
self.spark.read.table("listener_start_events").collect()[0].asDict()
)
self.check_start_event(start_event)
- self.check_progress_event(progress_event)
- self.check_terminated_event(terminated_event)
finally:
self.spark.streams.removeListener(test_listener)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
index 65a0f6279fb..d4e185c3d85 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
@@ -16,12 +16,8 @@
#
import time
-from pyspark.sql.dataframe import DataFrame
-from pyspark.testing.sqlutils import ReusedSQLTestCase
-
-def my_test_function_1():
- return 1
+from pyspark.testing.sqlutils import ReusedSQLTestCase
class StreamingTestsForeachBatchMixin:
@@ -92,111 +88,6 @@ class StreamingTestsForeachBatchMixin:
q.stop()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
- def test_streaming_foreachBatch_spark_session(self):
- table_name = "testTable_foreachBatch"
-
- def func(df: DataFrame, batch_id: int):
- if batch_id > 0: # only process once
- return
- spark = df.sparkSession
- df1 = spark.createDataFrame([("structured",), ("streaming",)])
- df1.union(df).write.mode("append").saveAsTable(table_name)
-
- df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
- q = df.writeStream.foreachBatch(func).start()
- q.processAllAvailable()
- q.stop()
-
- actual = self.spark.read.table(table_name)
- df = (
- self.spark.read.format("text")
- .load(path="python/test_support/sql/streaming/")
- .union(self.spark.createDataFrame([("structured",),
("streaming",)]))
- )
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- def test_streaming_foreachBatch_path_access(self):
- table_name = "testTable_foreachBatch_path"
-
- def func(df: DataFrame, batch_id: int):
- if batch_id > 0: # only process once
- return
- spark = df.sparkSession
- df1 =
spark.read.format("text").load("python/test_support/sql/streaming")
- df1.union(df).write.mode("append").saveAsTable(table_name)
-
- df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
- q = df.writeStream.foreachBatch(func).start()
- q.processAllAvailable()
- q.stop()
-
- actual = self.spark.read.table(table_name)
- df =
self.spark.read.format("text").load(path="python/test_support/sql/streaming/")
- df = df.union(df)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- # write to delta table?
-
- @staticmethod
- def my_test_function_2():
- return 2
-
- def test_streaming_foreachBatch_fuction_calling(self):
- def my_test_function_3():
- return 3
-
- table_name = "testTable_foreachBatch_function"
-
- def func(df: DataFrame, batch_id: int):
- if batch_id > 0: # only process once
- return
- spark = df.sparkSession
- df1 = spark.createDataFrame(
- [
- (my_test_function_1(),),
- (StreamingTestsForeachBatchMixin.my_test_function_2(),),
- (my_test_function_3(),),
- ]
- )
- df1.write.mode("append").saveAsTable(table_name)
-
- df = self.spark.readStream.format("rate").load()
- q = df.writeStream.foreachBatch(func).start()
- q.processAllAvailable()
- q.stop()
-
- actual = self.spark.read.table(table_name)
- df = self.spark.createDataFrame(
- [
- (my_test_function_1(),),
- (StreamingTestsForeachBatchMixin.my_test_function_2(),),
- (my_test_function_3(),),
- ]
- )
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- def test_streaming_foreachBatch_import(self):
- import time # not imported in foreachBatch_worker
-
- table_name = "testTable_foreachBatch_import"
-
- def func(df: DataFrame, batch_id: int):
- if batch_id > 0: # only process once
- return
- time.sleep(1)
- spark = df.sparkSession
- df1 =
spark.read.format("text").load("python/test_support/sql/streaming")
- df1.write.mode("append").saveAsTable(table_name)
-
- df = self.spark.readStream.format("rate").load()
- q = df.writeStream.foreachBatch(func).start()
- q.processAllAvailable()
- q.stop()
-
- actual = self.spark.read.table(table_name)
- df =
self.spark.read.format("text").load("python/test_support/sql/streaming")
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin,
ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]