This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 854628f84565 [SPARK-55326][PYTHON][CONNECT] Release remote session
when SPARK_CONNECT_RELEASE_SESSION_ON_EXIT is set
854628f84565 is described below
commit 854628f84565ec0b7016ee4db1eb077765f58972
Author: Bobby Wang <[email protected]>
AuthorDate: Tue Mar 3 09:27:30 2026 +0900
[SPARK-55326][PYTHON][CONNECT] Release remote session when
SPARK_CONNECT_RELEASE_SESSION_ON_EXIT is set
### What changes were proposed in this pull request?
This PR adds an _on_exit handler to SparkConnectClient that is registered
with Python's atexit module. When enabled via the
SPARK_CONNECT_RELEASE_SESSION_ON_EXIT environment variable, the client will
automatically
### Why are the changes needed?
Currently, when a PySpark Connect client process exits without explicitly
calling `spark.stop()`, the session may remain active on the server side,
consuming resources unnecessarily. This change provides an opt-in mechanism to
automatically release the session during process exit
### Does this PR introduce _any_ user-facing change?
Yes. Users can now set the environment variable
`SPARK_CONNECT_RELEASE_SESSION_ON_EXIT=true` to enable automatic session
release when the Python process exits.
### How was this patch tested?
Pass the CIs.
### Was this patch authored or co-authored using generative AI tooling?
Yes, co-authored with claude-4.5-opus-high
Closes #54106 from wbo4958/release-on-exit.
Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 19 ++-
.../sql/tests/connect/client/test_client.py | 138 +++++++++++++++++++++
2 files changed, 155 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index aa060df24e41..ab7979a28326 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -755,8 +755,11 @@ class SparkConnectClient(object):
self._release_futures: weakref.WeakSet[concurrent.futures.Future] =
weakref.WeakSet()
- # cleanup ml cache if possible
- atexit.register(self._cleanup_ml_cache)
+ self._release_session_on_exit = os.getenv(
+ "SPARK_CONNECT_RELEASE_SESSION_ON_EXIT", "false"
+ ).lower() in ("true", "1")
+ # cleanup if possible
+ atexit.register(self._on_exit)
self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = []
self.global_user_context_extensions_lock = threading.Lock()
@@ -2281,6 +2284,18 @@ class SparkConnectClient(object):
except Exception:
return []
+ def _on_exit(self) -> None:
+ self._cleanup_ml_cache()
+ if self._release_session_on_exit and not self._closed:
+ try:
+ self.release_session()
+ except Exception:
+ pass
+ try:
+ self.close()
+ except Exception:
+ pass
+
def _cleanup_ml_cache(self) -> None:
try:
command = pb2.Command()
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py
b/python/pyspark/sql/tests/connect/client/test_client.py
index 55faff5e9ed3..85fbafe22728 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -450,6 +450,144 @@ class SparkConnectClientTestCase(unittest.TestCase):
for resp in client._stub.ExecutePlan(req, metadata=None):
assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
+ def test_on_exit_calls_release_and_close_when_enabled(self):
+ client = SparkConnectClient("sc://foo/",
use_reattachable_execute=False)
+ client._release_session_on_exit = True
+ client._closed = False
+
+ call_tracker = {"release_session": 0, "close": 0}
+
+ def mock_release_session():
+ call_tracker["release_session"] += 1
+
+ def mock_close():
+ call_tracker["close"] += 1
+
+ client.release_session = mock_release_session
+ client.close = mock_close
+
+ client._on_exit()
+
+ self.assertEqual(call_tracker["release_session"], 1)
+ self.assertEqual(call_tracker["close"], 1)
+
+ def test_on_exit_does_not_call_when_release_disabled(self):
+ """Test _on_exit does nothing when _release_session_on_exit is
False."""
+ client = SparkConnectClient("sc://foo/",
use_reattachable_execute=False)
+ client._release_session_on_exit = False
+ client._closed = False
+
+ call_tracker = {"release_session": 0, "close": 0}
+
+ def mock_release_session():
+ call_tracker["release_session"] += 1
+
+ def mock_close():
+ call_tracker["close"] += 1
+
+ client.release_session = mock_release_session
+ client.close = mock_close
+
+ client._on_exit()
+
+ self.assertEqual(call_tracker["release_session"], 0)
+ self.assertEqual(call_tracker["close"], 0)
+
+ def test_on_exit_does_not_call_when_already_closed(self):
+ """Test _on_exit does nothing when client is already closed."""
+ client = SparkConnectClient("sc://foo/",
use_reattachable_execute=False)
+ client._release_session_on_exit = True
+ client._closed = True
+
+ call_tracker = {"release_session": 0, "close": 0}
+
+ def mock_release_session():
+ call_tracker["release_session"] += 1
+
+ def mock_close():
+ call_tracker["close"] += 1
+
+ client.release_session = mock_release_session
+ client.close = mock_close
+
+ client._on_exit()
+
+ self.assertEqual(call_tracker["release_session"], 0)
+ self.assertEqual(call_tracker["close"], 0)
+
+ def test_on_exit_catches_release_session_exception(self):
+ """Test _on_exit continues to call close even if release_session
raises."""
+ client = SparkConnectClient("sc://foo/",
use_reattachable_execute=False)
+ client._release_session_on_exit = True
+ client._closed = False
+
+ call_tracker = {"release_session": 0, "close": 0}
+
+ def mock_release_session():
+ call_tracker["release_session"] += 1
+ raise Exception("release error")
+
+ def mock_close():
+ call_tracker["close"] += 1
+
+ client.release_session = mock_release_session
+ client.close = mock_close
+
+ # Should not raise
+ client._on_exit()
+
+ self.assertEqual(call_tracker["release_session"], 1)
+ self.assertEqual(call_tracker["close"], 1)
+
+ def test_on_exit_catches_close_exception(self):
+ """Test _on_exit silently catches exception from close."""
+ client = SparkConnectClient("sc://foo/",
use_reattachable_execute=False)
+ client._release_session_on_exit = True
+ client._closed = False
+
+ call_tracker = {"release_session": 0, "close": 0}
+
+ def mock_release_session():
+ call_tracker["release_session"] += 1
+
+ def mock_close():
+ call_tracker["close"] += 1
+ raise Exception("close error")
+
+ client.release_session = mock_release_session
+ client.close = mock_close
+
+ # Should not raise
+ client._on_exit()
+
+ self.assertEqual(call_tracker["release_session"], 1)
+ self.assertEqual(call_tracker["close"], 1)
+
+ def test_on_exit_catches_both_exceptions(self):
+ """Test _on_exit handles both release_session and close raising
exceptions."""
+ client = SparkConnectClient("sc://foo/",
use_reattachable_execute=False)
+ client._release_session_on_exit = True
+ client._closed = False
+
+ call_tracker = {"release_session": 0, "close": 0}
+
+ def mock_release_session():
+ call_tracker["release_session"] += 1
+ raise Exception("release error")
+
+ def mock_close():
+ call_tracker["close"] += 1
+ raise Exception("close error")
+
+ client.release_session = mock_release_session
+ client.close = mock_close
+
+ # Should not raise
+ client._on_exit()
+
+ self.assertEqual(call_tracker["release_session"], 1)
+ self.assertEqual(call_tracker["close"], 1)
+
def test_get_operations_statuses_all(self):
"""Test get_operations_statuses returns all operation statuses when no
IDs specified."""
OperationStatus = proto.GetStatusResponse.OperationStatus
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]