This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c0caee754116 [SPARK-46402][PYTHON] Add getMessageParameters and 
getQueryContext support
c0caee754116 is described below

commit c0caee75411650ddf3a929fb23c3e93aff39b55d
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Dec 14 21:41:18 2023 -0800

    [SPARK-46402][PYTHON] Add getMessageParameters and getQueryContext support
    
    ### What changes were proposed in this pull request?
    
    This PR adds new API with/without Spark Connect as below.
    
    - `getMessageParamater` working fine
    - `getQueryContext`
    
    ### Why are the changes needed?
    
    For feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds the new API, `QueryContext`, `QueryContextType`, 
`PySparkException.getQueryContext` with making `getMessageParamater` working 
fine.
    
    ### How was this patch tested?
    
    Unittests were added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44349 from HyukjinKwon/error-fields.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/protobuf/spark/connect/base.proto     |   2 +-
 python/docs/source/reference/pyspark.errors.rst    |   3 +
 python/pyspark/errors/__init__.py                  |   4 +
 python/pyspark/errors/exceptions/base.py           | 113 +++++++++++++++++++--
 python/pyspark/errors/exceptions/captured.py       |  51 +++++++++-
 python/pyspark/errors/exceptions/connect.py        |  91 ++++++++++++++++-
 python/pyspark/sql/connect/proto/base_pb2.py       |  26 ++---
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  10 +-
 python/pyspark/sql/tests/connect/test_utils.py     |   5 +-
 python/pyspark/sql/tests/test_utils.py             |  41 +++++---
 10 files changed, 301 insertions(+), 45 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index da089dcd7564..f24ca0a8fc3b 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -925,7 +925,7 @@ message FetchErrorDetailsResponse {
     string fragment = 5;
 
     // The user code (call site of the API) that caused throwing the exception.
-    string callSite = 6;
+    string call_site = 6;
 
     // Summary of the exception cause.
     string summary = 7;
diff --git a/python/docs/source/reference/pyspark.errors.rst 
b/python/docs/source/reference/pyspark.errors.rst
index a40f88ce2044..fd824ab8901c 100644
--- a/python/docs/source/reference/pyspark.errors.rst
+++ b/python/docs/source/reference/pyspark.errors.rst
@@ -47,6 +47,8 @@ Classes
     PySparkImportError
     PySparkIndexError
     PythonException
+    QueryContext
+    QueryContextType
     QueryExecutionException
     RetriesExceeded
     SessionNotSameException
@@ -70,4 +72,5 @@ Methods
     PySparkException.getErrorClass
     PySparkException.getMessage
     PySparkException.getMessageParameters
+    PySparkException.getQueryContext
     PySparkException.getSqlState
diff --git a/python/pyspark/errors/__init__.py 
b/python/pyspark/errors/__init__.py
index a4f64e85f875..85e9bb65f0a6 100644
--- a/python/pyspark/errors/__init__.py
+++ b/python/pyspark/errors/__init__.py
@@ -48,6 +48,8 @@ from pyspark.errors.exceptions.base import (  # noqa: F401
     PySparkPicklingError,
     RetriesExceeded,
     PySparkKeyError,
+    QueryContext,
+    QueryContextType,
 )
 
 
@@ -81,4 +83,6 @@ __all__ = [
     "PySparkPicklingError",
     "RetriesExceeded",
     "PySparkKeyError",
+    "QueryContext",
+    "QueryContextType",
 ]
diff --git a/python/pyspark/errors/exceptions/base.py 
b/python/pyspark/errors/exceptions/base.py
index e40e1b2e93cb..dcfc6df77a77 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -14,8 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
-from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING, List
 
 from pyspark.errors.utils import ErrorClassesReader
 from pickle import PicklingError
@@ -34,12 +35,10 @@ class PySparkException(Exception):
         message: Optional[str] = None,
         error_class: Optional[str] = None,
         message_parameters: Optional[Dict[str, str]] = None,
+        query_contexts: Optional[List["QueryContext"]] = None,
     ):
