This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 50f94120e9c7 [SPARK-55799][PYTHON][TESTS] Add datasource tests with 
simple worker
50f94120e9c7 is described below

commit 50f94120e9c7530003917270f2f30ab742187f46
Author: Tian Gao <[email protected]>
AuthorDate: Wed Mar 4 07:32:58 2026 +0900

    [SPARK-55799][PYTHON][TESTS] Add datasource tests with simple worker
    
    ### What changes were proposed in this pull request?
    
    * Fixed `is_pyspark_module` in worker logging so it can recognize simple 
workers
    * Add a simpler worker test suite that runs all data source tests with 
simple worker.
    * Fixed the test where it always expects 2 abort messages which is not 
guaranteed.
    
    ### Why are the changes needed?
    
    We support simple worker for data source but we have 0 test for it. Some 
issues were already fixed in the previous PRs and this is the one to enable the 
test.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    `test_python_datasource` passed locally.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54580 from gaogaotiantian/add-simple-worker-datasource-test.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/logger/worker_io.py                 |  9 +++--
 python/pyspark/sql/tests/test_python_datasource.py | 39 +++++++++++++++++-----
 2 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/logger/worker_io.py 
b/python/pyspark/logger/worker_io.py
index 79684b7aca62..7843b43e10ee 100644
--- a/python/pyspark/logger/worker_io.py
+++ b/python/pyspark/logger/worker_io.py
@@ -223,7 +223,11 @@ def context_provider() -> dict[str, str]:
             - class_name: Name of the class that initiated the logging if 
available
     """
 
-    def is_pyspark_module(module_name: str) -> bool:
+    def is_pyspark_module(frame: FrameType) -> bool:
+        module_name = frame.f_globals.get("__name__", "")
+        if module_name == "__main__":
+            if (mod := sys.modules.get("__main__", None)) and mod.__spec__:
+                module_name = mod.__spec__.name
         return module_name.startswith("pyspark.") and ".tests." not in 
module_name
 
     bottom: Optional[FrameType] = None
@@ -236,9 +240,8 @@ def context_provider() -> dict[str, str]:
         if frame:
             while frame.f_back:
                 f_back = frame.f_back
-                module_name = f_back.f_globals.get("__name__", "")
 
-                if is_pyspark_module(module_name):
+                if is_pyspark_module(f_back):
                     if not is_in_pyspark_module:
                         bottom = frame
                         is_in_pyspark_module = True
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 1bdb7a5395e1..9d90082c654d 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -1237,8 +1237,20 @@ class BasePythonDataSourceTestsMixin:
 
                 logs = self.spark.tvf.python_worker_logs()
 
+                # We could get either 1 or 2 "TestJsonWriter.write: abort 
test" logs because
+                # the operation is time sensitive. When the first partition 
gets aborted,
+                # the executor will cancel the rest of the tasks. Whether we 
are able to get
+                # the second log depends on whether the second partition 
starts before the
+                # cancellation. When we use simple worker, the second log is 
often missing
+                # because the spawn overhead is large.
+                non_abort_logs = logs.select("level", "msg", "context", 
"logger").filter(
+                    "msg != 'TestJsonWriter.write: abort test'"
+                )
+                abort_logs = logs.select("level", "msg", "context", 
"logger").filter(
+                    "msg == 'TestJsonWriter.write: abort test'"
+                )
                 assertDataFrameEqual(
-                    logs.select("level", "msg", "context", "logger"),
+                    non_abort_logs,
                     [
                         Row(
                             level="WARNING",
@@ -1283,14 +1295,6 @@ class BasePythonDataSourceTestsMixin:
                                 "TestJsonWriter.__init__: ['abort', 'path']",
                                 {"class_name": "TestJsonDataSource", 
"func_name": "writer"},
                             ),
-                            (
-                                "TestJsonWriter.write: abort test",
-                                {"class_name": "TestJsonWriter", "func_name": 
"write"},
-                            ),
-                            (
-                                "TestJsonWriter.write: abort test",
-                                {"class_name": "TestJsonWriter", "func_name": 
"write"},
-                            ),
                             (
                                 "TestJsonWriter.abort",
                                 {"class_name": "TestJsonWriter", "func_name": 
"abort"},
@@ -1298,6 +1302,17 @@ class BasePythonDataSourceTestsMixin:
                         ]
                     ],
                 )
+                assertDataFrameEqual(
+                    abort_logs.dropDuplicates(["msg"]),
+                    [
+                        Row(
+                            level="WARNING",
+                            msg="TestJsonWriter.write: abort test",
+                            context={"class_name": "TestJsonWriter", 
"func_name": "write"},
+                            logger="test_datasource_writer",
+                        )
+                    ],
+                )
 
     def test_data_source_perf_profiler(self):
         with self.sql_conf({"spark.sql.pyspark.dataSource.profiler": "perf"}):
@@ -1345,6 +1360,12 @@ class 
PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
     ...
 
 
+class PythonDataSourceTestsWithSimpleWorker(PythonDataSourceTests):
+    @classmethod
+    def conf(self):
+        return super().conf().set("spark.python.use.daemon", "false")
+
+
 if __name__ == "__main__":
     from pyspark.testing import main
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to