This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c5230e485d7 [SPARK-40453][SPARK-41715][CONNECT] Take super class into 
account when throwing an exception
c5230e485d7 is described below

commit c5230e485d781ecfa996a674443709b0ce261f36
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Feb 10 10:50:01 2023 +0900

    [SPARK-40453][SPARK-41715][CONNECT] Take super class into account when 
throwing an exception
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to take the super classes into account when throwing an 
exception from the server to Python side by adding more metadata of classes, 
causes and traceback in JVM.
    
    In addition, this PR matches the exceptions being thrown to the regular 
PySpark exceptions defined:
    
    
https://github.com/apache/spark/blob/04550edd49ee587656d215e59d6a072772d7d5ec/python/pyspark/errors/exceptions/captured.py#L108-L147
    
    ### Why are the changes needed?
    
    Right now, many exceptions cannot be handled (e.g., 
`NoSuchDatabaseException` that inherits `AnalysisException`) in Python side.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No to end users.
    Yes, it matches the exceptions to the regular PySpark exceptions.
    
    ### How was this patch tested?
    
    Unittests fixed.
    
    Closes #39947 from HyukjinKwon/SPARK-41715.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/connect/service/SparkConnectService.scala  | 92 ++++++++++++----------
 python/pyspark/errors/exceptions/connect.py        | 76 ++++++++++++------
 python/pyspark/sql/catalog.py                      |  8 +-
 python/pyspark/sql/connect/client.py               | 33 +-------
 python/pyspark/sql/dataframe.py                    |  6 +-
 .../sql/tests/connect/test_connect_basic.py        |  5 +-
 python/pyspark/sql/tests/pandas/test_pandas_udf.py |  6 +-
 python/pyspark/sql/tests/test_catalog.py           | 16 ++--
 8 files changed, 122 insertions(+), 120 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 25b7009860b..05aa2428140 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.connect.service
 
 import java.util.concurrent.TimeUnit
 
+import scala.annotation.tailrec
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
 import com.google.common.base.Ticker
@@ -31,12 +33,16 @@ import io.grpc.netty.NettyServerBuilder
 import io.grpc.protobuf.StatusProto
 import io.grpc.protobuf.services.ProtoReflectionService
 import io.grpc.stub.StreamObserver
+import org.apache.commons.lang3.StringUtils
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
 
-import org.apache.spark.{SparkEnv, SparkThrowable}
+import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
+import org.apache.spark.api.python.PythonException
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{AnalyzePlanRequest, 
AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, 
SparkConnectServiceGrpc}
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession}
+import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
 import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, 
SparkConnectPlanner}
 import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, 
