This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4192135c1234 [SPARK-55020][PYTHON][FOLLOW-UP] Disable gc only when we 
communicate through gRPC for ExecutePlan
4192135c1234 is described below

commit 4192135c1234e7f27f5b1d8ffca0a0fd729a5624
Author: Tian Gao <[email protected]>
AuthorDate: Wed Feb 11 10:06:11 2026 +0900

    [SPARK-55020][PYTHON][FOLLOW-UP] Disable gc only when we communicate 
through gRPC for ExecutePlan
    
    ### What changes were proposed in this pull request?
    
    Instead of disabling gc for the whole function (we did it wrong for 
generators), we precisely disable it when we do communications through gRPC 
with `ExecutePlan`.
    
    ### Why are the changes needed?
    
    The previous implementation 
[SPARK-55020](https://issues.apache.org/jira/browse/SPARK-55020) 
(https://github.com/apache/spark/pull/53783) was wrong - the generator was only 
protected when it's built, but it will trigger communication when it's being 
drained.
    
    The context `disable_gc` provides a more precise way to disable gc during 
some operation. Also we should make generator work in a way that gc is not 
always disabled when generator is not drained.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    CI.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54248 from gaogaotiantian/fix-disable-gc.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client/core.py          | 18 ++++++----
 python/pyspark/sql/connect/client/reattach.py      | 13 +++++---
 .../sql/tests/connect/client/test_client.py        |  2 +-
 python/pyspark/tests/test_util.py                  |  9 ++++-
 python/pyspark/util.py                             | 38 +++++++++++++---------
 5 files changed, 52 insertions(+), 28 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 8db7126e9a4f..58cbd22a36b7 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1462,7 +1462,6 @@ class SparkConnectClient(object):
         except Exception as error:
             self._handle_error(error)
 
-    @disable_gc
     def _execute(self, req: pb2.ExecutePlanRequest) -> None:
         """
         Execute the passed request `req` and drop all results.
@@ -1496,12 +1495,12 @@ class SparkConnectClient(object):
             else:
                 for attempt in self._retrying():
                     with attempt:
-                        for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                            handle_response(b)
+                        with disable_gc():
+                            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                                handle_response(b)
         except Exception as error:
             self._handle_error(error)
 
-    @disable_gc
     def _execute_and_fetch_as_iterator(
         self,
         req: pb2.ExecutePlanRequest,
@@ -1697,8 +1696,15 @@ class SparkConnectClient(object):
             else:
                 for attempt in self._retrying():
                     with attempt:
-                        for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                            yield from handle_response(b)
+                        with disable_gc():
+                            gen = self._stub.ExecutePlan(req, 
metadata=self._builder.metadata())
+                        while True:
+                            try:
+                                with disable_gc():
+                                    b = next(gen)
+                                yield from handle_response(b)
+                            except StopIteration:
+                                break
         except KeyboardInterrupt as kb:
             logger.debug(f"Interrupt request received for 
operation={req.operation_id}")
             if progress is not None:
diff --git a/python/pyspark/sql/connect/client/reattach.py 
b/python/pyspark/sql/connect/client/reattach.py
index 70ebe0f667ef..dd16b80b82b6 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -34,6 +34,7 @@ from pyspark.sql.connect.logging import logger
 import pyspark.sql.connect.proto as pb2
 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
 from pyspark.errors import PySparkRuntimeError
+from pyspark.util import disable_gc
 
 
 class ExecutePlanResponseReattachableIterator(Generator):
@@ -108,9 +109,10 @@ class ExecutePlanResponseReattachableIterator(Generator):
         # Note: This is not retried, because no error would ever be thrown 
here, and GRPC will only
         # throw error on first self._has_next().
         self._metadata = metadata
-        self._iterator: Optional[Iterator[pb2.ExecutePlanResponse]] = iter(
-            self._stub.ExecutePlan(self._initial_request, metadata=metadata)
-        )
+        with disable_gc():
+            self._iterator: Optional[Iterator[pb2.ExecutePlanResponse]] = iter(
+                self._stub.ExecutePlan(self._initial_request, 
metadata=metadata)
+            )
 
         # Current item from this iterator.
         self._current: Optional[pb2.ExecutePlanResponse] = None
@@ -142,8 +144,9 @@ class ExecutePlanResponseReattachableIterator(Generator):
 
     def send(self, value: Any) -> pb2.ExecutePlanResponse:
         # will trigger reattach in case the stream completed without 
result_complete
-        if not self._has_next():
-            raise StopIteration()
+        with disable_gc():
+            if not self._has_next():
+                raise StopIteration()
 
         ret = self._current
         assert ret is not None
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 8daf82ad3f37..6c58a8a98d87 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -158,7 +158,7 @@ if should_test_connect:
             buf = sink.getvalue()
             resp.arrow_batch.data = buf.to_pybytes()
             resp.arrow_batch.row_count = 2
-            return [resp]
+            return iter([resp])
 
         def Interrupt(self, req: proto.InterruptRequest, metadata):
             self.req = req
diff --git a/python/pyspark/tests/test_util.py 
b/python/pyspark/tests/test_util.py
index ec703e52b9c2..9fd0135eff4e 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import gc
 import os
 import time
 import unittest
@@ -22,7 +23,7 @@ from unittest.mock import patch
 from py4j.protocol import Py4JJavaError
 
 from pyspark import keyword_only
-from pyspark.util import _parse_memory
+from pyspark.util import _parse_memory, disable_gc
 from pyspark.loose_version import LooseVersion
 from pyspark.testing.utils import PySparkTestCase, eventually, timeout
 from pyspark.find_spark_home import _find_spark_home
@@ -148,6 +149,12 @@ class UtilTests(PySparkTestCase):
         with self.assertRaisesRegex(ValueError, "invalid format"):
             _parse_memory("2gs")
 
+    def test_disable_gc(self):
+        self.assertTrue(gc.isenabled())
+        with disable_gc():
+            self.assertFalse(gc.isenabled())
+        self.assertTrue(gc.isenabled())
+
     @eventually(timeout=180, catch_timeout=True)
     @timeout(timeout=1)
     def test_retry_timeout_test(self):
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 1ba72b198581..fb672e3d1222 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 #
 
+import contextlib
 import copy
 import functools
 import faulthandler
@@ -32,7 +33,19 @@ import socket
 import warnings
 from contextlib import contextmanager
 from types import TracebackType
-from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple, 
TypeVar, Union, cast
+from typing import (
+    Any,
+    Callable,
+    Generator,
+    IO,
+    Iterator,
+    List,
+    Optional,
+    TextIO,
+    Tuple,
+    TypeVar,
+    Union,
+)
 
 from pyspark.errors import PySparkRuntimeError
 from pyspark.serializers import (
@@ -860,21 +873,16 @@ def _do_server_auth(conn: "io.IOBase", auth_secret: str) 
-> None:
         )
 
 
-def disable_gc(f: FuncT) -> FuncT:
-    """Mark the function that should disable gc during execution"""
-
-    @functools.wraps(f)
-    def wrapped(*args: Any, **kwargs: Any) -> Any:
-        gc_enabled_originally = gc.isenabled()
[email protected]
+def disable_gc() -> Generator[None, None, None]:
+    gc_enabled_originally = gc.isenabled()
+    if gc_enabled_originally:
+        gc.disable()
+    try:
+        yield
+    finally:
         if gc_enabled_originally:
-            gc.disable()
-        try:
-            return f(*args, **kwargs)
-        finally:
-            if gc_enabled_originally:
-                gc.enable()
-
-    return cast(FuncT, wrapped)
+            gc.enable()
 
 
 _is_remote_only = None


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to