-        # `message` vs `error_class` & `message_parameters` are mutually 
exclusive.
-        assert (message is not None and (error_class is None and 
message_parameters is None)) or (
-            message is None and (error_class is not None and 
message_parameters is not None)
-        )
-
+        if query_contexts is None:
+            query_contexts = []
         self._error_reader = ErrorClassesReader()
 
         if message is None:
@@ -51,6 +50,7 @@ class PySparkException(Exception):
 
         self._error_class = error_class
         self._message_parameters = message_parameters
+        self._query_contexts = query_contexts
 
     def getErrorClass(self) -> Optional[str]:
         """
@@ -62,6 +62,7 @@ class PySparkException(Exception):
         --------
         :meth:`PySparkException.getMessage`
         :meth:`PySparkException.getMessageParameters`
+        :meth:`PySparkException.getQueryContext`
         :meth:`PySparkException.getSqlState`
         """
         return self._error_class
@@ -76,6 +77,7 @@ class PySparkException(Exception):
         --------
         :meth:`PySparkException.getErrorClass`
         :meth:`PySparkException.getMessage`
+        :meth:`PySparkException.getQueryContext`
         :meth:`PySparkException.getSqlState`
         """
         return self._message_parameters
@@ -93,6 +95,7 @@ class PySparkException(Exception):
         :meth:`PySparkException.getErrorClass`
         :meth:`PySparkException.getMessage`
         :meth:`PySparkException.getMessageParameters`
+        :meth:`PySparkException.getQueryContext`
         """
         return None
 
@@ -106,10 +109,26 @@ class PySparkException(Exception):
         --------
         :meth:`PySparkException.getErrorClass`
         :meth:`PySparkException.getMessageParameters`
