This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bc93c255bd8b [SPARK-53568][CONNECT][PYTHON] Fix several small bugs in Spark Connect Python client error handling logic bc93c255bd8b is described below commit bc93c255bd8bc73ffc29689445321f06b053eb33 Author: Alex Khakhlyuk <alex.khakhl...@gmail.com> AuthorDate: Mon Sep 15 07:49:11 2025 +0900 [SPARK-53568][CONNECT][PYTHON] Fix several small bugs in Spark Connect Python client error handling logic ### What changes were proposed in this pull request? 1. [This PR ](https://github.com/apache/spark/commit/5102370dcf37ebf64d19b536656576d6b068e59a#diff-67c4c88c462539a60764612d6ac0048523f0cdc44a8997a2228803565054c6a9) introduced a bug where `SparkConnectGrpcException.grpc_status_code: grpc.StatusCode` is set from `status.code`, which is an int. I am fixing this by setting it from rpc_error.code() which has the correct type. 2. Errors without status don't contain the status codes, although it's always available for grpc exceptions. This is now fixed. 3. There is an old bug where `SparkConnectGrpcException._errorClass` is set to "" instead of None, because `info.metadata["errorClass"] == "INVALID_HANDLE.SESSION_CHANGED"` populates it with "". Probably because `info.metadata` is a `defaultdict(str)` or smth similar. This is fixed by using `info.metadata.get("errorClass")` instead. 4. added type hints for `rpc_call`, `status_code` and `status` to avoid type mismatch bugs in the future and make the code more readable. 5. Added proper tests for grpc status code, error class and sql states. ### Why are the changes needed? Bug fixes ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New tests ### Was this patch authored or co-authored using generative AI tooling? Closes #52325 from khakhlyuk/python-connect-error-handling-fixes. Authored-by: Alex Khakhlyuk <alex.khakhl...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/client/core.py | 17 +++-- .../sql/tests/connect/client/test_client.py | 86 ++++++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index cf730d8c796e..741d612f53f4 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -107,6 +107,7 @@ from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, from_p if TYPE_CHECKING: from google.rpc.error_details_pb2 import ErrorInfo + from google.rpc.status_pb2 import Status from pyspark.sql.connect._typing import DataTypeOrString from pyspark.sql.connect.session import SparkSession from pyspark.sql.datasource import DataSource @@ -1949,7 +1950,9 @@ class SparkConnectClient(object): logger.exception("GRPC Error received") # We have to cast the value here because, a RpcError is a Call as well. # https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__ - status = rpc_status.from_call(cast(grpc.Call, rpc_error)) + error: grpc.Call = cast(grpc.Call, rpc_error) + status_code: grpc.StatusCode = error.code() + status: Optional[Status] = rpc_status.from_call(error) if status: for d in status.details: if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): @@ -1957,7 +1960,7 @@ class SparkConnectClient(object): d.Unpack(info) logger.debug(f"Received ErrorInfo: {info}") - if info.metadata["errorClass"] == "INVALID_HANDLE.SESSION_CHANGED": + if info.metadata.get("errorClass") == "INVALID_HANDLE.SESSION_CHANGED": self._closed = True raise convert_exception( @@ -1965,14 +1968,18 @@ class SparkConnectClient(object): status.message, self._fetch_enriched_error(info), self._display_server_stack_trace(), - status.code, + status_code, ) from None raise SparkConnectGrpcException( - message=status.message, grpc_status_code=status.code + message=status.message, + grpc_status_code=status_code, ) from None else: - raise SparkConnectGrpcException(str(rpc_error)) from None + raise SparkConnectGrpcException( + message=str(error), + grpc_status_code=status_code, + ) from None def add_artifacts(self, *paths: str, pyfile: bool, archive: bool, file: bool) -> None: try: diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c3954827bae5..21995f235839 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -20,12 +20,15 @@ import uuid from collections.abc import Generator from typing import Optional, Any, Union +from pyspark.errors.exceptions.connect import SparkConnectGrpcException from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import eventually if should_test_connect: import grpc + import google.protobuf.any_pb2 as any_pb2 from google.rpc import status_pb2 + from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd import pyarrow as pa from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder @@ -450,6 +453,89 @@ class SparkConnectClientReattachTestCase(unittest.TestCase): reattach = ite._create_reattach_execute_request() self.assertEqual(reattach.client_observed_server_side_session_id, session_id) + def test_server_unreachable(self): + # DNS resolution should fail for "foo". This error is a retriable UNAVAILABLE error. + client = SparkConnectClient( + "sc://foo", use_reattachable_execute=False, retry_policy=dict(max_retries=0) + ) + with self.assertRaises(SparkConnectGrpcException) as cm: + command = proto.Command() + client.execute_command(command) + err = cm.exception + self.assertEqual(err.getGrpcStatusCode(), grpc.StatusCode.UNAVAILABLE) + self.assertEqual(err.getErrorClass(), None) + self.assertEqual(err.getSqlState(), None) + + def test_error_codes(self): + msg = "Something went wrong on the server" + + def raise_without_status(): + raise TestException(msg=msg, trailing_status=None) + + def raise_without_status_unauthenticated(): + raise TestException(msg=msg, code=grpc.StatusCode.UNAUTHENTICATED) + + def raise_without_status_permission_denied(): + raise TestException(msg=msg, code=grpc.StatusCode.PERMISSION_DENIED) + + def raise_without_details(): + status = status_pb2.Status( + code=grpc.StatusCode.INTERNAL.value[0], message=msg, details=[] + ) + raise TestException(msg=msg, trailing_status=status) + + def raise_without_metadata(): + any = any_pb2.Any() + any.Pack(ErrorInfo()) + status = status_pb2.Status( + code=grpc.StatusCode.INTERNAL.value[0], message=msg, details=[any] + ) + raise TestException(msg=msg, trailing_status=status) + + def raise_with_error_class(): + any = any_pb2.Any() + any.Pack(ErrorInfo(metadata=dict(errorClass="TEST_ERROR_CLASS"))) + status = status_pb2.Status( + code=grpc.StatusCode.INTERNAL.value[0], message=msg, details=[any] + ) + raise TestException(msg=msg, trailing_status=status) + + def raise_with_sql_state(): + any = any_pb2.Any() + any.Pack(ErrorInfo(metadata=dict(sqlState="TEST_SQL_STATE"))) + status = status_pb2.Status( + code=grpc.StatusCode.INTERNAL.value[0], message=msg, details=[any] + ) + raise TestException(msg=msg, trailing_status=status) + + test_cases = [ + (raise_without_status, grpc.StatusCode.INTERNAL, None, None), + (raise_without_status_unauthenticated, grpc.StatusCode.UNAUTHENTICATED, None, None), + (raise_without_status_permission_denied, grpc.StatusCode.PERMISSION_DENIED, None, None), + (raise_without_details, grpc.StatusCode.INTERNAL, None, None), + (raise_without_metadata, grpc.StatusCode.INTERNAL, None, None), + (raise_with_error_class, grpc.StatusCode.INTERNAL, "TEST_ERROR_CLASS", None), + (raise_with_sql_state, grpc.StatusCode.INTERNAL, None, "TEST_SQL_STATE"), + ] + + for ( + response_function, + expected_status_code, + expected_error_class, + expected_sql_state, + ) in test_cases: + client = SparkConnectClient( + "sc://foo", use_reattachable_execute=False, retry_policy=dict(max_retries=0) + ) + client._stub = self._stub_with([response_function]) + with self.assertRaises(SparkConnectGrpcException) as cm: + command = proto.Command() + client.execute_command(command) + err = cm.exception + self.assertEqual(err.getGrpcStatusCode(), expected_status_code) + self.assertEqual(err.getErrorClass(), expected_error_class) + self.assertEqual(err.getSqlState(), expected_sql_state) + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_client import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org