This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9881e0a3fec1 [SPARK-47777] fix python streaming data source connect 
test
9881e0a3fec1 is described below

commit 9881e0a3fec1fa931e8f61fec0c0ce679d6d9842
Author: Chaoqin Li <[email protected]>
AuthorDate: Mon Jun 17 09:28:10 2024 +0900

    [SPARK-47777] fix python streaming data source connect test
    
    ### What changes were proposed in this pull request?
    Add back python streaming data source connect test. Previously we remove 
this test because it fails in pure python spark connect library due to missing 
py4j.
    This is because in python_streaming_source_runner, we import from 
pyspark.java_gateway instead of pyspark.utils, and pyspark.java_gateway 
requires py4j
    `from pyspark.java_gateway import local_connect_and_auth`
    This was introduced in 
https://github.com/apache/spark/commit/c8c2492041782b9be7f10647191dcd0d5f6a5a8a 
accidentally.
    
    ### Why are the changes needed?
    Add spark connect test for python streaming data source.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    Test change.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46906 from chaoqin-li1123/fix_include.
    
    Authored-by: Chaoqin Li <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 dev/sparktestsupport/modules.py                    |  1 +
 .../streaming/python_streaming_source_runner.py    |  3 +-
 .../test_parity_python_streaming_datasource.py     | 39 ++++++++++++++++++++++
 .../sql/tests/test_python_streaming_datasource.py  |  6 +---
 4 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index b97ec34b5382..8c17af559c25 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1048,6 +1048,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_arrow_grouped_map",
         "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map",
         "pyspark.sql.tests.connect.test_parity_python_datasource",
+        "pyspark.sql.tests.connect.test_parity_python_streaming_datasource",
         "pyspark.sql.tests.connect.test_utils",
         "pyspark.sql.tests.connect.client.test_artifact",
         "pyspark.sql.tests.connect.client.test_client",
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index bd779df4837b..5292e2f92784 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -21,7 +21,6 @@ import json
 from typing import IO, Iterator, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
-from pyspark.java_gateway import local_connect_and_auth
 from pyspark.errors import IllegalArgumentException, PySparkAssertionError, 
PySparkRuntimeError
 from pyspark.serializers import (
     read_int,
@@ -37,7 +36,7 @@ from pyspark.sql.types import (
     StructType,
 )
 from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches
-from pyspark.util import handle_worker_exception
+from pyspark.util import handle_worker_exception, local_connect_and_auth
 from pyspark.worker_util import (
     check_python_version,
     read_command,
diff --git 
a/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py 
b/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py
new file mode 100644
index 000000000000..65bb4c021f4d
--- /dev/null
+++ 
b/python/pyspark/sql/tests/connect/test_parity_python_streaming_datasource.py
@@ -0,0 +1,39 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.sql.tests.test_python_streaming_datasource import (
+    BasePythonStreamingDataSourceTestsMixin,
+)
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class PythonStreamingDataSourceParityTests(
+    BasePythonStreamingDataSourceTestsMixin, ReusedConnectTestCase
+):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_python_streaming_datasource 
import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py 
b/python/pyspark/sql/tests/test_python_streaming_datasource.py
index e5622e28f15b..183b0ad80d9d 100644
--- a/python/pyspark/sql/tests/test_python_streaming_datasource.py
+++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py
@@ -142,15 +142,11 @@ class BasePythonStreamingDataSourceTestsMixin:
         self.spark.dataSource.register(self._get_test_data_source())
         df = self.spark.readStream.format("TestDataSource").load()
 
-        current_batch_id = -1
-
         def check_batch(df, batch_id):
-            nonlocal current_batch_id
-            current_batch_id = batch_id
             assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 
1)])
 
         q = df.writeStream.foreachBatch(check_batch).start()
-        while current_batch_id < 10:
+        while len(q.recentProgress) < 10:
             time.sleep(0.2)
         q.stop()
         q.awaitTermination()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to