+        :meth:`PySparkException.getQueryContext`
         :meth:`PySparkException.getSqlState`
         """
         return f"[{self.getErrorClass()}] {self._message}"
 
+    def getQueryContext(self) -> List["QueryContext"]:
+        """
+        Returns :class:`QueryContext`.
+
+        .. versionadded:: 4.0.0
+
+        See Also
+        --------
+        :meth:`PySparkException.getErrorClass`
+        :meth:`PySparkException.getMessageParameters`
+        :meth:`PySparkException.getMessage`
+        :meth:`PySparkException.getSqlState`
+        """
+        return self._query_contexts
+
     def __str__(self) -> str:
         if self.getErrorClass() is not None:
             return self.getMessage()
@@ -294,3 +313,83 @@ class PySparkImportError(PySparkException, ImportError):
     """
     Wrapper class for ImportError to support error classes.
     """
+
+
+class QueryContextType(Enum):
+    """
+    The type of :class:`QueryContext`.
+
+    .. versionadded:: 4.0.0
+    """
+
+    SQL = 0
+    DataFrame = 1
+
+
+class QueryContext(ABC):
+    """
+    Query context of a :class:`PySparkException`. It helps users understand
+    where error occur while executing queries.
+
+    .. versionadded:: 4.0.0
+    """
+
+    @abstractmethod
+    def contextType(self) -> QueryContextType:
+        """
+        The type of this query context.
+        """
+        ...
+
+    @abstractmethod
+    def objectType(self) -> str:
+        """
+        The object type of the query which throws the exception.
+        If the exception is directly from the main query, it should be an 
empty string.
+        Otherwise, it should be the exact object type in upper case. For 
example, a "VIEW".
+        """
+        ...
+
+    @abstractmethod
+    def objectName(self) -> str:
+        """
+        The object name of the query which throws the exception.
+        If the exception is directly from the main query, it should be an 
empty string.
+        Otherwise, it should be the object name. For example, a view name "V1".
+        """
+        ...
+
+    @abstractmethod
+    def startIndex(self) -> int:
+        """
+        The starting index in the query text which throws the exception. The 
index starts from 0.
+        """
+        ...
+
+    @abstractmethod
+    def stopIndex(self) -> int:
+        """
+        The stopping index in the query which throws the exception. The index 
starts from 0.
+        """
+        ...
+
+    @abstractmethod
+    def fragment(self) -> str:
+        """
+        The corresponding fragment of the query which throws the exception.
+        """
+        ...
+
+    @abstractmethod
+    def callSite(self) -> str:
+        """
+        The user code (call site of the API) that caused throwing the 
exception.
+        """
+        ...
+
+    @abstractmethod
+    def summary(self) -> str:
+        """
+        Summary of the exception cause.
+        """
+        ...
diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index 4164bb7b428d..687bdec14154 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -15,11 +15,11 @@
 # limitations under the License.
 #
 from contextlib import contextmanager
-from typing import Any, Callable, Dict, Iterator, Optional, cast
+from typing import Any, Callable, Dict, Iterator, Optional, cast, List
 
 import py4j
 from py4j.protocol import Py4JJavaError
-from py4j.java_gateway import is_instance_of
+from py4j.java_gateway import is_instance_of, JavaObject
 
 from pyspark import SparkContext
 from pyspark.errors.exceptions.base import (
@@ -39,6 +39,8 @@ from pyspark.errors.exceptions.base import (
     SparkNoSuchElementException as BaseNoSuchElementException,
     StreamingQueryException as BaseStreamingQueryException,
     UnknownException as BaseUnknownException,
+    QueryContext as BaseQueryContext,
+    QueryContextType,
 )
 
 
@@ -136,6 +138,17 @@ class CapturedException(PySparkException):
         else:
             return ""
 
+    def getQueryContext(self) -> List[BaseQueryContext]:
+        assert SparkContext._gateway is not None
+
+        gw = SparkContext._gateway
+        if self._origin is not None and is_instance_of(
+            gw, self._origin, "org.apache.spark.SparkThrowable"
+        ):
+            return [QueryContext(q) for q in self._origin.getQueryContext()]
+        else:
+            return []
+
 
 def convert_exception(e: Py4JJavaError) -> CapturedException:
     assert e is not None
@@ -332,3 +345,37 @@ class UnknownException(CapturedException, 
BaseUnknownException):
     """
     None of the other exceptions.
     """
+
+
+class QueryContext(BaseQueryContext):
+    def __init__(self, q: JavaObject):
+        self._q = q
+
+    def contextType(self) -> QueryContextType:
+        context_type = self._q.contextType().toString()
+        assert context_type in ("SQL", "DataFrame")
+        if context_type == "DataFrame":
+            return QueryContextType.DataFrame
+        else:
+            return QueryContextType.SQL
+
+    def objectType(self) -> str:
+        return str(self._q.objectType())
+
+    def objectName(self) -> str:
+        return str(self._q.objectName())
+
+    def startIndex(self) -> int:
+        return int(self._q.startIndex())
+
+    def stopIndex(self) -> int:
+        return int(self._q.stopIndex())
+
+    def fragment(self) -> str:
+        return str(self._q.fragment())
+
+    def callSite(self) -> str:
+        return str(self._q.callSite())
+
+    def summary(self) -> str:
+        return str(self._q.summary())
diff --git a/python/pyspark/errors/exceptions/connect.py 
b/python/pyspark/errors/exceptions/connect.py
index aaa52f9b20a5..ba172135cb64 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -34,6 +34,8 @@ from pyspark.errors.exceptions.base import (
     SparkRuntimeException as BaseSparkRuntimeException,
     SparkNoSuchElementException as BaseNoSuchElementException,
     SparkUpgradeException as BaseSparkUpgradeException,
+    QueryContext as BaseQueryContext,
+    QueryContextType,
 )
 
 if TYPE_CHECKING:
@@ -55,8 +57,8 @@ def convert_exception(
     classes = []
     sql_state = None
     error_class = None
-
-    stacktrace: Optional[str] = None
+    message_parameters = None
+    query_contexts: Optional[List[BaseQueryContext]] = None
 
     if "classes" in info.metadata:
         classes = json.loads(info.metadata["classes"])
@@ -67,6 +69,7 @@ def convert_exception(
     if "errorClass" in info.metadata:
         error_class = info.metadata["errorClass"]
 
+    stacktrace: Optional[str] = None
     if resp is not None and resp.HasField("root_error_idx"):
         message = resp.errors[resp.root_error_idx].message
         stacktrace = _extract_jvm_stacktrace(resp)
@@ -75,79 +78,109 @@ def convert_exception(
         stacktrace = info.metadata["stackTrace"] if "stackTrace" in 
info.metadata else None
         display_server_stacktrace = display_server_stacktrace if stacktrace is 
not None else False
 
+    if (
+        resp is not None
+        and resp.errors
+        and hasattr(resp.errors[resp.root_error_idx], "spark_throwable")
+    ):
+        message_parameters = dict(
+            resp.errors[resp.root_error_idx].spark_throwable.message_parameters
+        )
+        query_contexts = []
+        for query_context in 
resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
+            query_contexts.append(QueryContext(query_context))
+
     if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
         return ParseException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     # Order matters. ParseException inherits AnalysisException.
     elif "org.apache.spark.sql.AnalysisException" in classes:
         return AnalysisException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
         return StreamingQueryException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
         return QueryExecutionException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     # Order matters. NumberFormatException inherits IllegalArgumentException.
     elif "java.lang.NumberFormatException" in classes:
         return NumberFormatException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "java.lang.IllegalArgumentException" in classes:
         return IllegalArgumentException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "java.lang.ArithmeticException" in classes:
         return ArithmeticException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "java.lang.UnsupportedOperationException" in classes:
         return UnsupportedOperationException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
         return ArrayIndexOutOfBoundsException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "java.time.DateTimeException" in classes:
         return DateTimeException(
@@ -156,22 +189,27 @@ def convert_exception(
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "org.apache.spark.SparkRuntimeException" in classes:
         return SparkRuntimeException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "org.apache.spark.SparkUpgradeException" in classes:
         return SparkUpgradeException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     elif "org.apache.spark.api.python.PythonException" in classes:
         return PythonException(
@@ -182,27 +220,33 @@ def convert_exception(
         return SparkNoSuchElementException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     # Make sure that the generic SparkException is handled last.
     elif "org.apache.spark.SparkException" in classes:
         return SparkException(
             message,
             error_class=error_class,
+            message_parameters=message_parameters,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
     else:
         return SparkConnectGrpcException(
             message,
             reason=info.reason,
+            message_parameters=message_parameters,
             error_class=error_class,
             sql_state=sql_state,
             server_stacktrace=stacktrace,
             display_server_stacktrace=display_server_stacktrace,
+            query_contexts=query_contexts,
         )
 
 
@@ -247,7 +291,10 @@ class SparkConnectGrpcException(SparkConnectException):
         sql_state: Optional[str] = None,
         server_stacktrace: Optional[str] = None,
         display_server_stacktrace: bool = False,
+        query_contexts: Optional[List[BaseQueryContext]] = None,
     ) -> None:
+        if query_contexts is None:
+            query_contexts = []
         self._message = message  # type: ignore[assignment]
         if reason is not None:
             self._message = f"({reason}) {self._message}"
@@ -271,6 +318,7 @@ class SparkConnectGrpcException(SparkConnectException):
         self._sql_state: Optional[str] = sql_state
         self._stacktrace: Optional[str] = server_stacktrace
         self._display_stacktrace: bool = display_server_stacktrace
+        self._query_contexts: List[BaseQueryContext] = query_contexts
 
     def getSqlState(self) -> Optional[str]:
         if self._sql_state is not None:
@@ -281,12 +329,15 @@ class SparkConnectGrpcException(SparkConnectException):
     def getStackTrace(self) -> Optional[str]:
         return self._stacktrace
 
-    def __str__(self) -> str:
+    def getMessage(self) -> str:
         desc = self._message
         if self._display_stacktrace:
             desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
         return desc
 
+    def __str__(self) -> str:
+        return self.getMessage()
+
 
 class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
     """
