This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c0caee754116 [SPARK-46402][PYTHON] Add getMessageParameters and getQueryContext support c0caee754116 is described below commit c0caee75411650ddf3a929fb23c3e93aff39b55d Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Dec 14 21:41:18 2023 -0800 [SPARK-46402][PYTHON] Add getMessageParameters and getQueryContext support ### What changes were proposed in this pull request? This PR adds new API with/without Spark Connect as below. - `getMessageParamater` working fine - `getQueryContext` ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds the new API, `QueryContext`, `QueryContextType`, `PySparkException.getQueryContext` with making `getMessageParamater` working fine. ### How was this patch tested? Unittests were added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44349 from HyukjinKwon/error-fields. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/base.proto | 2 +- python/docs/source/reference/pyspark.errors.rst | 3 + python/pyspark/errors/__init__.py | 4 + python/pyspark/errors/exceptions/base.py | 113 +++++++++++++++++++-- python/pyspark/errors/exceptions/captured.py | 51 +++++++++- python/pyspark/errors/exceptions/connect.py | 91 ++++++++++++++++- python/pyspark/sql/connect/proto/base_pb2.py | 26 ++--- python/pyspark/sql/connect/proto/base_pb2.pyi | 10 +- python/pyspark/sql/tests/connect/test_utils.py | 5 +- python/pyspark/sql/tests/test_utils.py | 41 +++++--- 10 files changed, 301 insertions(+), 45 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index da089dcd7564..f24ca0a8fc3b 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -925,7 +925,7 @@ message FetchErrorDetailsResponse { string fragment = 5; // The user code (call site of the API) that caused throwing the exception. - string callSite = 6; + string call_site = 6; // Summary of the exception cause. string summary = 7; diff --git a/python/docs/source/reference/pyspark.errors.rst b/python/docs/source/reference/pyspark.errors.rst index a40f88ce2044..fd824ab8901c 100644 --- a/python/docs/source/reference/pyspark.errors.rst +++ b/python/docs/source/reference/pyspark.errors.rst @@ -47,6 +47,8 @@ Classes PySparkImportError PySparkIndexError PythonException + QueryContext + QueryContextType QueryExecutionException RetriesExceeded SessionNotSameException @@ -70,4 +72,5 @@ Methods PySparkException.getErrorClass PySparkException.getMessage PySparkException.getMessageParameters + PySparkException.getQueryContext PySparkException.getSqlState diff --git a/python/pyspark/errors/__init__.py b/python/pyspark/errors/__init__.py index a4f64e85f875..85e9bb65f0a6 100644 --- a/python/pyspark/errors/__init__.py +++ b/python/pyspark/errors/__init__.py @@ -48,6 +48,8 @@ from pyspark.errors.exceptions.base import ( # noqa: F401 PySparkPicklingError, RetriesExceeded, PySparkKeyError, + QueryContext, + QueryContextType, ) @@ -81,4 +83,6 @@ __all__ = [ "PySparkPicklingError", "RetriesExceeded", "PySparkKeyError", + "QueryContext", + "QueryContextType", ] diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py index e40e1b2e93cb..dcfc6df77a77 100644 --- a/python/pyspark/errors/exceptions/base.py +++ b/python/pyspark/errors/exceptions/base.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING +from abc import ABC, abstractmethod +from enum import Enum +from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING, List from pyspark.errors.utils import ErrorClassesReader from pickle import PicklingError @@ -34,12 +35,10 @@ class PySparkException(Exception): message: Optional[str] = None, error_class: Optional[str] = None, message_parameters: Optional[Dict[str, str]] = None, + query_contexts: Optional[List["QueryContext"]] = None, ): - # `message` vs `error_class` & `message_parameters` are mutually exclusive. - assert (message is not None and (error_class is None and message_parameters is None)) or ( - message is None and (error_class is not None and message_parameters is not None) - ) - + if query_contexts is None: + query_contexts = [] self._error_reader = ErrorClassesReader() if message is None: @@ -51,6 +50,7 @@ class PySparkException(Exception): self._error_class = error_class self._message_parameters = message_parameters + self._query_contexts = query_contexts def getErrorClass(self) -> Optional[str]: """ @@ -62,6 +62,7 @@ class PySparkException(Exception): -------- :meth:`PySparkException.getMessage` :meth:`PySparkException.getMessageParameters` + :meth:`PySparkException.getQueryContext` :meth:`PySparkException.getSqlState` """ return self._error_class @@ -76,6 +77,7 @@ class PySparkException(Exception): -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getMessage` + :meth:`PySparkException.getQueryContext` :meth:`PySparkException.getSqlState` """ return self._message_parameters @@ -93,6 +95,7 @@ class PySparkException(Exception): :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getMessage` :meth:`PySparkException.getMessageParameters` + :meth:`PySparkException.getQueryContext` """ return None @@ -106,10 +109,26 @@ class PySparkException(Exception): -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getMessageParameters` + :meth:`PySparkException.getQueryContext` :meth:`PySparkException.getSqlState` """ return f"[{self.getErrorClass()}] {self._message}" + def getQueryContext(self) -> List["QueryContext"]: + """ + Returns :class:`QueryContext`. + + .. versionadded:: 4.0.0 + + See Also + -------- + :meth:`PySparkException.getErrorClass` + :meth:`PySparkException.getMessageParameters` + :meth:`PySparkException.getMessage` + :meth:`PySparkException.getSqlState` + """ + return self._query_contexts + def __str__(self) -> str: if self.getErrorClass() is not None: return self.getMessage() @@ -294,3 +313,83 @@ class PySparkImportError(PySparkException, ImportError): """ Wrapper class for ImportError to support error classes. """ + + +class QueryContextType(Enum): + """ + The type of :class:`QueryContext`. + + .. versionadded:: 4.0.0 + """ + + SQL = 0 + DataFrame = 1 + + +class QueryContext(ABC): + """ + Query context of a :class:`PySparkException`. It helps users understand + where error occur while executing queries. + + .. versionadded:: 4.0.0 + """ + + @abstractmethod + def contextType(self) -> QueryContextType: + """ + The type of this query context. + """ + ... + + @abstractmethod + def objectType(self) -> str: + """ + The object type of the query which throws the exception. + If the exception is directly from the main query, it should be an empty string. + Otherwise, it should be the exact object type in upper case. For example, a "VIEW". + """ + ... + + @abstractmethod + def objectName(self) -> str: + """ + The object name of the query which throws the exception. + If the exception is directly from the main query, it should be an empty string. + Otherwise, it should be the object name. For example, a view name "V1". + """ + ... + + @abstractmethod + def startIndex(self) -> int: + """ + The starting index in the query text which throws the exception. The index starts from 0. + """ + ... + + @abstractmethod + def stopIndex(self) -> int: + """ + The stopping index in the query which throws the exception. The index starts from 0. + """ + ... + + @abstractmethod + def fragment(self) -> str: + """ + The corresponding fragment of the query which throws the exception. + """ + ... + + @abstractmethod + def callSite(self) -> str: + """ + The user code (call site of the API) that caused throwing the exception. + """ + ... + + @abstractmethod + def summary(self) -> str: + """ + Summary of the exception cause. + """ + ... diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index 4164bb7b428d..687bdec14154 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -15,11 +15,11 @@ # limitations under the License. # from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterator, Optional, cast +from typing import Any, Callable, Dict, Iterator, Optional, cast, List import py4j from py4j.protocol import Py4JJavaError -from py4j.java_gateway import is_instance_of +from py4j.java_gateway import is_instance_of, JavaObject from pyspark import SparkContext from pyspark.errors.exceptions.base import ( @@ -39,6 +39,8 @@ from pyspark.errors.exceptions.base import ( SparkNoSuchElementException as BaseNoSuchElementException, StreamingQueryException as BaseStreamingQueryException, UnknownException as BaseUnknownException, + QueryContext as BaseQueryContext, + QueryContextType, ) @@ -136,6 +138,17 @@ class CapturedException(PySparkException): else: return "" + def getQueryContext(self) -> List[BaseQueryContext]: + assert SparkContext._gateway is not None + + gw = SparkContext._gateway + if self._origin is not None and is_instance_of( + gw, self._origin, "org.apache.spark.SparkThrowable" + ): + return [QueryContext(q) for q in self._origin.getQueryContext()] + else: + return [] + def convert_exception(e: Py4JJavaError) -> CapturedException: assert e is not None @@ -332,3 +345,37 @@ class UnknownException(CapturedException, BaseUnknownException): """ None of the other exceptions. """ + + +class QueryContext(BaseQueryContext): + def __init__(self, q: JavaObject): + self._q = q + + def contextType(self) -> QueryContextType: + context_type = self._q.contextType().toString() + assert context_type in ("SQL", "DataFrame") + if context_type == "DataFrame": + return QueryContextType.DataFrame + else: + return QueryContextType.SQL + + def objectType(self) -> str: + return str(self._q.objectType()) + + def objectName(self) -> str: + return str(self._q.objectName()) + + def startIndex(self) -> int: + return int(self._q.startIndex()) + + def stopIndex(self) -> int: + return int(self._q.stopIndex()) + + def fragment(self) -> str: + return str(self._q.fragment()) + + def callSite(self) -> str: + return str(self._q.callSite()) + + def summary(self) -> str: + return str(self._q.summary()) diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index aaa52f9b20a5..ba172135cb64 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -34,6 +34,8 @@ from pyspark.errors.exceptions.base import ( SparkRuntimeException as BaseSparkRuntimeException, SparkNoSuchElementException as BaseNoSuchElementException, SparkUpgradeException as BaseSparkUpgradeException, + QueryContext as BaseQueryContext, + QueryContextType, ) if TYPE_CHECKING: @@ -55,8 +57,8 @@ def convert_exception( classes = [] sql_state = None error_class = None - - stacktrace: Optional[str] = None + message_parameters = None + query_contexts: Optional[List[BaseQueryContext]] = None if "classes" in info.metadata: classes = json.loads(info.metadata["classes"]) @@ -67,6 +69,7 @@ def convert_exception( if "errorClass" in info.metadata: error_class = info.metadata["errorClass"] + stacktrace: Optional[str] = None if resp is not None and resp.HasField("root_error_idx"): message = resp.errors[resp.root_error_idx].message stacktrace = _extract_jvm_stacktrace(resp) @@ -75,79 +78,109 @@ def convert_exception( stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None display_server_stacktrace = display_server_stacktrace if stacktrace is not None else False + if ( + resp is not None + and resp.errors + and hasattr(resp.errors[resp.root_error_idx], "spark_throwable") + ): + message_parameters = dict( + resp.errors[resp.root_error_idx].spark_throwable.message_parameters + ) + query_contexts = [] + for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts: + query_contexts.append(QueryContext(query_context)) + if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: return ParseException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) # Order matters. ParseException inherits AnalysisException. elif "org.apache.spark.sql.AnalysisException" in classes: return AnalysisException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: return StreamingQueryException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: return QueryExecutionException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) # Order matters. NumberFormatException inherits IllegalArgumentException. elif "java.lang.NumberFormatException" in classes: return NumberFormatException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "java.lang.IllegalArgumentException" in classes: return IllegalArgumentException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "java.lang.ArithmeticException" in classes: return ArithmeticException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "java.lang.UnsupportedOperationException" in classes: return UnsupportedOperationException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "java.lang.ArrayIndexOutOfBoundsException" in classes: return ArrayIndexOutOfBoundsException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "java.time.DateTimeException" in classes: return DateTimeException( @@ -156,22 +189,27 @@ def convert_exception( sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "org.apache.spark.SparkRuntimeException" in classes: return SparkRuntimeException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "org.apache.spark.SparkUpgradeException" in classes: return SparkUpgradeException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) elif "org.apache.spark.api.python.PythonException" in classes: return PythonException( @@ -182,27 +220,33 @@ def convert_exception( return SparkNoSuchElementException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) # Make sure that the generic SparkException is handled last. elif "org.apache.spark.SparkException" in classes: return SparkException( message, error_class=error_class, + message_parameters=message_parameters, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) else: return SparkConnectGrpcException( message, reason=info.reason, + message_parameters=message_parameters, error_class=error_class, sql_state=sql_state, server_stacktrace=stacktrace, display_server_stacktrace=display_server_stacktrace, + query_contexts=query_contexts, ) @@ -247,7 +291,10 @@ class SparkConnectGrpcException(SparkConnectException): sql_state: Optional[str] = None, server_stacktrace: Optional[str] = None, display_server_stacktrace: bool = False, + query_contexts: Optional[List[BaseQueryContext]] = None, ) -> None: + if query_contexts is None: + query_contexts = [] self._message = message # type: ignore[assignment] if reason is not None: self._message = f"({reason}) {self._message}" @@ -271,6 +318,7 @@ class SparkConnectGrpcException(SparkConnectException): self._sql_state: Optional[str] = sql_state self._stacktrace: Optional[str] = server_stacktrace self._display_stacktrace: bool = display_server_stacktrace + self._query_contexts: List[BaseQueryContext] = query_contexts def getSqlState(self) -> Optional[str]: if self._sql_state is not None: @@ -281,12 +329,15 @@ class SparkConnectGrpcException(SparkConnectException): def getStackTrace(self) -> Optional[str]: return self._stacktrace - def __str__(self) -> str: + def getMessage(self) -> str: desc = self._message if self._display_stacktrace: desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace return desc + def __str__(self) -> str: + return self.getMessage() + class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): """ @@ -374,3 +425,37 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx """ No such element exception. """ + + +class QueryContext(BaseQueryContext): + def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext): + self._q = q + + def contextType(self) -> QueryContextType: + context_type = self._q.context_type + + if int(context_type) == QueryContextType.DataFrame.value: + return QueryContextType.DataFrame + else: + return QueryContextType.SQL + + def objectType(self) -> str: + return str(self._q.object_type) + + def objectName(self) -> str: + return str(self._q.object_name) + + def startIndex(self) -> int: + return int(self._q.start_index) + + def stopIndex(self) -> int: + return int(self._q.stop_index) + + def fragment(self) -> str: + return str(self._q.fragment) + + def callSite(self) -> str: + return str(self._q.call_site) + + def summary(self) -> str: + return str(self._q.summary) diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index e23f3bdaaaa4..8326ce511d56 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -206,19 +206,19 @@ if _descriptor._USE_C_DESCRIPTORS == False: _FETCHERRORDETAILSREQUEST._serialized_start = 12098 _FETCHERRORDETAILSREQUEST._serialized_end = 12299 _FETCHERRORDETAILSRESPONSE._serialized_start = 12302 - _FETCHERRORDETAILSRESPONSE._serialized_end = 13856 + _FETCHERRORDETAILSRESPONSE._serialized_end = 13857 _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 12531 _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12705 _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12708 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 13075 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 13038 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 13075 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 13078 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 13487 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 13389 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 13457 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 13490 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13837 - _SPARKCONNECTSERVICE._serialized_start = 13859 - _SPARKCONNECTSERVICE._serialized_end = 14805 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 13076 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 13039 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 13076 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 13079 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 13488 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 13390 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 13458 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 13491 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13838 + _SPARKCONNECTSERVICE._serialized_start = 13860 + _SPARKCONNECTSERVICE._serialized_end = 14806 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index cdf7e2b0bce0..e4ed03dc6945 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -3097,7 +3097,7 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): START_INDEX_FIELD_NUMBER: builtins.int STOP_INDEX_FIELD_NUMBER: builtins.int FRAGMENT_FIELD_NUMBER: builtins.int - CALLSITE_FIELD_NUMBER: builtins.int + CALL_SITE_FIELD_NUMBER: builtins.int SUMMARY_FIELD_NUMBER: builtins.int context_type: global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType object_type: builtins.str @@ -3116,7 +3116,7 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): """The stopping index in the query which throws the exception. The index starts from 0.""" fragment: builtins.str """The corresponding fragment of the query which throws the exception.""" - callSite: builtins.str + call_site: builtins.str """The user code (call site of the API) that caused throwing the exception.""" summary: builtins.str """Summary of the exception cause.""" @@ -3129,14 +3129,14 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): start_index: builtins.int = ..., stop_index: builtins.int = ..., fragment: builtins.str = ..., - callSite: builtins.str = ..., + call_site: builtins.str = ..., summary: builtins.str = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ - "callSite", - b"callSite", + "call_site", + b"call_site", "context_type", b"context_type", "fragment", diff --git a/python/pyspark/sql/tests/connect/test_utils.py b/python/pyspark/sql/tests/connect/test_utils.py index 917cb58057f7..5f5f401cc626 100644 --- a/python/pyspark/sql/tests/connect/test_utils.py +++ b/python/pyspark/sql/tests/connect/test_utils.py @@ -14,13 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import unittest from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.sql.tests.test_utils import UtilsTestsMixin class ConnectUtilsTests(ReusedConnectTestCase, UtilsTestsMixin): - pass + @unittest.skip("SPARK-46397: Different exception thrown") + def test_capture_illegalargument_exception(self): + super().test_capture_illegalargument_exception() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index e13b933c46ba..d54db78d4b65 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -19,6 +19,7 @@ import unittest import difflib from itertools import zip_longest +from pyspark.errors import QueryContextType from pyspark.sql.functions import sha2, to_timestamp from pyspark.errors import ( AnalysisException, @@ -1701,11 +1702,9 @@ class UtilsTestsMixin: self.assertTrue("apple" in error_message and "banana" not in error_message) - -class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) - self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b").collect()) def test_capture_user_friendly_exception(self): try: @@ -1744,19 +1743,31 @@ class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin): def test_get_error_class_state(self): # SPARK-36953: test CapturedException.getErrorClass and getSqlState (from SparkThrowable) + exception = None try: self.spark.sql("""SELECT a""") except AnalysisException as e: - self.assertEqual(e.getErrorClass(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") - self.assertEqual(e.getSqlState(), "42703") - self.assertEqual(e.getMessageParameters(), {"objectName": "`a`"}) - self.assertEqual( - e.getMessage(), - ( - "[UNRESOLVED_COLUMN.WITHOUT_SUGGESTION] A column, variable, or function " - "parameter with name `a` cannot be resolved. SQLSTATE: 42703" - ), - ) + exception = e + + self.assertIsNotNone(exception) + self.assertEqual(exception.getErrorClass(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") + self.assertEqual(exception.getSqlState(), "42703") + self.assertEqual(exception.getMessageParameters(), {"objectName": "`a`"}) + self.assertIn( + ( + "[UNRESOLVED_COLUMN.WITHOUT_SUGGESTION] A column, variable, or function " + "parameter with name `a` cannot be resolved. SQLSTATE: 42703" + ), + exception.getMessage(), + ) + self.assertEqual(len(exception.getQueryContext()), 1) + qc = exception.getQueryContext()[0] + self.assertEqual(qc.fragment(), "a") + self.assertEqual(qc.stopIndex(), 7) + self.assertEqual(qc.startIndex(), 7) + self.assertEqual(qc.contextType(), QueryContextType.SQL) + self.assertEqual(qc.objectName(), "") + self.assertEqual(qc.objectType(), "") try: self.spark.sql("""SELECT assert_true(FALSE)""") @@ -1767,6 +1778,10 @@ class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin): self.assertEqual(e.getMessage(), "") +class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin): + pass + + if __name__ == "__main__": import unittest from pyspark.sql.tests.test_utils import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org