This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e17d2735b44c [SPARK-55020][PYTHON] Disable gc when executing gRPC 
command
e17d2735b44c is described below

commit e17d2735b44c4ebb50a2ac5606cdc969c5658bcf
Author: Tian Gao <[email protected]>
AuthorDate: Wed Jan 28 09:01:25 2026 +0800

    [SPARK-55020][PYTHON] Disable gc when executing gRPC command
    
    ### What changes were proposed in this pull request?
    
    Disable gc during sending gRPC command so `__del__` of java object 
references won't trigger.
    
    ### Why are the changes needed?
    
    We have a very flaky test `test_distributed_lda` 
[Failure1](https://github.com/gaogaotiantian/spark/actions/runs/20940523104/job/60173040026),
 
[Failure2](https://github.com/apache/spark/actions/runs/20938424034/job/60166745846),
 
[Failure3](https://github.com/apache/spark/actions/runs/20908277105/job/60066145491).
 This is a deadlock issue with the traceback:
    
    ```
        (Python) File 
"/__w/spark/spark/python/pyspark/ml/tests/connect/test_parity_clustering.py", 
line 30, in <module>
            main()
        (Python) File 
"/__w/spark/spark/python/pyspark/testing/unittestutils.py", line 43, in main
            unittest.main(module=module, testRunner=testRunner, verbosity=2)
        (Python) File "/usr/lib/python3.11/unittest/main.py", line 102, in 
__init__
            self.runTests()
        (Python) File "/usr/lib/python3.11/unittest/main.py", line 274, in 
runTests
            self.result = testRunner.run(self.test)
        (Python) File 
"/usr/local/lib/python3.11/dist-packages/xmlrunner/runner.py", line 67, in run
            test(result)
        (Python) File "/usr/lib/python3.11/unittest/suite.py", line 84, in 
__call__
            return self.run(*args, **kwds)
        (Python) File "/usr/lib/python3.11/unittest/suite.py", line 122, in run
            test(result)
        (Python) File "/usr/lib/python3.11/unittest/suite.py", line 84, in 
__call__
            return self.run(*args, **kwds)
        (Python) File "/usr/lib/python3.11/unittest/suite.py", line 122, in run
            test(result)
        (Python) File "/usr/lib/python3.11/unittest/case.py", line 678, in 
__call__
            return self.run(*args, **kwds)
        (Python) File "/usr/lib/python3.11/unittest/case.py", line 623, in run
            self._callTestMethod(testMethod)
        (Python) File "/usr/lib/python3.11/unittest/case.py", line 579, in 
_callTestMethod
            if method() is not None:
        (Python) File 
"/__w/spark/spark/python/pyspark/ml/tests/test_clustering.py", line 466, in 
test_distributed_lda
            self.assertEqual(str(model), str(model2))
        (Python) File "/__w/spark/spark/python/pyspark/ml/wrapper.py", line 
474, in __repr__
            return self._call_java("toString")
        (Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 322, 
in wrapped
            return remote_call()
        (Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 308, 
in remote_call
            (_, properties, _) = session.client.execute_command(command)
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1162, in 
execute_command
            data, _, metrics, observed_metrics, properties = 
self._execute_and_fetch(
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1664, in 
_execute_and_fetch
            for response in self._execute_and_fetch_as_iterator(
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1621, in 
_execute_and_fetch_as_iterator
            generator = ExecutePlanResponseReattachableIterator(
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/reattach.py", line 127, in 
__init__
            self._stub.ExecutePlan(self._initial_request, metadata=metadata)
        (Python) File 
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1396, in 
__call__
            call = self._managed_call(
        (Python) File 
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1785, in create
            call = state.channel.integrated_call(
        (Python) File "/usr/lib/python3.11/threading.py", line 905, in __init__
            self._started = Event()
        (Python) File "/usr/lib/python3.11/threading.py", line 563, in __init__
            self._cond = Condition(Lock())
        (Python) File "/usr/lib/python3.11/threading.py", line 254, in __init__
            self._release_save = lock._release_save
        (Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 379, 
in wrapped
            self._remote_model_obj.release_ref()
        (Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 162, 
in release_ref
            del_remote_cache(self.ref_id)
        (Python) File "/__w/spark/spark/python/pyspark/ml/util.py", line 358, 
in del_remote_cache
            session.client._delete_ml_cache([ref_id])
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 2137, in 
_delete_ml_cache
            (_, properties, _) = self.execute_command(command)
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1162, in 
execute_command
            data, _, metrics, observed_metrics, properties = 
self._execute_and_fetch(
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1664, in 
_execute_and_fetch
            for response in self._execute_and_fetch_as_iterator(
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/core.py", line 1621, in 
_execute_and_fetch_as_iterator
            generator = ExecutePlanResponseReattachableIterator(
        (Python) File 
"/__w/spark/spark/python/pyspark/sql/connect/client/reattach.py", line 127, in 
__init__
            self._stub.ExecutePlan(self._initial_request, metadata=metadata)
        (Python) File 
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1396, in 
__call__
            call = self._managed_call(
        (Python) File 
"/usr/local/lib/python3.11/dist-packages/grpc/_channel.py", line 1784, in create
            with state.lock:
    ```
    
    The deadlock happened in `create` function, the issue is that it gets the 
exclusive lock to create a gRPC command, then another gRPC command is being 
created during the process and trying to get the same exclusive lock (which is 
not re-entriable).
    
    This is triggered by a garbage collection, then `__del__` method of 
`JavaWrapper` tried to release cache by sending a gRPC command to remote.
    
    The design is fundamentally wrong and there is no easy fix. You simply 
can't send gRPC command in `__del__` because that method can be triggered at 
any arbitrary point (as gc can happen anytime).
    
    I can think of three options:
    
    ~~1. Like this, we just do not clear cache. Simplest way to solve the 
issue. Connect client will send a command to clear cache when it's being 
deleted.~~
    2. We stop gc while we make command. Not the best solution but it might 
just work. This kind of plays with Python internal mechanism but the code would 
be clean and we can have the desired behavior.
    ~~3. Do something like `execute_command_later`, and queue the commands to 
run in the next synchronous call to `execute_command`. A bit refactoring to the 
existing framework, command will be delayed, but code wise it might be the most 
accurate.~~
    
    Notice that this problem is not only affecting this specific test, or even 
our general test suite. This is a real deadlock problem that could happen to 
our users. We should backport the final fix.
    
    For now, we need to make a quick decision about how to mitigate the flaky 
test - it's interrupting our workflow. One flaky test will make the whole 
workflow fragile.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Not really, user should not know that cache is not being cleared eagerly.
    
    ### How was this patch tested?
    
    It's not super reproducible so we need to rely on CI.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #53783 from gaogaotiantian/fix-ml-util-racing.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/ml/tests/test_clustering.py |  2 --
 python/pyspark/sql/connect/client/core.py  |  4 +++-
 python/pyspark/util.py                     | 22 +++++++++++++++++++++-
 3 files changed, 24 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index 8bd021903fba..e22e97a5e7f1 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -16,7 +16,6 @@
 #
 
 import tempfile
-import unittest
 
 import numpy as np
 
@@ -384,7 +383,6 @@ class ClusteringTestsMixin:
             model2 = LocalLDAModel.load(d)
             self.assertEqual(str(model), str(model2))
 
-    @unittest.skip("SPARK-55020: Test triggers frequent deadlock in CI")
     def test_distributed_lda(self):
         spark = self.spark
         df = (
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 58ae6a0eea98..69447503da6c 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -65,7 +65,7 @@ import grpc
 from google.protobuf import text_format, any_pb2
 from google.rpc import error_details_pb2
 
-from pyspark.util import is_remote_only
+from pyspark.util import is_remote_only, disable_gc
 from pyspark.accumulators import SpecialAccumulatorIds
 from pyspark.version import __version__
 from pyspark.traceback_utils import CallSite
@@ -1457,6 +1457,7 @@ class SparkConnectClient(object):
         except Exception as error:
             self._handle_error(error)
 
+    @disable_gc
     def _execute(self, req: pb2.ExecutePlanRequest) -> None:
         """
         Execute the passed request `req` and drop all results.
@@ -1494,6 +1495,7 @@ class SparkConnectClient(object):
         except Exception as error:
             self._handle_error(error)
 
+    @disable_gc
     def _execute_and_fetch_as_iterator(
         self,
         req: pb2.ExecutePlanRequest,
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 9f5fad57cb14..e181339e4896 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -19,6 +19,7 @@
 import copy
 import functools
 import faulthandler
+import gc
 import itertools
 import os
 import platform
@@ -31,7 +32,7 @@ import socket
 import warnings
 from contextlib import contextmanager
 from types import TracebackType
-from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple, 
Union
+from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple, 
TypeVar, Union, cast
 
 from pyspark.errors import PySparkRuntimeError
 from pyspark.serializers import (
@@ -96,6 +97,8 @@ JVM_INT_MAX: int = (1 << 31) - 1
 JVM_LONG_MIN: int = -(1 << 63)
 JVM_LONG_MAX: int = (1 << 63) - 1
 
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
+
 
 def print_exec(stream: TextIO) -> None:
     ei = sys.exc_info()
@@ -857,6 +860,23 @@ def _do_server_auth(conn: "io.IOBase", auth_secret: str) 
-> None:
         )
 
 
+def disable_gc(f: FuncT) -> FuncT:
+    """Mark the function that should disable gc during execution"""
+
+    @functools.wraps(f)
+    def wrapped(*args: Any, **kwargs: Any) -> Any:
+        gc_enabled_originally = gc.isenabled()
+        if gc_enabled_originally:
+            gc.disable()
+        try:
+            return f(*args, **kwargs)
+        finally:
+            if gc_enabled_originally:
+                gc.enable()
+
+    return cast(FuncT, wrapped)
+
+
 _is_remote_only = None
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to