@@ -374,3 +425,37 @@ class 
SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
     """
     No such element exception.
     """
+
+
+class QueryContext(BaseQueryContext):
+    def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
+        self._q = q
+
+    def contextType(self) -> QueryContextType:
+        context_type = self._q.context_type
+
+        if int(context_type) == QueryContextType.DataFrame.value:
+            return QueryContextType.DataFrame
+        else:
+            return QueryContextType.SQL
+
+    def objectType(self) -> str:
+        return str(self._q.object_type)
+
+    def objectName(self) -> str:
+        return str(self._q.object_name)
+
+    def startIndex(self) -> int:
+        return int(self._q.start_index)
+
+    def stopIndex(self) -> int:
+        return int(self._q.stop_index)
+
+    def fragment(self) -> str:
+        return str(self._q.fragment)
+
+    def callSite(self) -> str:
+        return str(self._q.call_site)
+
+    def summary(self) -> str:
+        return str(self._q.summary)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index e23f3bdaaaa4..8326ce511d56 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -206,19 +206,19 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _FETCHERRORDETAILSREQUEST._serialized_start = 12098
     _FETCHERRORDETAILSREQUEST._serialized_end = 12299
     _FETCHERRORDETAILSRESPONSE._serialized_start = 12302
-    _FETCHERRORDETAILSRESPONSE._serialized_end = 13856
+    _FETCHERRORDETAILSRESPONSE._serialized_end = 13857
     _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 12531
     _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12705
     _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12708
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 13075
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
13038
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 13075
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 13078
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 13487
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 13389
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 13457
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 13490
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13837
-    _SPARKCONNECTSERVICE._serialized_start = 13859
-    _SPARKCONNECTSERVICE._serialized_end = 14805
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 13076
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 
13039
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 13076
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 13079
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 13488
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 13390
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 13458
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 13491
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13838
+    _SPARKCONNECTSERVICE._serialized_start = 13860
+    _SPARKCONNECTSERVICE._serialized_end = 14806
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index cdf7e2b0bce0..e4ed03dc6945 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -3097,7 +3097,7 @@ class 
FetchErrorDetailsResponse(google.protobuf.message.Message):
         START_INDEX_FIELD_NUMBER: builtins.int
         STOP_INDEX_FIELD_NUMBER: builtins.int
         FRAGMENT_FIELD_NUMBER: builtins.int
