This is an automated email from the ASF dual-hosted git repository.
allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d259132156e2 [SPARK-50858][PYTHON] Add configuration to hide Python
UDF stack trace
d259132156e2 is described below
commit d259132156e2e40c89fdc1d12911e12fed273c3e
Author: Haoyu Weng <[email protected]>
AuthorDate: Mon Jan 27 16:37:23 2025 -0800
[SPARK-50858][PYTHON] Add configuration to hide Python UDF stack trace
### What changes were proposed in this pull request?
This PR adds new configuration
`spark.sql.execution.pyspark.udf.hideTraceback.enabled`. If set, when handling
an exception from Python UDF, only the exception class and message are
included. The configuration is turned off by default.
This PR also adds a new optional parameter `hide_traceback` for
`handle_udf_exception` to override the configuration.
Suggested review order:
1. `python/pyspark/util.py`: logic changes
2. `python/pyspark/tests/test_util.py`: unit tests
3. other files: adding new configuration
### Why are the changes needed?
This allows library provided UDFs to show only the relevant message without
unnecessary stack trace.
### Does this PR introduce _any_ user-facing change?
If the configuration is turned off, no user change.
Otherwise, the stack trace is not included in the error message when
handling an exception from Python UDF.
<details>
<summary>Example that illustrates the difference</summary>
```py
from pyspark.errors.exceptions.base import PySparkRuntimeError
from pyspark.sql.types import IntegerType, StructField, StructType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
from pyspark.sql.functions import udtf
udtf()
class PythonUDTF:
staticmethod
def analyze(x: AnalyzeArgument) -> AnalyzeResult:
raise PySparkRuntimeError("[XXX] My PySpark runtime error.")
def eval(self, x: int):
yield (x,)
spark.udtf.register("my_udtf", PythonUDTF)
spark.sql("select * from my_udtf(1)").show()
```
With configuration turned off, the last line gives:
```
...
pyspark.errors.exceptions.captured.AnalysisException:
[TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON] Failed to analyze the
Python user defined table function: Traceback (most recent call last):
File "<stdin>", line 7, in analyze
pyspark.errors.exceptions.base.PySparkRuntimeError: [XXX] My PySpark
runtime error. SQLSTATE: 38000; line 1 pos 14
```
With configuration turned on, the last line gives:
```
...
pyspark.errors.exceptions.captured.AnalysisException:
[TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON] Failed to analyze the
Python user defined table function:
pyspark.errors.exceptions.base.PySparkRuntimeError: [XXX] My PySpark runtime
error. SQLSTATE: 38000; line 1 pos 14
```
</details>
### How was this patch tested?
Added unit test in `python/pyspark/tests/test_util.py`, testing two cases
with the configuration turned on and off respectively.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49535 from wengh/spark-50858-hide-udf-stack-trace.
Authored-by: Haoyu Weng <[email protected]>
Signed-off-by: Allison Wang <[email protected]>
---
.../org/apache/spark/api/python/PythonRunner.scala | 4 +++
python/pyspark/tests/test_util.py | 41 ++++++++++++++++++++++
python/pyspark/util.py | 30 ++++++++++++----
.../org/apache/spark/sql/internal/SQLConf.scala | 11 ++++++
.../ApplyInPandasWithStatePythonRunner.scala | 1 +
.../sql/execution/python/ArrowPythonRunner.scala | 1 +
.../execution/python/ArrowPythonUDTFRunner.scala | 1 +
.../python/CoGroupedArrowPythonRunner.scala | 1 +
.../sql/execution/python/PythonForeachWriter.scala | 1 +
.../sql/execution/python/PythonPlannerRunner.scala | 4 +++
.../sql/execution/python/PythonUDFRunner.scala | 1 +
11 files changed, 90 insertions(+), 6 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 28950e5b41d4..64e78dbccb2f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -122,6 +122,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
protected val faultHandlerEnabled: Boolean =
conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
+ protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false
// All the Python functions should have the same exec, version and envvars.
@@ -199,6 +200,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
+ if (hideTraceback) {
+ envVars.put("SPARK_HIDE_TRACEBACK", "1")
+ }
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
diff --git a/python/pyspark/tests/test_util.py
b/python/pyspark/tests/test_util.py
index ad0b106d229a..e1079ca7b4e8 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -16,6 +16,7 @@
#
import os
import unittest
+from unittest.mock import patch
from py4j.protocol import Py4JJavaError
@@ -125,6 +126,46 @@ class UtilTests(PySparkTestCase):
_parse_memory("2gs")
+class HandleWorkerExceptionTests(unittest.TestCase):
+ exception_bytes = b"ValueError: test_message"
+ traceback_bytes = b"Traceback (most recent call last):"
+
+ def run_handle_worker_exception(self, hide_traceback=None):
+ import io
+ from pyspark.util import handle_worker_exception
+
+ try:
+ raise ValueError("test_message")
+ except Exception as e:
+ with io.BytesIO() as stream:
+ handle_worker_exception(e, stream, hide_traceback)
+ return stream.getvalue()
+
+ @patch.dict(os.environ, {"SPARK_SIMPLIFIED_TRACEBACK": "",
"SPARK_HIDE_TRACEBACK": ""})
+ def test_env_full(self):
+ result = self.run_handle_worker_exception()
+ self.assertIn(self.exception_bytes, result)
+ self.assertIn(self.traceback_bytes, result)
+
+ @patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
+ def test_env_hide_traceback(self):
+ result = self.run_handle_worker_exception()
+ self.assertIn(self.exception_bytes, result)
+ self.assertNotIn(self.traceback_bytes, result)
+
+ @patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
+ def test_full(self):
+ # Should ignore the environment variable because hide_traceback is
explicitly set.
+ result = self.run_handle_worker_exception(False)
+ self.assertIn(self.exception_bytes, result)
+ self.assertIn(self.traceback_bytes, result)
+
+ def test_hide_traceback(self):
+ result = self.run_handle_worker_exception(True)
+ self.assertIn(self.exception_bytes, result)
+ self.assertNotIn(self.traceback_bytes, result)
+
+
if __name__ == "__main__":
from pyspark.tests.test_util import * # noqa: F401
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 3e9a68ccfe2e..f51706858182 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -462,22 +462,40 @@ def inheritable_thread_target(f: Optional[Union[Callable,
"SparkSession"]] = Non
return f # type: ignore[return-value]
-def handle_worker_exception(e: BaseException, outfile: IO) -> None:
+def handle_worker_exception(
+ e: BaseException, outfile: IO, hide_traceback: Optional[bool] = None
+) -> None:
"""
Handles exception for Python worker which writes
SpecialLengths.PYTHON_EXCEPTION_THROWN (-2)
and exception traceback info to outfile. JVM could then read from the
outfile and perform
exception handling there.
+
+ Parameters
+ ----------
+ e : BaseException
+ Exception handled
+ outfile : IO
+ IO object to write the exception info
+ hide_traceback : bool, optional
+ Whether to hide the traceback in the output.
+ By default, hides the traceback if environment variable
SPARK_HIDE_TRACEBACK is set.
"""
- try:
- exc_info = None
+
+ if hide_traceback is None:
+ hide_traceback = bool(os.environ.get("SPARK_HIDE_TRACEBACK", False))
+
+ def format_exception() -> str:
+ if hide_traceback:
+ return "".join(traceback.format_exception_only(type(e), e))
if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
tb = try_simplify_traceback(sys.exc_info()[-1]) # type:
ignore[arg-type]
if tb is not None:
e.__cause__ = None
- exc_info = "".join(traceback.format_exception(type(e), e, tb))
- if exc_info is None:
- exc_info = traceback.format_exc()
+ return "".join(traceback.format_exception(type(e), e, tb))
+ return traceback.format_exc()
+ try:
+ exc_info = format_exception()
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(exc_info.encode("utf-8"), outfile)
except IOError:
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4907e7ee6276..7b560002edeb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3475,6 +3475,15 @@ object SQLConf {
.checkValues(Set("legacy", "row", "dict"))
.createWithDefaultString("legacy")
+ val PYSPARK_HIDE_TRACEBACK =
+ buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled")
+ .doc(
+ "When true, only show the message of the exception from Python UDFs, "
+
+ "hiding the stack trace. If this is enabled, simplifiedTraceback has
no effect.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PYSPARK_SIMPLIFIED_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled")
.doc(
@@ -6286,6 +6295,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)
+ def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK)
+
def pysparkSimplifiedTraceback: Boolean =
getConf(PYSPARK_SIMPLIFIED_TRACEBACK)
def pandasGroupedMapAssignColumnsByName: Boolean =
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index d704638b85e8..f598430df0ee 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -84,6 +84,7 @@ class ApplyInPandasWithStatePythonRunner(
override protected lazy val timeZoneId: String = _timeZoneId
override val errorOnDuplicatedFieldNames: Boolean = true
+ override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback
override val simplifiedTraceback: Boolean =
sqlConf.pysparkSimplifiedTraceback
override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 579b49604685..1bddd81fbfe2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -50,6 +50,7 @@ abstract class BaseArrowPythonRunner(
override val errorOnDuplicatedFieldNames: Boolean = true
+ override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
// Use lazy val to initialize the fields before these are accessed in
[[PythonArrowInput]]'s
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 99a9e706c662..f42c4b6106cb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -59,6 +59,7 @@ class ArrowPythonUDTFRunner(
override val errorOnDuplicatedFieldNames: Boolean = true
+ override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index c5e86d010938..59e8970b9c9b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -60,6 +60,7 @@ class CoGroupedArrowPythonRunner(
override val faultHandlerEnabled: Boolean =
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
+ override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
protected def newWriter(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
index ed7ff6a75348..4655f96425fd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
@@ -100,6 +100,7 @@ class PythonForeachWriter(func: PythonFunction, schema:
StructType)
override val faultHandlerEnabled: Boolean =
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
+ override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
index 8cc2e1de7a4c..1974c393c472 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -52,6 +52,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
val localdir = env.blockManager.diskBlockManager.localDirs.map(f =>
f.getPath()).mkString(",")
val faultHandlerEnabled: Boolean =
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
+ val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory
@@ -68,6 +69,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
+ if (hideTraceback) {
+ envVars.put("SPARK_HIDE_TRACEBACK", "1")
+ }
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 167e1fd8b0f0..a322dfa10df5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -42,6 +42,7 @@ abstract class BasePythonUDFRunner(
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head._1.funcs.head.pythonExec)
+ override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
override val faultHandlerEnabled: Boolean =
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]