ExtendedMode, FormattedMode, SimpleMode}
@@ -53,8 +59,24 @@ class SparkConnectService(debug: Boolean)
     extends SparkConnectServiceGrpc.SparkConnectServiceImplBase
     with Logging {
 
-  private def buildStatusFromThrowable[A <: Throwable with SparkThrowable](st: 
A): RPCStatus = {
-    val t = Option(st.getCause).getOrElse(st)
+  private def allClasses(cl: Class[_]): Seq[Class[_]] = {
+    val classes = ArrayBuffer.empty[Class[_]]
+    if (cl != null && !cl.equals(classOf[java.lang.Object])) {
+      classes.append(cl) // Includes itself.
+    }
+
+    @tailrec
+    def appendSuperClasses(clazz: Class[_]): Unit = {
+      if (clazz == null || clazz.equals(classOf[java.lang.Object])) return
+      classes.append(clazz.getSuperclass)
+      appendSuperClasses(clazz.getSuperclass)
+    }
+
+    appendSuperClasses(cl)
+    classes.toSeq
+  }
+
+  private def buildStatusFromThrowable(st: Throwable): RPCStatus = {
     RPCStatus
       .newBuilder()
       .setCode(RPCCode.INTERNAL_VALUE)
@@ -62,13 +84,21 @@ class SparkConnectService(debug: Boolean)
         ProtoAny.pack(
           ErrorInfo
             .newBuilder()
-            .setReason(t.getClass.getName)
+            .setReason(st.getClass.getName)
             .setDomain("org.apache.spark")
+            .putMetadata("classes", 
compact(render(allClasses(st.getClass).map(_.getName))))
             .build()))
-      .setMessage(t.getLocalizedMessage)
+      .setMessage(StringUtils.abbreviate(st.getMessage, 2048))
       .build()
   }
 
+  private def isPythonExecutionException(se: SparkException): Boolean = {
+    // See also pyspark.errors.exceptions.captured.convert_exception in 
PySpark.
+    se.getCause != null && se.getCause
+      .isInstanceOf[PythonException] && se.getCause.getStackTrace
+      .exists(_.toString.contains("org.apache.spark.sql.execution.python"))
+  }
+
   /**
    * Common exception handling function for the Analysis and Execution 
methods. Closes the stream
    * after the error has been sent.
@@ -83,46 +113,22 @@ class SparkConnectService(debug: Boolean)
   private def handleError[V](
       opType: String,
       observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = {
-    case ae: AnalysisException =>
-      logError(s"Error during: $opType", ae)
-      val status = RPCStatus
-        .newBuilder()
-        .setCode(RPCCode.INTERNAL_VALUE)
-        .addDetails(
-          ProtoAny.pack(
-            ErrorInfo
-              .newBuilder()
-              .setReason(ae.getClass.getName)
-              .setDomain("org.apache.spark")
-              .putMetadata("message", ae.getSimpleMessage)
-              .putMetadata("plan", Option(ae.plan).flatten.map(p => 
s"$p").getOrElse(""))
-              .build()))
-        .setMessage(ae.getLocalizedMessage)
-        .build()
-      observer.onError(StatusProto.toStatusRuntimeException(status))
-    case st: SparkThrowable =>
-      logError(s"Error during: $opType", st)
-      val status = buildStatusFromThrowable(st)
-      observer.onError(StatusProto.toStatusRuntimeException(status))
-    case NonFatal(nf) =>
-      logError(s"Error during: $opType", nf)
-      val status = RPCStatus
-        .newBuilder()
-        .setCode(RPCCode.INTERNAL_VALUE)
-        .addDetails(
-          ProtoAny.pack(
-            ErrorInfo
-              .newBuilder()
-              .setReason(nf.getClass.getName)
-              .setDomain("org.apache.spark")
-              .build()))
-        .setMessage(nf.getLocalizedMessage)
-        .build()
-      observer.onError(StatusProto.toStatusRuntimeException(status))
+    case se: SparkException if isPythonExecutionException(se) =>
+      logError(s"Error during: $opType", se)
+      observer.onError(
+        
StatusProto.toStatusRuntimeException(buildStatusFromThrowable(se.getCause)))
+
+    case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) =>
+      logError(s"Error during: $opType", e)
+      
observer.onError(StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e)))
+
     case e: Throwable =>
       logError(s"Error during: $opType", e)
       observer.onError(
-        
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
+        Status.UNKNOWN
+          .withCause(e)
+          .withDescription(StringUtils.abbreviate(e.getMessage, 2048))
+          .asRuntimeException())
   }
 
   /**
diff --git a/python/pyspark/errors/exceptions/connect.py 
b/python/pyspark/errors/exceptions/connect.py
index ba3bc9f7576..f5f1d42ca5d 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -14,8 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import json
+from typing import Dict, Optional, TYPE_CHECKING
 
-from typing import Dict, Optional
 
 from pyspark.errors.exceptions.base import (
     AnalysisException as BaseAnalysisException,
@@ -23,9 +24,14 @@ from pyspark.errors.exceptions.base import (
     ParseException as BaseParseException,
     PySparkException,
     PythonException as BasePythonException,
-    TempTableAlreadyExistsException as BaseTempTableAlreadyExistsException,
+    StreamingQueryException as BaseStreamingQueryException,
+    QueryExecutionException as BaseQueryExecutionException,
+    SparkUpgradeException as BaseSparkUpgradeException,
 )
 
+if TYPE_CHECKING:
+    from google.rpc.error_details_pb2 import ErrorInfo
+
 
 class SparkConnectException(PySparkException):
     """
@@ -33,6 +39,33 @@ class SparkConnectException(PySparkException):
     """
 
 
+def convert_exception(info: "ErrorInfo", message: str) -> 
SparkConnectException:
+    classes = []
+    if "classes" in info.metadata:
+        classes = json.loads(info.metadata["classes"])
+
+    if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
+        return ParseException(message)
+    # Order matters. ParseException inherits AnalysisException.
+    elif "org.apache.spark.sql.AnalysisException" in classes:
+        return AnalysisException(message)
+    elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
+        return StreamingQueryException(message)
+    elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
+        return QueryExecutionException(message)
+    elif "java.lang.IllegalArgumentException" in classes:
+        return IllegalArgumentException(message)
+    elif "org.apache.spark.SparkUpgradeException" in classes:
+        return SparkUpgradeException(message)
+    elif "org.apache.spark.api.python.PythonException" in classes:
+        return PythonException(
+            "\n  An exception was thrown from the Python worker. "
+            "Please see the stack trace below.\n%s" % message
+        )
+    else:
+        return SparkConnectGrpcException(message, reason=info.reason)
+
+
 class SparkConnectGrpcException(SparkConnectException):
     """
     Base class to handle the errors from GRPC.
@@ -61,41 +94,34 @@ class AnalysisException(SparkConnectGrpcException, 
BaseAnalysisException):
     Failed to analyze a SQL query plan from Spark Connect server.
     """
 
-    def __init__(
-        self,
-        message: Optional[str] = None,
-        error_class: Optional[str] = None,
-        message_parameters: Optional[Dict[str, str]] = None,
-        plan: Optional[str] = None,
-        reason: Optional[str] = None,
-    ) -> None:
-        self.message = message  # type: ignore[assignment]
-        if plan is not None:
-            self.message = f"{self.message}\nPlan: {plan}"
 
-        super().__init__(
-            message=self.message,
-            error_class=error_class,
-            message_parameters=message_parameters,
-            reason=reason,
-        )
+class ParseException(SparkConnectGrpcException, BaseParseException):
+    """
+    Failed to parse a SQL command from Spark Connect server.
+    """
 
 
-class TempTableAlreadyExistsException(AnalysisException, 
BaseTempTableAlreadyExistsException):
+class IllegalArgumentException(SparkConnectGrpcException, 
BaseIllegalArgumentException):
     """
-    Failed to create temp view from Spark Connect server since it is already 
exists.
+    Passed an illegal or inappropriate argument from Spark Connect server.
     """
 
 
-class ParseException(SparkConnectGrpcException, BaseParseException):
+class StreamingQueryException(SparkConnectGrpcException, 
BaseStreamingQueryException):
     """
-    Failed to parse a SQL command from Spark Connect server.
+    Exception that stopped a :class:`StreamingQuery` from Spark Connect server.
     """
 
 
-class IllegalArgumentException(SparkConnectGrpcException, 
BaseIllegalArgumentException):
+class QueryExecutionException(SparkConnectGrpcException, 
BaseQueryExecutionException):
     """
-    Passed an illegal or inappropriate argument from Spark Connect server.
+    Failed to execute a query from Spark Connect server.
+    """
+
+
+class SparkUpgradeException(SparkConnectGrpcException, 
BaseSparkUpgradeException):
+    """
+    Exception thrown because of Spark upgrade from Spark Connect
     """
 
 
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 6deee786164..a7f3e761f3f 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -387,7 +387,7 @@ class Catalog:
 
         Throw an analysis exception when the table does not exist.
 
-        >>> spark.catalog.getTable("tbl1")  # doctest: +SKIP
+        >>> spark.catalog.getTable("tbl1")
         Traceback (most recent call last):
             ...
         AnalysisException: ...
@@ -548,7 +548,7 @@ class Catalog:
 
         Throw an analysis exception when the function does not exists.
 
-        >>> spark.catalog.getFunction("my_func2")  # doctest: +SKIP
+        >>> spark.catalog.getFunction("my_func2")
         Traceback (most recent call last):
             ...
         AnalysisException: ...
@@ -867,7 +867,7 @@ class Catalog:
 
         Throw an exception if the temporary view does not exists.
 
-        >>> spark.table("my_table")  # doctest: +SKIP
+        >>> spark.table("my_table")
         Traceback (most recent call last):
             ...
         AnalysisException: ...
@@ -907,7 +907,7 @@ class Catalog:
 
         Throw an exception if the global view does not exists.
 
-        >>> spark.table("global_temp.my_table")  # doctest: +SKIP
+        >>> spark.table("global_temp.my_table")
         Traceback (most recent call last):
             ...
         AnalysisException: ...
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 903981a015b..943a7e70464 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -60,13 +60,9 @@ import pyspark.sql.connect.proto as pb2
 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
 import pyspark.sql.connect.types as types
 from pyspark.errors.exceptions.connect import (
-    AnalysisException,
-    ParseException,
-    PythonException,
+    convert_exception,
     SparkConnectException,
     SparkConnectGrpcException,
-    TempTableAlreadyExistsException,
-    IllegalArgumentException,
 )
 from pyspark.sql.connect.expressions import (
     PythonUDF,
@@ -730,32 +726,7 @@ class SparkConnectClient(object):
                 if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
                     info = error_details_pb2.ErrorInfo()
                     d.Unpack(info)
-                    reason = info.reason
-                    if reason == "org.apache.spark.sql.AnalysisException":
-                        raise AnalysisException(
-                            info.metadata["message"], 
plan=info.metadata["plan"]
-                        ) from None
-                    elif reason == 
"org.apache.spark.sql.catalyst.parser.ParseException":
-                        raise ParseException(info.metadata["message"]) from 
None
-                    elif (
-                        reason
-                        == 
"org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException"
-                    ):
-                        raise TempTableAlreadyExistsException(
-                            info.metadata["message"], 
plan=info.metadata["plan"]
-                        ) from None
-                    elif reason == "java.lang.IllegalArgumentException":
-                        message = info.metadata["message"]
-                        message = message if message != "" else status.message
-                        raise IllegalArgumentException(message) from None
-                    elif reason == 
"org.apache.spark.api.python.PythonException":
-                        message = info.metadata["message"]
-                        message = message if message != "" else status.message
-                        raise PythonException(message) from None
-                    else:
-                        raise SparkConnectGrpcException(
-                            status.message, reason=info.reason
-                        ) from None
+                    raise convert_exception(info, status.message) from None
 
             raise SparkConnectGrpcException(status.message) from None
         else:
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index e794bb94e75..5649d362b8b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -356,7 +356,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         Throw an exception if the table already exists.
 
-        >>> df.createTempView("people")  # doctest: +IGNORE_EXCEPTION_DETAIL, 
+SKIP
+        >>> df.createTempView("people")  # doctest: +IGNORE_EXCEPTION_DETAIL
         Traceback (most recent call last):
         ...
         AnalysisException: "Temporary table 'people' already exists;"
@@ -439,7 +439,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         Throws an exception if the global temporary view already exists.
 
-        >>> df.createGlobalTempView("people")  # doctest: 
+IGNORE_EXCEPTION_DETAIL, +SKIP
+        >>> df.createGlobalTempView("people")  # doctest: 
+IGNORE_EXCEPTION_DETAIL
         Traceback (most recent call last):
         ...
         AnalysisException: "Temporary table 'people' already exists;"
@@ -4598,7 +4598,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         Examples
         --------
         >>> df = spark.createDataFrame([(1, 11), (1, 11), (3, 10), (4, 8), (4, 
8)], ["c1", "c2"])
-        >>> df.freqItems(["c1", "c2"]).show()  # doctest: +SKIP
+        >>> df.freqItems(["c1", "c2"]).show()
         +------------+------------+
         |c1_freqItems|c2_freqItems|
         +------------+------------+
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index b8e2c7b151a..b3b241b2d4e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -53,7 +53,6 @@ from pyspark.errors.exceptions.connect import (
     AnalysisException,
     ParseException,
     SparkConnectException,
-    TempTableAlreadyExistsException,
 )
 
 if should_test_connect:
@@ -1244,7 +1243,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
 
             # Test when creating a view which is already exists but
             
self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
-            with self.assertRaises(TempTableAlreadyExistsException):
+            with self.assertRaises(AnalysisException):
                 self.connect.sql("SELECT 1 AS X LIMIT 
0").createGlobalTempView("view_1")
 
     def test_create_session_local_temp_view(self):
@@ -1256,7 +1255,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.assertEqual(self.connect.sql("SELECT * FROM 
view_local_temp").count(), 0)
 
             # Test when creating a view which is already exists but
-            with self.assertRaises(TempTableAlreadyExistsException):
+            with self.assertRaises(AnalysisException):
                 self.connect.sql("SELECT 1 AS X LIMIT 
0").createTempView("view_local_temp")
 
     def test_to_pandas(self):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 1b3b4555d7f..0f927113130 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -171,7 +171,7 @@ class PandasUDFTestsMixin:
         def foo(x):
             raise StopIteration()
 
-        exc_message = "Caught StopIteration thrown from user's code; failing 
the task"
+        exc_message = "StopIteration"
         df = self.spark.range(0, 100)
 
         # plain udf (test for SPARK-23754)
@@ -193,7 +193,7 @@ class PandasUDFTestsMixin:
         def foofoo(x, y):
             raise StopIteration()
 
-        exc_message = "Caught StopIteration thrown from user's code; failing 
the task"
+        exc_message = "StopIteration"
         df = self.spark.range(0, 100)
 
         # pandas grouped map
@@ -215,7 +215,7 @@ class PandasUDFTestsMixin:
         def foo(x):
             raise StopIteration()
 
-        exc_message = "Caught StopIteration thrown from user's code; failing 
the task"
+        exc_message = "StopIteration"
         df = self.spark.range(0, 100)
 
         # pandas grouped agg
diff --git a/python/pyspark/sql/tests/test_catalog.py 
b/python/pyspark/sql/tests/test_catalog.py
index 4ab11c46071..10f3ec12c9c 100644
--- a/python/pyspark/sql/tests/test_catalog.py
+++ b/python/pyspark/sql/tests/test_catalog.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+from pyspark.errors import AnalysisException
 from pyspark.sql.types import StructType, StructField, IntegerType
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
@@ -28,9 +28,7 @@ class CatalogTestsMixin:
             spark.catalog.setCurrentDatabase("some_db")
             self.assertEqual(spark.catalog.currentDatabase(), "some_db")
             self.assertRaisesRegex(
-                # TODO(SPARK-41715): Should catch specific exceptions for both
-                #  Spark Connect and PySpark
-                Exception,
+                AnalysisException,
                 "does_not_exist",
                 lambda: spark.catalog.setCurrentDatabase("does_not_exist"),
             )
@@ -181,7 +179,7 @@ class CatalogTestsMixin:
                         )
                     )
                     self.assertRaisesRegex(
-                        Exception,
+                        AnalysisException,
                         "does_not_exist",
                         lambda: spark.catalog.listTables("does_not_exist"),
                     )
@@ -236,7 +234,7 @@ class CatalogTestsMixin:
                 self.assertTrue("func1" not in newFunctionsSomeDb)
                 self.assertTrue("func2" in newFunctionsSomeDb)
                 self.assertRaisesRegex(
-                    Exception,
+                    AnalysisException,
                     "does_not_exist",
                     lambda: spark.catalog.listFunctions("does_not_exist"),
                 )
@@ -333,9 +331,11 @@ class CatalogTestsMixin:
                         isBucket=False,
                     ),
                 )
-                self.assertRaisesRegex(Exception, "tab2", lambda: 
spark.catalog.listColumns("tab2"))
                 self.assertRaisesRegex(
-                    Exception,
+                    AnalysisException, "tab2", lambda: 
spark.catalog.listColumns("tab2")
+                )
+                self.assertRaisesRegex(
+                    AnalysisException,
                     "does_not_exist",
                     lambda: spark.catalog.listColumns("does_not_exist"),
                 )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to