-        CALLSITE_FIELD_NUMBER: builtins.int
+        CALL_SITE_FIELD_NUMBER: builtins.int
         SUMMARY_FIELD_NUMBER: builtins.int
         context_type: 
global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType
         object_type: builtins.str
@@ -3116,7 +3116,7 @@ class 
FetchErrorDetailsResponse(google.protobuf.message.Message):
         """The stopping index in the query which throws the exception. The 
index starts from 0."""
         fragment: builtins.str
         """The corresponding fragment of the query which throws the 
exception."""
-        callSite: builtins.str
+        call_site: builtins.str
         """The user code (call site of the API) that caused throwing the 
exception."""
         summary: builtins.str
         """Summary of the exception cause."""
@@ -3129,14 +3129,14 @@ class 
FetchErrorDetailsResponse(google.protobuf.message.Message):
             start_index: builtins.int = ...,
             stop_index: builtins.int = ...,
             fragment: builtins.str = ...,
-            callSite: builtins.str = ...,
+            call_site: builtins.str = ...,
             summary: builtins.str = ...,
         ) -> None: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "callSite",
-                b"callSite",
+                "call_site",
+                b"call_site",
                 "context_type",
                 b"context_type",
                 "fragment",
diff --git a/python/pyspark/sql/tests/connect/test_utils.py 
b/python/pyspark/sql/tests/connect/test_utils.py
index 917cb58057f7..5f5f401cc626 100644
--- a/python/pyspark/sql/tests/connect/test_utils.py
+++ b/python/pyspark/sql/tests/connect/test_utils.py
@@ -14,13 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import unittest
 
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.sql.tests.test_utils import UtilsTestsMixin
 
 
 class ConnectUtilsTests(ReusedConnectTestCase, UtilsTestsMixin):
