This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e17d2735b44c [SPARK-55020][PYTHON] Disable gc when executing gRPC
command
e17d2735b44c is described below
commit e17d2735b44c4ebb50a2ac5606cdc969c5658bcf
Author: Tian Gao <[email protected]>
AuthorDate: Wed Jan 28 09:01:25 2026 +0800
[SPARK-55020][PYTHON] Disable gc when executing gRPC command
### What changes were proposed in this pull request?
Disable gc during sending gRPC command so `__del__` of java object
references won't trigger.
### Why are the changes needed?
We have a very flaky test `test_distributed_lda`
[Failure1](https://github.com/gaogaotiantian/spark/actions/runs/20940523104/job/60173040026),
[Failure2](https://github.com/apache/spark/actions/runs/20938424034/job/60166745846),
[Failure3](https://github.com/apache/spark/actions/runs/20908277105/job/60066145491).
This is a deadlock issue with the traceback:
```
(Python) File
"/__w/spark/spark/python/pyspark/ml/tests/connect/test_parity_clustering.py",
line 30, in <module>
main()
(Python) File
"/__w/spark/spark/python/pyspark/testing/unittestutils.py", line 43, in main
unittest.main(module=module, testRunner=testRunner, verbosity=2)
(Python) File "/usr/lib/python3.11/unittest/main.py", line 102, in
__init__
self.runTests()
(Python) File "/usr/lib/python3.11/unittest/main.py", line 274, in
runTests
self.result = testRunner.run(self.test)
(Python) File
"/usr/local/lib/python3.11/dist-packages/xmlrunner/runner.py", line 67, in run
test(result)
(Python) File "/usr/lib/python3.11/unittest/suite.py", line 84, in
__call__
return self.run(*args, **kwds)
(Python) File "/usr/lib/python3.11/unittest/suite.py", line 122, in run
test(result)
(Python) File "/usr/lib/python3.11/unittest/suite.py", line 84, in
__call__
return self.run(*args, **kwds)
(Python) File "/usr/lib/python3.11/unittest/suite.py", line 122, in run
test(result)
(Python) File "/usr/lib/python3.11/unittest/case.py", line 678, in
__call__
return self.run(*args, **kwds)
(Python) File "/usr/lib/python3.11/unittest/case.py", line 623, in run
self._callTestMethod(testMethod)
(Python) File "/usr/lib/python3.11/unittest/case.py", line 579, in
_callTestMethod
if method() is not None:
(Python) File
"/__w/spark/spark/python/pyspark/ml/tests/test_clustering.py", line 466, in
test_distributed_lda
self.assertEqual(str(model), str(model2))
(Python) File "/__w/spark/spark/python/pyspark/ml/wrapper.py", line
474, in __repr__
return self._call_java("toString")
(Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 322,
in wrapped
return remote_call()
(Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 308,
in remote_call
(_, properties, _) = session.client.execute_command(command)
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1162, in
execute_command
data, _, metrics, observed_metrics, properties =
self._execute_and_fetch(
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1664, in
_execute_and_fetch
for response in self._execute_and_fetch_as_iterator(
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1621, in
_execute_and_fetch_as_iterator
generator = ExecutePlanResponseReattachableIterator(
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/reattach.py", line 127, in
__init__
self._stub.ExecutePlan(self._initial_request, metadata=metadata)
(Python) File
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1396, in
__call__
call = self._managed_call(
(Python) File
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1785, in create
call = state.channel.integrated_call(
(Python) File "/usr/lib/python3.11/threading.py", line 905, in __init__
self._started = Event()
(Python) File "/usr/lib/python3.11/threading.py", line 563, in __init__
self._cond = Condition(Lock())
(Python) File "/usr/lib/python3.11/threading.py", line 254, in __init__
self._release_save = lock._release_save
(Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 379,
in wrapped
self._remote_model_obj.release_ref()
(Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 162,
in release_ref
del_remote_cache(self.ref_id)
(Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 358,
in del_remote_cache
session.client._delete_ml_cache([ref_id])
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 2137, in
_delete_ml_cache
(_, properties, _) = self.execute_command(command)
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1162, in
execute_command
data, _, metrics, observed_metrics, properties =
self._execute_and_fetch(
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1664, in
_execute_and_fetch
for response in self._execute_and_fetch_as_iterator(
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1621, in
_execute_and_fetch_as_iterator
generator = ExecutePlanResponseReattachableIterator(
(Python) File
"/__w/spark/spark/python/pyspark/sql/connect/client/reattach.py", line 127, in
__init__
self._stub.ExecutePlan(self._initial_request, metadata=metadata)
(Python) File
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1396, in
__call__
call = self._managed_call(
(Python) File
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1784, in create
with state.lock:
```
The deadlock happened in `create` function, the issue is that it gets the
exclusive lock to create a gRPC command, then another gRPC command is being
created during the process and trying to get the same exclusive lock (which is
not re-entriable).
This is triggered by a garbage collection, then `__del__` method of
`JavaWrapper` tried to release cache by sending a gRPC command to remote.
The design is fundamentally wrong and there is no easy fix. You simply
can't send gRPC command in `__del__` because that method can be triggered at
any arbitrary point (as gc can happen anytime).
I can think of three options:
~~1. Like this, we just do not clear cache. Simplest way to solve the
issue. Connect client will send a command to clear cache when it's being
deleted.~~
2. We stop gc while we make command. Not the best solution but it might
just work. This kind of plays with Python internal mechanism but the code would
be clean and we can have the desired behavior.
~~3. Do something like `execute_command_later`, and queue the commands to
run in the next synchronous call to `execute_command`. A bit refactoring to the
existing framework, command will be delayed, but code wise it might be the most
accurate.~~
Notice that this problem is not only affecting this specific test, or even
our general test suite. This is a real deadlock problem that could happen to
our users. We should backport the final fix.
For now, we need to make a quick decision about how to mitigate the flaky
test - it's interrupting our workflow. One flaky test will make the whole
workflow fragile.
### Does this PR introduce _any_ user-facing change?
Not really, user should not know that cache is not being cleared eagerly.
### How was this patch tested?
It's not super reproducible so we need to rely on CI.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53783 from gaogaotiantian/fix-ml-util-racing.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/tests/test_clustering.py | 2 --
python/pyspark/sql/connect/client/core.py | 4 +++-
python/pyspark/util.py | 22 +++++++++++++++++++++-
3 files changed, 24 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/ml/tests/test_clustering.py
b/python/pyspark/ml/tests/test_clustering.py
index 8bd021903fba..e22e97a5e7f1 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -16,7 +16,6 @@
#
import tempfile
-import unittest
import numpy as np
@@ -384,7 +383,6 @@ class ClusteringTestsMixin:
model2 = LocalLDAModel.load(d)
self.assertEqual(str(model), str(model2))
- @unittest.skip("SPARK-55020: Test triggers frequent deadlock in CI")
def test_distributed_lda(self):
spark = self.spark
df = (
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 58ae6a0eea98..69447503da6c 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -65,7 +65,7 @@ import grpc
from google.protobuf import text_format, any_pb2
from google.rpc import error_details_pb2
-from pyspark.util import is_remote_only
+from pyspark.util import is_remote_only, disable_gc
from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.version import __version__
from pyspark.traceback_utils import CallSite
@@ -1457,6 +1457,7 @@ class SparkConnectClient(object):
except Exception as error:
self._handle_error(error)
+ @disable_gc
def _execute(self, req: pb2.ExecutePlanRequest) -> None:
"""
Execute the passed request `req` and drop all results.
@@ -1494,6 +1495,7 @@ class SparkConnectClient(object):
except Exception as error:
self._handle_error(error)
+ @disable_gc
def _execute_and_fetch_as_iterator(
self,
req: pb2.ExecutePlanRequest,
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 9f5fad57cb14..e181339e4896 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -19,6 +19,7 @@
import copy
import functools
import faulthandler
+import gc
import itertools
import os
import platform
@@ -31,7 +32,7 @@ import socket
import warnings
from contextlib import contextmanager
from types import TracebackType
-from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple,
Union
+from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple,
TypeVar, Union, cast
from pyspark.errors import PySparkRuntimeError
from pyspark.serializers import (
@@ -96,6 +97,8 @@ JVM_INT_MAX: int = (1 << 31) - 1
JVM_LONG_MIN: int = -(1 << 63)
JVM_LONG_MAX: int = (1 << 63) - 1
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
+
def print_exec(stream: TextIO) -> None:
ei = sys.exc_info()
@@ -857,6 +860,23 @@ def _do_server_auth(conn: "io.IOBase", auth_secret: str)
-> None:
)
+def disable_gc(f: FuncT) -> FuncT:
+ """Mark the function that should disable gc during execution"""
+
+ @functools.wraps(f)
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
+ gc_enabled_originally = gc.isenabled()
+ if gc_enabled_originally:
+ gc.disable()
+ try:
+ return f(*args, **kwargs)
+ finally:
+ if gc_enabled_originally:
+ gc.enable()
+
+ return cast(FuncT, wrapped)
+
+
_is_remote_only = None
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]