This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d9157101e26 [SPARK-46061][PYTHON][TESTS] Add the test party for 
reattach test case
d9157101e26 is described below

commit d9157101e260e54937785f50b7b1271e1da018e2
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Nov 23 12:40:05 2023 +0900

    [SPARK-46061][PYTHON][TESTS] Add the test party for reattach test case
    
    ### What changes were proposed in this pull request?
    
    This PR adds the same test `ReleaseSession releases all queries and does 
not allow more requests in the session` added in SPARK-45798, to PySpark side.
    
    ### Why are the changes needed?
    
    To identify an issue such as SPARK-46042.
    
    _To clarify, Python side does not have the behaviour of sending a request 
early when you create an generator of the response_.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only. It includes a bit of bug fixes but they are pretty trivial 
and minor.
    
    ### How was this patch tested?
    
    Unittests added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43965 from HyukjinKwon/SPARK-46042-followup.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/sql/connect/client/core.py          |  15 ++-
 .../sql/tests/connect/client/test_reattach.py      | 124 +++++++++++++++++++++
 3 files changed, 137 insertions(+), 3 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 89b2ff7976d..feb49062316 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -895,6 +895,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_utils",
         "pyspark.sql.tests.connect.client.test_artifact",
         "pyspark.sql.tests.connect.client.test_client",
+        "pyspark.sql.tests.connect.client.test_reattach",
         "pyspark.sql.tests.connect.streaming.test_parity_streaming",
         "pyspark.sql.tests.connect.streaming.test_parity_listener",
         "pyspark.sql.tests.connect.streaming.test_parity_foreach",
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index a2590dec960..58b48bd69ba 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1525,9 +1525,15 @@ class SparkConnectClient(object):
         from pyspark.sql.connect.conf import RuntimeConf
 
         conf = RuntimeConf(self)
-        if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
+        try:
+            if conf.get("spark.sql.connect.serverStacktrace.enabled") == 
"true":
+                return True
+            return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == 
"true"
+        except Exception as e:  # noqa: F841
+            # Falls back to true if an exception occurs during reading the 
config.
+            # Otherwise, it will recursively try to get the conf when it 
consistently
+            # fails, ending up with `RecursionError`.
             return True
-        return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
 
     def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
         """
@@ -1614,7 +1620,10 @@ class SparkConnectClient(object):
                 f"{response.session_id} != {self._session_id}"
             )
         if self._server_session_id is not None:
-            if response.server_side_session_id != self._server_session_id:
+            if (
+                response.server_side_session_id
+                and response.server_side_session_id != self._server_session_id
+            ):
                 raise PySparkAssertionError(
                     "Received incorrect server side session identifier for 
request. "
                     "Please create a new Spark Session to reconnect. ("
diff --git a/python/pyspark/sql/tests/connect/client/test_reattach.py 
b/python/pyspark/sql/tests/connect/client/test_reattach.py
new file mode 100644
index 00000000000..5f2cb3c4937
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/client/test_reattach.py
@@ -0,0 +1,124 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark.sql import SparkSession as PySparkSession
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.testing.utils import eventually
+
+
+class SparkConnectReattachTestCase(ReusedConnectTestCase, SQLTestUtils, 
PandasOnSparkTestUtils):
+    @classmethod
+    def setUpClass(cls):
+        super(SparkConnectReattachTestCase, cls).setUpClass()
+        # Disable the shared namespace so pyspark.sql.functions, etc point the 
regular
+        # PySpark libraries.
+        os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+
+        cls.connect = cls.spark  # Switch Spark Connect session and regular 
PySpark session.
+        cls.spark = PySparkSession._instantiatedSession
+        assert cls.spark is not None
+
+    @classmethod
+    def tearDownClass(cls):
+        try:
+            # Stopping Spark Connect closes the session in JVM at the server.
+            cls.spark = cls.connect
+            del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
+        finally:
+            super(SparkConnectReattachTestCase, cls).tearDownClass()
+
+    def test_release_sessions(self):
+        big_enough_query = "select * from range(1000000)"
+        query1 = self.connect.sql(big_enough_query).toLocalIterator()
+        query2 = self.connect.sql(big_enough_query).toLocalIterator()
+        query3 = self.connect.sql("select 1").toLocalIterator()
+
+        next(query1)
+        next(query2)
+
+        jvm = PySparkSession._instantiatedSession._jvm  # type: 
ignore[union-attr]
+        service = getattr(
+            getattr(
+                jvm.org.apache.spark.sql.connect.service,  # type: 
ignore[union-attr]
+                "SparkConnectService$",
+            ),
+            "MODULE$",
+        )
+
+        @eventually(catch_assertions=True)
+        def wait_for_requests():
+            
self.assertEqual(service.executionManager().listExecuteHolders().length(), 2)
+
+        wait_for_requests()
+
+        # Close session
+        self.connect.client.release_session()
+        # Calling release session again should be a no-op.
+        self.connect.client.release_session()
+
+        @eventually(catch_assertions=True)
+        def wait_for_responses():
+            
self.assertEqual(service.executionManager().listExecuteHolders().length(), 0)
+
+        wait_for_responses()
+
+        # query1 and query2 could get either an:
+        # OPERATION_CANCELED if it happens fast - when closing the session 
interrupted the queries,
+        # and that error got pushed to the client buffers before the client 
got disconnected.
+        # OPERATION_ABANDONED if it happens slow - when closing the session 
interrupted the client
+        # RPCs before it pushed out the error above. The client would then get 
an
+        # INVALID_CURSOR.DISCONNECTED, which it will retry with a 
ReattachExecute, and then get an
+        # INVALID_HANDLE.OPERATION_ABANDONED.
+
+        def check_error(q):
+            try:
+                list(q)  # Iterate all.
+            except Exception as e:  # noqa: F841
+                return e
+
+        e = check_error(query1)
+        self.assertIsNotNone(e, "An exception has to be thrown")
+        self.assertTrue(
+            "OPERATION_CANCELED" in str(e) or 
"INVALID_HANDLE.OPERATION_ABANDONED" in str(e)
+        )
+        e = check_error(query2)
+        self.assertIsNotNone(e, "An exception has to be thrown")
+        self.assertTrue(
+            "OPERATION_CANCELED" in str(e) or 
"INVALID_HANDLE.OPERATION_ABANDONED" in str(e)
+        )
+
+        # query3 has not been submitted before, so it should now fail with 
SESSION_CLOSED
+        e = check_error(query3)
+        self.assertIsNotNone(3, "An exception has to be thrown")
+        self.assertIn("INVALID_HANDLE.SESSION_CLOSED", str(e))
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.client.test_reattach import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to