This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cce32a8876bd [SPARK-54052][PYTHON] Add a bridge object to workaround 
Py4J limitation
cce32a8876bd is described below

commit cce32a8876bd3b98b501f7be61202d35a3d17b4d
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Sun Nov 2 16:02:31 2025 +0900

    [SPARK-54052][PYTHON] Add a bridge object to workaround Py4J limitation
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add PythonErrorUtils object to workaround Py4J 
limitation. Py4J does not support default method access.
    
    ### Why are the changes needed?
    
    To make the change easier and non error prone
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. Virtually a refactoring change.
    
    ### How was this patch tested?
    
    Unittest was added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52755 from HyukjinKwon/bridge-class.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../apache/spark/api/python/PythonErrorUtils.scala | 41 ++++++++++++++++++
 .../apache/spark/deploy/PythonRunnerSuite.scala    | 30 ++++++++++++-
 python/pyspark/errors/exceptions/captured.py       | 49 ++++++----------------
 3 files changed, 82 insertions(+), 38 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala
new file mode 100644
index 000000000000..73c2a29ea409
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.util
+
+import org.apache.spark.{BreakingChangeInfo, QueryContext, SparkThrowable}
+
+/**
+ * Utility object that provides convenient accessors for extracting
+ * detailed information from a [[SparkThrowable]] instance.
+ *
+ * This object is primarily used in PySpark
+ * to retrieve structured error metadata because Py4J does not work
+ * with default methods.
+ */
+private[spark] object PythonErrorUtils {
+  def getCondition(e: SparkThrowable): String = e.getCondition
+  def getErrorClass(e: SparkThrowable): String = e.getCondition
+  def getSqlState(e: SparkThrowable): String = e.getSqlState
+  def isInternalError(e: SparkThrowable): Boolean = e.isInternalError
+  def getBreakingChangeInfo(e: SparkThrowable): BreakingChangeInfo = 
e.getBreakingChangeInfo
+  def getMessageParameters(e: SparkThrowable): util.Map[String, String] = 
e.getMessageParameters
+  def getDefaultMessageTemplate(e: SparkThrowable): String = 
e.getDefaultMessageTemplate
+  def getQueryContext(e: SparkThrowable): Array[QueryContext] = 
e.getQueryContext
+}
diff --git 
a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala 
b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
index 473a2d7b2a25..2cce3d306e60 100644
--- a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.deploy
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, SparkThrowable}
+import org.apache.spark.api.python.PythonErrorUtils
 import org.apache.spark.util.Utils
 
 class PythonRunnerSuite extends SparkFunSuite {
@@ -64,4 +65,31 @@ class PythonRunnerSuite extends SparkFunSuite {
     intercept[IllegalArgumentException] { 
PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
     intercept[IllegalArgumentException] { 
PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
   }
+
+  test("SPARK-54052: PythonErrorUtils should have corresponding methods in 
SparkThrowable") {
+    // Find default methods in SparkThrowable
+    val defaultMethods = classOf[SparkThrowable]
+      .getMethods
+      .filter(m => m.getDeclaringClass == classOf[SparkThrowable])
+      .map(_.getName)
+      .toSet
+
+    // Find methods defined in PythonErrorUtils object
+    val utilsMethods = PythonErrorUtils.getClass
+      .getDeclaredMethods
+      .filterNot(_.isSynthetic)
+      .map(_.getName)
+      .filterNot(_.contains("$"))
+      .toSet
+
+    // Compare
+    assert(
+      utilsMethods == defaultMethods,
+      s"""
+         |PythonErrorUtils methods and SparkThrowable default methods differ!
+         |Missing in PythonErrorUtils: 
${defaultMethods.diff(utilsMethods).mkString(", ")}
+         |Extra in PythonErrorUtils: 
${utilsMethods.diff(defaultMethods).mkString(", ")}
+         |""".stripMargin
+    )
+  }
 }
diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index 56892db91f3b..0f76e3b5f6a0 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -107,7 +107,8 @@ class CapturedException(PySparkException):
         if self._origin is not None and is_instance_of(
             gw, self._origin, "org.apache.spark.SparkThrowable"
         ):
-            return self._origin.getCondition()
+            utils = SparkContext._jvm.PythonErrorUtils  # type: 
ignore[union-attr]
+            return utils.getCondition(self._origin)
         else:
             return None
 
@@ -118,7 +119,6 @@ class CapturedException(PySparkException):
     def getMessageParameters(self) -> Optional[Dict[str, str]]:
         from pyspark import SparkContext
         from py4j.java_gateway import is_instance_of
-        from py4j.protocol import Py4JError
 
         assert SparkContext._gateway is not None
 
@@ -126,38 +126,28 @@ class CapturedException(PySparkException):
         if self._origin is not None and is_instance_of(
             gw, self._origin, "org.apache.spark.SparkThrowable"
         ):
-            try:
-                return dict(self._origin.getMessageParameters())
-            except Py4JError as e:
-                if "py4j.Py4JException" in str(e) and "Method 
getMessageParameters" in str(e):
-                    return None
-                raise e
+            utils = SparkContext._jvm.PythonErrorUtils  # type: 
ignore[union-attr]
+            return dict(utils.getMessageParameters(self._origin))
         else:
             return None
 
     def getSqlState(self) -> Optional[str]:
         from pyspark import SparkContext
         from py4j.java_gateway import is_instance_of
-        from py4j.protocol import Py4JError
 
         assert SparkContext._gateway is not None
         gw = SparkContext._gateway
         if self._origin is not None and is_instance_of(
             gw, self._origin, "org.apache.spark.SparkThrowable"
         ):
-            try:
-                return self._origin.getSqlState()
-            except Py4JError as e:
-                if "py4j.Py4JException" in str(e) and "Method getSqlState" in 
str(e):
-                    return None
-                raise e
+            utils = SparkContext._jvm.PythonErrorUtils  # type: 
ignore[union-attr]
+            return utils.getSqlState(self._origin)
         else:
             return None
 
     def getMessage(self) -> str:
         from pyspark import SparkContext
         from py4j.java_gateway import is_instance_of
-        from py4j.protocol import Py4JError
 
         assert SparkContext._gateway is not None
         gw = SparkContext._gateway
@@ -165,21 +155,12 @@ class CapturedException(PySparkException):
         if self._origin is not None and is_instance_of(
             gw, self._origin, "org.apache.spark.SparkThrowable"
         ):
-            try:
-                error_class = self._origin.getCondition()
-            except Py4JError as e:
-                if "py4j.Py4JException" in str(e) and "Method getCondition" in 
str(e):
-                    return ""
-                raise e
-            try:
-                message_parameters = self._origin.getMessageParameters()
-            except Py4JError as e:
-                if "py4j.Py4JException" in str(e) and "Method 
getMessageParameters" in str(e):
-                    return ""
-                raise e
+            utils = SparkContext._jvm.PythonErrorUtils  # type: 
ignore[union-attr]
+            errorClass = utils.getCondition(self._origin)
+            messageParameters = utils.getMessageParameters(self._origin)
 
             error_message = getattr(gw.jvm, 
"org.apache.spark.SparkThrowableHelper").getMessage(
-                error_class, message_parameters
+                errorClass, messageParameters
             )
 
             return error_message
@@ -189,7 +170,6 @@ class CapturedException(PySparkException):
     def getQueryContext(self) -> List[BaseQueryContext]:
         from pyspark import SparkContext
         from py4j.java_gateway import is_instance_of
-        from py4j.protocol import Py4JError
 
         assert SparkContext._gateway is not None
 
@@ -198,13 +178,8 @@ class CapturedException(PySparkException):
             gw, self._origin, "org.apache.spark.SparkThrowable"
         ):
             contexts: List[BaseQueryContext] = []
-            try:
-                context = self._origin.getQueryContext()
-            except Py4JError as e:
-                if "py4j.Py4JException" in str(e) and "Method getQueryContext" 
in str(e):
-                    return []
-                raise e
-            for q in context:
+            utils = SparkContext._jvm.PythonErrorUtils  # type: 
ignore[union-attr]
+            for q in utils.getQueryContext(self._origin):
                 if q.contextType().toString() == "SQL":
                     contexts.append(SQLQueryContext(q))
                 else:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to