This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 96666d49feb3 [SPARK-49859][CONNECT] Replace multiprocessing.ThreadPool
with ThreadPoolExecutor
96666d49feb3 is described below
commit 96666d49feb3d4a6b5a76d05e48e898c0962653c
Author: Nemanja Boric <[email protected]>
AuthorDate: Fri Oct 4 09:18:28 2024 +0900
[SPARK-49859][CONNECT] Replace multiprocessing.ThreadPool with
ThreadPoolExecutor
### What changes were proposed in this pull request?
We change the reattachexecutor module to use
concurrent.futures.ThreadPoolExecutor instead of multiprocessing.ThreadPool.
### Why are the changes needed?
multiprocessing.ThreadPool doesn't work in environments where /dev/shm is
not writtable by the python process.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48327 from nemanja-boric-databricks/sparkly.
Authored-by: Nemanja Boric <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client/reattach.py | 18 ++++++++----------
python/pyspark/sql/tests/connect/client/test_client.py | 4 ++--
2 files changed, 10 insertions(+), 12 deletions(-)
diff --git a/python/pyspark/sql/connect/client/reattach.py
b/python/pyspark/sql/connect/client/reattach.py
index ea6788e85831..e0c7cc448933 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -24,7 +24,7 @@ import warnings
import uuid
from collections.abc import Generator
from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast,
Type, ClassVar
-from multiprocessing.pool import ThreadPool
+from concurrent.futures import ThreadPoolExecutor
import os
import grpc
@@ -58,19 +58,18 @@ class ExecutePlanResponseReattachableIterator(Generator):
# Lock to manage the pool
_lock: ClassVar[RLock] = RLock()
- _release_thread_pool_instance: Optional[ThreadPool] = None
+ _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None
@classmethod # type: ignore[misc]
@property
- def _release_thread_pool(cls) -> ThreadPool:
+ def _release_thread_pool(cls) -> ThreadPoolExecutor:
# Perform a first check outside the critical path.
if cls._release_thread_pool_instance is not None:
return cls._release_thread_pool_instance
with cls._lock:
if cls._release_thread_pool_instance is None:
- cls._release_thread_pool_instance = ThreadPool(
- os.cpu_count() if os.cpu_count() else 8
- )
+ max_workers = os.cpu_count() or 8
+ cls._release_thread_pool_instance =
ThreadPoolExecutor(max_workers=max_workers)
return cls._release_thread_pool_instance
@classmethod
@@ -81,8 +80,7 @@ class ExecutePlanResponseReattachableIterator(Generator):
"""
with cls._lock:
if cls._release_thread_pool_instance is not None:
- cls._release_thread_pool.close() # type: ignore[attr-defined]
- cls._release_thread_pool.join() # type: ignore[attr-defined]
+ cls._release_thread_pool.shutdown() # type:
ignore[attr-defined]
cls._release_thread_pool_instance = None
def __init__(
@@ -212,7 +210,7 @@ class ExecutePlanResponseReattachableIterator(Generator):
with self._lock:
if self._release_thread_pool_instance is not None:
- self._release_thread_pool.apply_async(target)
+ self._release_thread_pool.submit(target)
def _release_all(self) -> None:
"""
@@ -237,7 +235,7 @@ class ExecutePlanResponseReattachableIterator(Generator):
with self._lock:
if self._release_thread_pool_instance is not None:
- self._release_thread_pool.apply_async(target)
+ self._release_thread_pool.submit(target)
self._result_complete = True
def _call_iter(self, iter_fun: Callable) -> Any:
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py
b/python/pyspark/sql/tests/connect/client/test_client.py
index 5deb73a0ccf9..741d6b9c1104 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -408,8 +408,8 @@ class SparkConnectClientReattachTestCase(unittest.TestCase):
def checks():
self.assertEqual(1, stub.execute_calls)
self.assertEqual(1, stub.attach_calls)
- self.assertEqual(0, stub.release_calls)
- self.assertEqual(0, stub.release_until_calls)
+ self.assertEqual(1, stub.release_calls)
+ self.assertEqual(1, stub.release_until_calls)
eventually(timeout=1, catch_assertions=True)(checks)()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]