-    pass
+    @unittest.skip("SPARK-46397: Different exception thrown")
+    def test_capture_illegalargument_exception(self):
+        super().test_capture_illegalargument_exception()
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index e13b933c46ba..d54db78d4b65 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -19,6 +19,7 @@ import unittest
 import difflib
 from itertools import zip_longest
 
+from pyspark.errors import QueryContextType
 from pyspark.sql.functions import sha2, to_timestamp
 from pyspark.errors import (
     AnalysisException,
@@ -1701,11 +1702,9 @@ class UtilsTestsMixin:
 
         self.assertTrue("apple" in error_message and "banana" not in 
error_message)
 
-
-class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
     def test_capture_analysis_exception(self):
         self.assertRaises(AnalysisException, lambda: self.spark.sql("select 
abc"))
-        self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + 
b"))
+        self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + 
b").collect())
 
     def test_capture_user_friendly_exception(self):
         try:
@@ -1744,19 +1743,31 @@ class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
 
     def test_get_error_class_state(self):
         # SPARK-36953: test CapturedException.getErrorClass and getSqlState 
(from SparkThrowable)
+        exception = None
         try:
             self.spark.sql("""SELECT a""")
         except AnalysisException as e:
-            self.assertEqual(e.getErrorClass(), 
"UNRESOLVED_COLUMN.WITHOUT_SUGGESTION")
-            self.assertEqual(e.getSqlState(), "42703")
-            self.assertEqual(e.getMessageParameters(), {"objectName": "`a`"})
-            self.assertEqual(
-                e.getMessage(),
-                (
-                    "[UNRESOLVED_COLUMN.WITHOUT_SUGGESTION] A column, 
variable, or function "
-                    "parameter with name `a` cannot be resolved.  SQLSTATE: 
42703"
-                ),
-            )
+            exception = e
+
+        self.assertIsNotNone(exception)
+        self.assertEqual(exception.getErrorClass(), 
"UNRESOLVED_COLUMN.WITHOUT_SUGGESTION")
+        self.assertEqual(exception.getSqlState(), "42703")
+        self.assertEqual(exception.getMessageParameters(), {"objectName": 
"`a`"})
+        self.assertIn(
+            (
+                "[UNRESOLVED_COLUMN.WITHOUT_SUGGESTION] A column, variable, or 
function "
+                "parameter with name `a` cannot be resolved.  SQLSTATE: 42703"
+            ),
+            exception.getMessage(),
+        )
+        self.assertEqual(len(exception.getQueryContext()), 1)
+        qc = exception.getQueryContext()[0]
+        self.assertEqual(qc.fragment(), "a")
+        self.assertEqual(qc.stopIndex(), 7)
+        self.assertEqual(qc.startIndex(), 7)
+        self.assertEqual(qc.contextType(), QueryContextType.SQL)
+        self.assertEqual(qc.objectName(), "")
+        self.assertEqual(qc.objectType(), "")
 
         try:
             self.spark.sql("""SELECT assert_true(FALSE)""")
@@ -1767,6 +1778,10 @@ class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
             self.assertEqual(e.getMessage(), "")
 
 
+class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
+    pass
+
+
 if __name__ == "__main__":
     import unittest
     from pyspark.sql.tests.test_utils import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to