This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bc93c255bd8b [SPARK-53568][CONNECT][PYTHON] Fix several small bugs in 
Spark Connect Python client error handling logic
bc93c255bd8b is described below

commit bc93c255bd8bc73ffc29689445321f06b053eb33
Author: Alex Khakhlyuk <alex.khakhl...@gmail.com>
AuthorDate: Mon Sep 15 07:49:11 2025 +0900

    [SPARK-53568][CONNECT][PYTHON] Fix several small bugs in Spark Connect 
Python client error handling logic
    
    ### What changes were proposed in this pull request?
    
    1. [This PR 
](https://github.com/apache/spark/commit/5102370dcf37ebf64d19b536656576d6b068e59a#diff-67c4c88c462539a60764612d6ac0048523f0cdc44a8997a2228803565054c6a9)
 introduced a bug where `SparkConnectGrpcException.grpc_status_code: 
grpc.StatusCode` is set from `status.code`, which is an int. I am fixing this 
by setting it from rpc_error.code() which has the correct type.
    2. Errors without status don't contain the status codes, although it's 
always available for grpc exceptions. This is now fixed.
    3. There is an old bug where `SparkConnectGrpcException._errorClass` is set 
to "" instead of None, because `info.metadata["errorClass"] == 
"INVALID_HANDLE.SESSION_CHANGED"` populates it with "". Probably because 
`info.metadata` is a `defaultdict(str)` or smth similar. This is fixed by using 
`info.metadata.get("errorClass")` instead.
    4. added type hints for `rpc_call`, `status_code` and `status` to avoid 
type mismatch bugs in the future and make the code more readable.
    5. Added proper tests for grpc status code, error class and sql states.
    
    ### Why are the changes needed?
    
    Bug fixes
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #52325 from khakhlyuk/python-connect-error-handling-fixes.
    
    Authored-by: Alex Khakhlyuk <alex.khakhl...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/client/core.py          | 17 +++--
 .../sql/tests/connect/client/test_client.py        | 86 ++++++++++++++++++++++
 2 files changed, 98 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index cf730d8c796e..741d612f53f4 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -107,6 +107,7 @@ from pyspark.sql.connect.shell.progress import Progress, 
ProgressHandler, from_p
 
 if TYPE_CHECKING:
     from google.rpc.error_details_pb2 import ErrorInfo
+    from google.rpc.status_pb2 import Status
     from pyspark.sql.connect._typing import DataTypeOrString
     from pyspark.sql.connect.session import SparkSession
     from pyspark.sql.datasource import DataSource
@@ -1949,7 +1950,9 @@ class SparkConnectClient(object):
         logger.exception("GRPC Error received")
         # We have to cast the value here because, a RpcError is a Call as well.
         # 
https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__
-        status = rpc_status.from_call(cast(grpc.Call, rpc_error))
+        error: grpc.Call = cast(grpc.Call, rpc_error)
+        status_code: grpc.StatusCode = error.code()
+        status: Optional[Status] = rpc_status.from_call(error)
         if status:
             for d in status.details:
                 if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
@@ -1957,7 +1960,7 @@ class SparkConnectClient(object):
                     d.Unpack(info)
                     logger.debug(f"Received ErrorInfo: {info}")
 
-                    if info.metadata["errorClass"] == 
"INVALID_HANDLE.SESSION_CHANGED":
+                    if info.metadata.get("errorClass") == 
"INVALID_HANDLE.SESSION_CHANGED":
                         self._closed = True
 
                     raise convert_exception(
@@ -1965,14 +1968,18 @@ class SparkConnectClient(object):
                         status.message,
                         self._fetch_enriched_error(info),
                         self._display_server_stack_trace(),
-                        status.code,
+                        status_code,
                     ) from None
 
             raise SparkConnectGrpcException(
-                message=status.message, grpc_status_code=status.code
+                message=status.message,
+                grpc_status_code=status_code,
             ) from None
         else:
-            raise SparkConnectGrpcException(str(rpc_error)) from None
+            raise SparkConnectGrpcException(
+                message=str(error),
+                grpc_status_code=status_code,
+            ) from None
 
     def add_artifacts(self, *paths: str, pyfile: bool, archive: bool, file: 
bool) -> None:
         try:
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index c3954827bae5..21995f235839 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -20,12 +20,15 @@ import uuid
 from collections.abc import Generator
 from typing import Optional, Any, Union
 
+from pyspark.errors.exceptions.connect import SparkConnectGrpcException
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import eventually
 
 if should_test_connect:
     import grpc
+    import google.protobuf.any_pb2 as any_pb2
     from google.rpc import status_pb2
+    from google.rpc.error_details_pb2 import ErrorInfo
     import pandas as pd
     import pyarrow as pa
     from pyspark.sql.connect.client import SparkConnectClient, 
DefaultChannelBuilder
@@ -450,6 +453,89 @@ class 
SparkConnectClientReattachTestCase(unittest.TestCase):
         reattach = ite._create_reattach_execute_request()
         self.assertEqual(reattach.client_observed_server_side_session_id, 
session_id)
 
+    def test_server_unreachable(self):
+        # DNS resolution should fail for "foo". This error is a retriable 
UNAVAILABLE error.
+        client = SparkConnectClient(
+            "sc://foo", use_reattachable_execute=False, 
retry_policy=dict(max_retries=0)
+        )
+        with self.assertRaises(SparkConnectGrpcException) as cm:
+            command = proto.Command()
+            client.execute_command(command)
+        err = cm.exception
+        self.assertEqual(err.getGrpcStatusCode(), grpc.StatusCode.UNAVAILABLE)
+        self.assertEqual(err.getErrorClass(), None)
+        self.assertEqual(err.getSqlState(), None)
+
+    def test_error_codes(self):
+        msg = "Something went wrong on the server"
+
+        def raise_without_status():
+            raise TestException(msg=msg, trailing_status=None)
+
+        def raise_without_status_unauthenticated():
+            raise TestException(msg=msg, code=grpc.StatusCode.UNAUTHENTICATED)
+
+        def raise_without_status_permission_denied():
+            raise TestException(msg=msg, 
code=grpc.StatusCode.PERMISSION_DENIED)
+
+        def raise_without_details():
+            status = status_pb2.Status(
+                code=grpc.StatusCode.INTERNAL.value[0], message=msg, details=[]
+            )
+            raise TestException(msg=msg, trailing_status=status)
+
+        def raise_without_metadata():
+            any = any_pb2.Any()
+            any.Pack(ErrorInfo())
+            status = status_pb2.Status(
+                code=grpc.StatusCode.INTERNAL.value[0], message=msg, 
details=[any]
+            )
+            raise TestException(msg=msg, trailing_status=status)
+
+        def raise_with_error_class():
+            any = any_pb2.Any()
+            any.Pack(ErrorInfo(metadata=dict(errorClass="TEST_ERROR_CLASS")))
+            status = status_pb2.Status(
+                code=grpc.StatusCode.INTERNAL.value[0], message=msg, 
details=[any]
+            )
+            raise TestException(msg=msg, trailing_status=status)
+
+        def raise_with_sql_state():
+            any = any_pb2.Any()
+            any.Pack(ErrorInfo(metadata=dict(sqlState="TEST_SQL_STATE")))
+            status = status_pb2.Status(
+                code=grpc.StatusCode.INTERNAL.value[0], message=msg, 
details=[any]
+            )
+            raise TestException(msg=msg, trailing_status=status)
+
+        test_cases = [
+            (raise_without_status, grpc.StatusCode.INTERNAL, None, None),
+            (raise_without_status_unauthenticated, 
grpc.StatusCode.UNAUTHENTICATED, None, None),
+            (raise_without_status_permission_denied, 
grpc.StatusCode.PERMISSION_DENIED, None, None),
+            (raise_without_details, grpc.StatusCode.INTERNAL, None, None),
+            (raise_without_metadata, grpc.StatusCode.INTERNAL, None, None),
+            (raise_with_error_class, grpc.StatusCode.INTERNAL, 
"TEST_ERROR_CLASS", None),
+            (raise_with_sql_state, grpc.StatusCode.INTERNAL, None, 
"TEST_SQL_STATE"),
+        ]
+
+        for (
+            response_function,
+            expected_status_code,
+            expected_error_class,
+            expected_sql_state,
+        ) in test_cases:
+            client = SparkConnectClient(
+                "sc://foo", use_reattachable_execute=False, 
retry_policy=dict(max_retries=0)
+            )
+            client._stub = self._stub_with([response_function])
+            with self.assertRaises(SparkConnectGrpcException) as cm:
+                command = proto.Command()
+                client.execute_command(command)
+            err = cm.exception
+            self.assertEqual(err.getGrpcStatusCode(), expected_status_code)
+            self.assertEqual(err.getErrorClass(), expected_error_class)
+            self.assertEqual(err.getSqlState(), expected_sql_state)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.client.test_client import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to