This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cce32a8876bd [SPARK-54052][PYTHON] Add a bridge object to workaround
Py4J limitation
cce32a8876bd is described below
commit cce32a8876bd3b98b501f7be61202d35a3d17b4d
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Sun Nov 2 16:02:31 2025 +0900
[SPARK-54052][PYTHON] Add a bridge object to workaround Py4J limitation
### What changes were proposed in this pull request?
This PR proposes to add PythonErrorUtils object to workaround Py4J
limitation. Py4J does not support default method access.
### Why are the changes needed?
To make the change easier and non error prone
### Does this PR introduce _any_ user-facing change?
No. Virtually a refactoring change.
### How was this patch tested?
Unittest was added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52755 from HyukjinKwon/bridge-class.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../apache/spark/api/python/PythonErrorUtils.scala | 41 ++++++++++++++++++
.../apache/spark/deploy/PythonRunnerSuite.scala | 30 ++++++++++++-
python/pyspark/errors/exceptions/captured.py | 49 ++++++----------------
3 files changed, 82 insertions(+), 38 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala
new file mode 100644
index 000000000000..73c2a29ea409
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.util
+
+import org.apache.spark.{BreakingChangeInfo, QueryContext, SparkThrowable}
+
+/**
+ * Utility object that provides convenient accessors for extracting
+ * detailed information from a [[SparkThrowable]] instance.
+ *
+ * This object is primarily used in PySpark
+ * to retrieve structured error metadata because Py4J does not work
+ * with default methods.
+ */
+private[spark] object PythonErrorUtils {
+ def getCondition(e: SparkThrowable): String = e.getCondition
+ def getErrorClass(e: SparkThrowable): String = e.getCondition
+ def getSqlState(e: SparkThrowable): String = e.getSqlState
+ def isInternalError(e: SparkThrowable): Boolean = e.isInternalError
+ def getBreakingChangeInfo(e: SparkThrowable): BreakingChangeInfo =
e.getBreakingChangeInfo
+ def getMessageParameters(e: SparkThrowable): util.Map[String, String] =
e.getMessageParameters
+ def getDefaultMessageTemplate(e: SparkThrowable): String =
e.getDefaultMessageTemplate
+ def getQueryContext(e: SparkThrowable): Array[QueryContext] =
e.getQueryContext
+}
diff --git
a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
index 473a2d7b2a25..2cce3d306e60 100644
--- a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.deploy
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, SparkThrowable}
+import org.apache.spark.api.python.PythonErrorUtils
import org.apache.spark.util.Utils
class PythonRunnerSuite extends SparkFunSuite {
@@ -64,4 +65,31 @@ class PythonRunnerSuite extends SparkFunSuite {
intercept[IllegalArgumentException] {
PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
intercept[IllegalArgumentException] {
PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
}
+
+ test("SPARK-54052: PythonErrorUtils should have corresponding methods in
SparkThrowable") {
+ // Find default methods in SparkThrowable
+ val defaultMethods = classOf[SparkThrowable]
+ .getMethods
+ .filter(m => m.getDeclaringClass == classOf[SparkThrowable])
+ .map(_.getName)
+ .toSet
+
+ // Find methods defined in PythonErrorUtils object
+ val utilsMethods = PythonErrorUtils.getClass
+ .getDeclaredMethods
+ .filterNot(_.isSynthetic)
+ .map(_.getName)
+ .filterNot(_.contains("$"))
+ .toSet
+
+ // Compare
+ assert(
+ utilsMethods == defaultMethods,
+ s"""
+ |PythonErrorUtils methods and SparkThrowable default methods differ!
+ |Missing in PythonErrorUtils:
${defaultMethods.diff(utilsMethods).mkString(", ")}
+ |Extra in PythonErrorUtils:
${utilsMethods.diff(defaultMethods).mkString(", ")}
+ |""".stripMargin
+ )
+ }
}
diff --git a/python/pyspark/errors/exceptions/captured.py
b/python/pyspark/errors/exceptions/captured.py
index 56892db91f3b..0f76e3b5f6a0 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -107,7 +107,8 @@ class CapturedException(PySparkException):
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- return self._origin.getCondition()
+ utils = SparkContext._jvm.PythonErrorUtils # type:
ignore[union-attr]
+ return utils.getCondition(self._origin)
else:
return None
@@ -118,7 +119,6 @@ class CapturedException(PySparkException):
def getMessageParameters(self) -> Optional[Dict[str, str]]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
@@ -126,38 +126,28 @@ class CapturedException(PySparkException):
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- try:
- return dict(self._origin.getMessageParameters())
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method
getMessageParameters" in str(e):
- return None
- raise e
+ utils = SparkContext._jvm.PythonErrorUtils # type:
ignore[union-attr]
+ return dict(utils.getMessageParameters(self._origin))
else:
return None
def getSqlState(self) -> Optional[str]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
gw = SparkContext._gateway
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- try:
- return self._origin.getSqlState()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getSqlState" in
str(e):
- return None
- raise e
+ utils = SparkContext._jvm.PythonErrorUtils # type:
ignore[union-attr]
+ return utils.getSqlState(self._origin)
else:
return None
def getMessage(self) -> str:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
gw = SparkContext._gateway
@@ -165,21 +155,12 @@ class CapturedException(PySparkException):
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- try:
- error_class = self._origin.getCondition()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getCondition" in
str(e):
- return ""
- raise e
- try:
- message_parameters = self._origin.getMessageParameters()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method
getMessageParameters" in str(e):
- return ""
- raise e
+ utils = SparkContext._jvm.PythonErrorUtils # type:
ignore[union-attr]
+ errorClass = utils.getCondition(self._origin)
+ messageParameters = utils.getMessageParameters(self._origin)
error_message = getattr(gw.jvm,
"org.apache.spark.SparkThrowableHelper").getMessage(
- error_class, message_parameters
+ errorClass, messageParameters
)
return error_message
@@ -189,7 +170,6 @@ class CapturedException(PySparkException):
def getQueryContext(self) -> List[BaseQueryContext]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
@@ -198,13 +178,8 @@ class CapturedException(PySparkException):
gw, self._origin, "org.apache.spark.SparkThrowable"
):
contexts: List[BaseQueryContext] = []
- try:
- context = self._origin.getQueryContext()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getQueryContext"
in str(e):
- return []
- raise e
- for q in context:
+ utils = SparkContext._jvm.PythonErrorUtils # type:
ignore[union-attr]
+ for q in utils.getQueryContext(self._origin):
if q.contextType().toString() == "SQL":
contexts.append(SQLQueryContext(q))
else:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]