This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 96666d49feb3 [SPARK-49859][CONNECT] Replace multiprocessing.ThreadPool 
with ThreadPoolExecutor
96666d49feb3 is described below

commit 96666d49feb3d4a6b5a76d05e48e898c0962653c
Author: Nemanja Boric <[email protected]>
AuthorDate: Fri Oct 4 09:18:28 2024 +0900

    [SPARK-49859][CONNECT] Replace multiprocessing.ThreadPool with 
ThreadPoolExecutor
    
    ### What changes were proposed in this pull request?
    
    We change the reattachexecutor module to use 
concurrent.futures.ThreadPoolExecutor instead of multiprocessing.ThreadPool.
    
    ### Why are the changes needed?
    
    multiprocessing.ThreadPool doesn't work in environments where /dev/shm is 
not writtable by the python process.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #48327 from nemanja-boric-databricks/sparkly.
    
    Authored-by: Nemanja Boric <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client/reattach.py          | 18 ++++++++----------
 python/pyspark/sql/tests/connect/client/test_client.py |  4 ++--
 2 files changed, 10 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/sql/connect/client/reattach.py 
b/python/pyspark/sql/connect/client/reattach.py
index ea6788e85831..e0c7cc448933 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -24,7 +24,7 @@ import warnings
 import uuid
 from collections.abc import Generator
 from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, 
Type, ClassVar
-from multiprocessing.pool import ThreadPool
+from concurrent.futures import ThreadPoolExecutor
 import os
 
 import grpc
@@ -58,19 +58,18 @@ class ExecutePlanResponseReattachableIterator(Generator):
 
     # Lock to manage the pool
     _lock: ClassVar[RLock] = RLock()
-    _release_thread_pool_instance: Optional[ThreadPool] = None
+    _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None
 
     @classmethod  # type: ignore[misc]
     @property
-    def _release_thread_pool(cls) -> ThreadPool:
+    def _release_thread_pool(cls) -> ThreadPoolExecutor:
         # Perform a first check outside the critical path.
         if cls._release_thread_pool_instance is not None:
             return cls._release_thread_pool_instance
         with cls._lock:
             if cls._release_thread_pool_instance is None:
-                cls._release_thread_pool_instance = ThreadPool(
-                    os.cpu_count() if os.cpu_count() else 8
-                )
+                max_workers = os.cpu_count() or 8
+                cls._release_thread_pool_instance = 
ThreadPoolExecutor(max_workers=max_workers)
             return cls._release_thread_pool_instance
 
     @classmethod
@@ -81,8 +80,7 @@ class ExecutePlanResponseReattachableIterator(Generator):
         """
         with cls._lock:
             if cls._release_thread_pool_instance is not None:
-                cls._release_thread_pool.close()  # type: ignore[attr-defined]
-                cls._release_thread_pool.join()  # type: ignore[attr-defined]
+                cls._release_thread_pool.shutdown()  # type: 
ignore[attr-defined]
                 cls._release_thread_pool_instance = None
 
     def __init__(
@@ -212,7 +210,7 @@ class ExecutePlanResponseReattachableIterator(Generator):
 
         with self._lock:
             if self._release_thread_pool_instance is not None:
-                self._release_thread_pool.apply_async(target)
+                self._release_thread_pool.submit(target)
 
     def _release_all(self) -> None:
         """
@@ -237,7 +235,7 @@ class ExecutePlanResponseReattachableIterator(Generator):
 
         with self._lock:
             if self._release_thread_pool_instance is not None:
-                self._release_thread_pool.apply_async(target)
+                self._release_thread_pool.submit(target)
         self._result_complete = True
 
     def _call_iter(self, iter_fun: Callable) -> Any:
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 5deb73a0ccf9..741d6b9c1104 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -408,8 +408,8 @@ class SparkConnectClientReattachTestCase(unittest.TestCase):
             def checks():
                 self.assertEqual(1, stub.execute_calls)
                 self.assertEqual(1, stub.attach_calls)
-                self.assertEqual(0, stub.release_calls)
-                self.assertEqual(0, stub.release_until_calls)
+                self.assertEqual(1, stub.release_calls)
+                self.assertEqual(1, stub.release_until_calls)
 
             eventually(timeout=1, catch_assertions=True)(checks)()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to