ueshin commented on a change in pull request #34219:
URL: https://github.com/apache/spark/pull/34219#discussion_r727615112



##########
File path: python/pyspark/sql/utils.py
##########
@@ -16,24 +16,55 @@
 #
 
 import py4j
+from py4j.java_gateway import is_instance_of
 
 from pyspark import SparkContext
 
 
 class CapturedException(Exception):
-    def __init__(self, desc, stackTrace, cause=None):
-        self.desc = desc
-        self.stackTrace = stackTrace
+    def __init__(self, desc=None, stackTrace=None, cause=None, origin=None):
+        # desc & stackTrace vs origin are mutually exclusive.
+        # cause is optional.
+        assert ((origin is not None and desc is None and stackTrace is None)
+                or (origin is None and desc is not None and stackTrace is not 
None))
+
+        self.desc = desc if desc is not None else origin.getMessage()
+        self.stackTrace = (
+            stackTrace if stackTrace is not None
+            else 
SparkContext._jvm.org.apache.spark.util.Utils.exceptionString(origin)
+        )
         self.cause = convert_exception(cause) if cause is not None else None
+        if self.cause is None and origin is not None and origin.getCause() is 
not None:
+            self.cause = convert_exception(origin.getCause())
+        self._origin = origin
 
     def __str__(self):
-        sql_conf = 
SparkContext._jvm.org.apache.spark.sql.internal.SQLConf.get()
+        assert SparkContext._jvm is not None
+
+        jvm = SparkContext._jvm
+        sql_conf = jvm.org.apache.spark.sql.internal.SQLConf.get()
         debug_enabled = sql_conf.pysparkJVMStacktraceEnabled()
         desc = self.desc
         if debug_enabled:
             desc = desc + "\n\nJVM stacktrace:\n%s" % self.stackTrace
         return str(desc)
 
+    def getErrorClass(self):
+        assert SparkContext._gateway is not None
+
+        gw = SparkContext._gateway
+        if self._origin is not None and is_instance_of(
+                gw, self._origin, "org.apache.spark.SparkThrowable"):
+            return self._origin.getErrorClass()
+
+    def getSqlState(self):
+        assert SparkContext._gateway is not None
+
+        gw = SparkContext._gateway
+        if self._origin is not None and is_instance_of(
+                gw, self._origin, "org.apache.spark.SparkThrowable"):
+            return self._origin.getSqlState()

Review comment:
       ditto.

##########
File path: python/pyspark/sql/utils.py
##########
@@ -16,24 +16,55 @@
 #
 
 import py4j
+from py4j.java_gateway import is_instance_of
 
 from pyspark import SparkContext
 
 
 class CapturedException(Exception):
-    def __init__(self, desc, stackTrace, cause=None):
-        self.desc = desc
-        self.stackTrace = stackTrace
+    def __init__(self, desc=None, stackTrace=None, cause=None, origin=None):
+        # desc & stackTrace vs origin are mutually exclusive.
+        # cause is optional.
+        assert ((origin is not None and desc is None and stackTrace is None)
+                or (origin is None and desc is not None and stackTrace is not 
None))
+
+        self.desc = desc if desc is not None else origin.getMessage()
+        self.stackTrace = (
+            stackTrace if stackTrace is not None
+            else 
SparkContext._jvm.org.apache.spark.util.Utils.exceptionString(origin)
+        )
         self.cause = convert_exception(cause) if cause is not None else None
+        if self.cause is None and origin is not None and origin.getCause() is 
not None:
+            self.cause = convert_exception(origin.getCause())
+        self._origin = origin
 
     def __str__(self):
-        sql_conf = 
SparkContext._jvm.org.apache.spark.sql.internal.SQLConf.get()
+        assert SparkContext._jvm is not None
+
+        jvm = SparkContext._jvm
+        sql_conf = jvm.org.apache.spark.sql.internal.SQLConf.get()
         debug_enabled = sql_conf.pysparkJVMStacktraceEnabled()
         desc = self.desc
         if debug_enabled:
             desc = desc + "\n\nJVM stacktrace:\n%s" % self.stackTrace
         return str(desc)
 
+    def getErrorClass(self):
+        assert SparkContext._gateway is not None
+
+        gw = SparkContext._gateway
+        if self._origin is not None and is_instance_of(
+                gw, self._origin, "org.apache.spark.SparkThrowable"):
+            return self._origin.getErrorClass()

Review comment:
       We should explicitly return `None` or raise an error for `else` case?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to