This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d9157101e26 [SPARK-46061][PYTHON][TESTS] Add the test party for
reattach test case
d9157101e26 is described below
commit d9157101e260e54937785f50b7b1271e1da018e2
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Nov 23 12:40:05 2023 +0900
[SPARK-46061][PYTHON][TESTS] Add the test party for reattach test case
### What changes were proposed in this pull request?
This PR adds the same test `ReleaseSession releases all queries and does
not allow more requests in the session` added in SPARK-45798, to PySpark side.
### Why are the changes needed?
To identify an issue such as SPARK-46042.
_To clarify, Python side does not have the behaviour of sending a request
early when you create an generator of the response_.
### Does this PR introduce _any_ user-facing change?
No, test-only. It includes a bit of bug fixes but they are pretty trivial
and minor.
### How was this patch tested?
Unittests added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43965 from HyukjinKwon/SPARK-46042-followup.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/client/core.py | 15 ++-
.../sql/tests/connect/client/test_reattach.py | 124 +++++++++++++++++++++
3 files changed, 137 insertions(+), 3 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 89b2ff7976d..feb49062316 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -895,6 +895,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_utils",
"pyspark.sql.tests.connect.client.test_artifact",
"pyspark.sql.tests.connect.client.test_client",
+ "pyspark.sql.tests.connect.client.test_reattach",
"pyspark.sql.tests.connect.streaming.test_parity_streaming",
"pyspark.sql.tests.connect.streaming.test_parity_listener",
"pyspark.sql.tests.connect.streaming.test_parity_foreach",
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index a2590dec960..58b48bd69ba 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1525,9 +1525,15 @@ class SparkConnectClient(object):
from pyspark.sql.connect.conf import RuntimeConf
conf = RuntimeConf(self)
- if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
+ try:
+ if conf.get("spark.sql.connect.serverStacktrace.enabled") ==
"true":
+ return True
+ return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") ==
"true"
+ except Exception as e: # noqa: F841
+ # Falls back to true if an exception occurs during reading the
config.
+ # Otherwise, it will recursively try to get the conf when it
consistently
+ # fails, ending up with `RecursionError`.
return True
- return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
@@ -1614,7 +1620,10 @@ class SparkConnectClient(object):
f"{response.session_id} != {self._session_id}"
)
if self._server_session_id is not None:
- if response.server_side_session_id != self._server_session_id:
+ if (
+ response.server_side_session_id
+ and response.server_side_session_id != self._server_session_id
+ ):
raise PySparkAssertionError(
"Received incorrect server side session identifier for
request. "
"Please create a new Spark Session to reconnect. ("
diff --git a/python/pyspark/sql/tests/connect/client/test_reattach.py
b/python/pyspark/sql/tests/connect/client/test_reattach.py
new file mode 100644
index 00000000000..5f2cb3c4937
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/client/test_reattach.py
@@ -0,0 +1,124 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark.sql import SparkSession as PySparkSession
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.testing.utils import eventually
+
+
+class SparkConnectReattachTestCase(ReusedConnectTestCase, SQLTestUtils,
PandasOnSparkTestUtils):
+ @classmethod
+ def setUpClass(cls):
+ super(SparkConnectReattachTestCase, cls).setUpClass()
+ # Disable the shared namespace so pyspark.sql.functions, etc point the
regular
+ # PySpark libraries.
+ os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+
+ cls.connect = cls.spark # Switch Spark Connect session and regular
PySpark session.
+ cls.spark = PySparkSession._instantiatedSession
+ assert cls.spark is not None
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ # Stopping Spark Connect closes the session in JVM at the server.
+ cls.spark = cls.connect
+ del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
+ finally:
+ super(SparkConnectReattachTestCase, cls).tearDownClass()
+
+ def test_release_sessions(self):
+ big_enough_query = "select * from range(1000000)"
+ query1 = self.connect.sql(big_enough_query).toLocalIterator()
+ query2 = self.connect.sql(big_enough_query).toLocalIterator()
+ query3 = self.connect.sql("select 1").toLocalIterator()
+
+ next(query1)
+ next(query2)
+
+ jvm = PySparkSession._instantiatedSession._jvm # type:
ignore[union-attr]
+ service = getattr(
+ getattr(
+ jvm.org.apache.spark.sql.connect.service, # type:
ignore[union-attr]
+ "SparkConnectService$",
+ ),
+ "MODULE$",
+ )
+
+ @eventually(catch_assertions=True)
+ def wait_for_requests():
+
self.assertEqual(service.executionManager().listExecuteHolders().length(), 2)
+
+ wait_for_requests()
+
+ # Close session
+ self.connect.client.release_session()
+ # Calling release session again should be a no-op.
+ self.connect.client.release_session()
+
+ @eventually(catch_assertions=True)
+ def wait_for_responses():
+
self.assertEqual(service.executionManager().listExecuteHolders().length(), 0)
+
+ wait_for_responses()
+
+ # query1 and query2 could get either an:
+ # OPERATION_CANCELED if it happens fast - when closing the session
interrupted the queries,
+ # and that error got pushed to the client buffers before the client
got disconnected.
+ # OPERATION_ABANDONED if it happens slow - when closing the session
interrupted the client
+ # RPCs before it pushed out the error above. The client would then get
an
+ # INVALID_CURSOR.DISCONNECTED, which it will retry with a
ReattachExecute, and then get an
+ # INVALID_HANDLE.OPERATION_ABANDONED.
+
+ def check_error(q):
+ try:
+ list(q) # Iterate all.
+ except Exception as e: # noqa: F841
+ return e
+
+ e = check_error(query1)
+ self.assertIsNotNone(e, "An exception has to be thrown")
+ self.assertTrue(
+ "OPERATION_CANCELED" in str(e) or
"INVALID_HANDLE.OPERATION_ABANDONED" in str(e)
+ )
+ e = check_error(query2)
+ self.assertIsNotNone(e, "An exception has to be thrown")
+ self.assertTrue(
+ "OPERATION_CANCELED" in str(e) or
"INVALID_HANDLE.OPERATION_ABANDONED" in str(e)
+ )
+
+ # query3 has not been submitted before, so it should now fail with
SESSION_CLOSED
+ e = check_error(query3)
+ self.assertIsNotNone(3, "An exception has to be thrown")
+ self.assertIn("INVALID_HANDLE.SESSION_CLOSED", str(e))
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.client.test_reattach import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]