This is an automated email from the ASF dual-hosted git repository.

allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d259132156e2 [SPARK-50858][PYTHON] Add configuration to hide Python 
UDF stack trace
d259132156e2 is described below

commit d259132156e2e40c89fdc1d12911e12fed273c3e
Author: Haoyu Weng <[email protected]>
AuthorDate: Mon Jan 27 16:37:23 2025 -0800

    [SPARK-50858][PYTHON] Add configuration to hide Python UDF stack trace
    
    ### What changes were proposed in this pull request?
    
    This PR adds new configuration 
`spark.sql.execution.pyspark.udf.hideTraceback.enabled`. If set, when handling 
an exception from Python UDF, only the exception class and message are 
included. The configuration is turned off by default.
    This PR also adds a new optional parameter `hide_traceback` for 
`handle_udf_exception` to override the configuration.
    
    Suggested review order:
    1. `python/pyspark/util.py`: logic changes
    2. `python/pyspark/tests/test_util.py`: unit tests
    3. other files: adding new configuration
    
    ### Why are the changes needed?
    
    This allows library provided UDFs to show only the relevant message without 
unnecessary stack trace.
    
    ### Does this PR introduce _any_ user-facing change?
    
    If the configuration is turned off, no user change.
    Otherwise, the stack trace is not included in the error message when 
handling an exception from Python UDF.
    
    <details>
    <summary>Example that illustrates the difference</summary>
    
    ```py
    from pyspark.errors.exceptions.base import PySparkRuntimeError
    from pyspark.sql.types import IntegerType, StructField, StructType
    from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
    from pyspark.sql.functions import udtf
    
    udtf()
    class PythonUDTF:
        staticmethod
        def analyze(x: AnalyzeArgument) -> AnalyzeResult:
            raise PySparkRuntimeError("[XXX] My PySpark runtime error.")
    
        def eval(self, x: int):
            yield (x,)
    
    spark.udtf.register("my_udtf", PythonUDTF)
    spark.sql("select * from my_udtf(1)").show()
    ```
    
    With configuration turned off, the last line gives:
    ```
    ...
    pyspark.errors.exceptions.captured.AnalysisException: 
[TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON] Failed to analyze the 
Python user defined table function: Traceback (most recent call last):
      File "<stdin>", line 7, in analyze
    pyspark.errors.exceptions.base.PySparkRuntimeError: [XXX] My PySpark 
runtime error. SQLSTATE: 38000; line 1 pos 14
    ```
    
    With configuration turned on, the last line gives:
    ```
    ...
    pyspark.errors.exceptions.captured.AnalysisException: 
[TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON] Failed to analyze the 
Python user defined table function: 
pyspark.errors.exceptions.base.PySparkRuntimeError: [XXX] My PySpark runtime 
error. SQLSTATE: 38000; line 1 pos 14
    ```
    
    </details>
    
    ### How was this patch tested?
    
    Added unit test in `python/pyspark/tests/test_util.py`, testing two cases 
with the configuration turned on and off respectively.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #49535 from wengh/spark-50858-hide-udf-stack-trace.
    
    Authored-by: Haoyu Weng <[email protected]>
    Signed-off-by: Allison Wang <[email protected]>
---
 .../org/apache/spark/api/python/PythonRunner.scala |  4 +++
 python/pyspark/tests/test_util.py                  | 41 ++++++++++++++++++++++
 python/pyspark/util.py                             | 30 ++++++++++++----
 .../org/apache/spark/sql/internal/SQLConf.scala    | 11 ++++++
 .../ApplyInPandasWithStatePythonRunner.scala       |  1 +
 .../sql/execution/python/ArrowPythonRunner.scala   |  1 +
 .../execution/python/ArrowPythonUDTFRunner.scala   |  1 +
 .../python/CoGroupedArrowPythonRunner.scala        |  1 +
 .../sql/execution/python/PythonForeachWriter.scala |  1 +
 .../sql/execution/python/PythonPlannerRunner.scala |  4 +++
 .../sql/execution/python/PythonUDFRunner.scala     |  1 +
 11 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 28950e5b41d4..64e78dbccb2f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -122,6 +122,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
   protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
   private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
   protected val faultHandlerEnabled: Boolean = 
conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
+  protected val hideTraceback: Boolean = false
   protected val simplifiedTraceback: Boolean = false
 
   // All the Python functions should have the same exec, version and envvars.
@@ -199,6 +200,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     if (reuseWorker) {
       envVars.put("SPARK_REUSE_WORKER", "1")
     }
+    if (hideTraceback) {
+      envVars.put("SPARK_HIDE_TRACEBACK", "1")
+    }
     if (simplifiedTraceback) {
       envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
     }
diff --git a/python/pyspark/tests/test_util.py 
b/python/pyspark/tests/test_util.py
index ad0b106d229a..e1079ca7b4e8 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -16,6 +16,7 @@
 #
 import os
 import unittest
+from unittest.mock import patch
 
 from py4j.protocol import Py4JJavaError
 
@@ -125,6 +126,46 @@ class UtilTests(PySparkTestCase):
             _parse_memory("2gs")
 
 
+class HandleWorkerExceptionTests(unittest.TestCase):
+    exception_bytes = b"ValueError: test_message"
+    traceback_bytes = b"Traceback (most recent call last):"
+
+    def run_handle_worker_exception(self, hide_traceback=None):
+        import io
+        from pyspark.util import handle_worker_exception
+
+        try:
+            raise ValueError("test_message")
+        except Exception as e:
+            with io.BytesIO() as stream:
+                handle_worker_exception(e, stream, hide_traceback)
+                return stream.getvalue()
+
+    @patch.dict(os.environ, {"SPARK_SIMPLIFIED_TRACEBACK": "", 
"SPARK_HIDE_TRACEBACK": ""})
+    def test_env_full(self):
+        result = self.run_handle_worker_exception()
+        self.assertIn(self.exception_bytes, result)
+        self.assertIn(self.traceback_bytes, result)
+
+    @patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
+    def test_env_hide_traceback(self):
+        result = self.run_handle_worker_exception()
+        self.assertIn(self.exception_bytes, result)
+        self.assertNotIn(self.traceback_bytes, result)
+
+    @patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
+    def test_full(self):
+        # Should ignore the environment variable because hide_traceback is 
explicitly set.
+        result = self.run_handle_worker_exception(False)
+        self.assertIn(self.exception_bytes, result)
+        self.assertIn(self.traceback_bytes, result)
+
+    def test_hide_traceback(self):
+        result = self.run_handle_worker_exception(True)
+        self.assertIn(self.exception_bytes, result)
+        self.assertNotIn(self.traceback_bytes, result)
+
+
 if __name__ == "__main__":
     from pyspark.tests.test_util import *  # noqa: F401
 
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 3e9a68ccfe2e..f51706858182 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -462,22 +462,40 @@ def inheritable_thread_target(f: Optional[Union[Callable, 
"SparkSession"]] = Non
         return f  # type: ignore[return-value]
 
 
-def handle_worker_exception(e: BaseException, outfile: IO) -> None:
+def handle_worker_exception(
+    e: BaseException, outfile: IO, hide_traceback: Optional[bool] = None
+) -> None:
     """
     Handles exception for Python worker which writes 
SpecialLengths.PYTHON_EXCEPTION_THROWN (-2)
     and exception traceback info to outfile. JVM could then read from the 
outfile and perform
     exception handling there.
+
+    Parameters
+    ----------
+    e : BaseException
+        Exception handled
+    outfile : IO
+        IO object to write the exception info
+    hide_traceback : bool, optional
+        Whether to hide the traceback in the output.
+        By default, hides the traceback if environment variable 
SPARK_HIDE_TRACEBACK is set.
     """
-    try:
-        exc_info = None
+
+    if hide_traceback is None:
+        hide_traceback = bool(os.environ.get("SPARK_HIDE_TRACEBACK", False))
+
+    def format_exception() -> str:
+        if hide_traceback:
+            return "".join(traceback.format_exception_only(type(e), e))
         if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
             tb = try_simplify_traceback(sys.exc_info()[-1])  # type: 
ignore[arg-type]
             if tb is not None:
                 e.__cause__ = None
-                exc_info = "".join(traceback.format_exception(type(e), e, tb))
-        if exc_info is None:
-            exc_info = traceback.format_exc()
+                return "".join(traceback.format_exception(type(e), e, tb))
+        return traceback.format_exc()
 
+    try:
+        exc_info = format_exception()
         write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
         write_with_length(exc_info.encode("utf-8"), outfile)
     except IOError:
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4907e7ee6276..7b560002edeb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3475,6 +3475,15 @@ object SQLConf {
       .checkValues(Set("legacy", "row", "dict"))
       .createWithDefaultString("legacy")
 
+  val PYSPARK_HIDE_TRACEBACK =
+    buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled")
+      .doc(
+        "When true, only show the message of the exception from Python UDFs, " 
+
+        "hiding the stack trace. If this is enabled, simplifiedTraceback has 
no effect.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val PYSPARK_SIMPLIFIED_TRACEBACK =
     buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled")
       .doc(
@@ -6286,6 +6295,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)
 
+  def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK)
+
   def pysparkSimplifiedTraceback: Boolean = 
getConf(PYSPARK_SIMPLIFIED_TRACEBACK)
 
   def pandasGroupedMapAssignColumnsByName: Boolean =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index d704638b85e8..f598430df0ee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -84,6 +84,7 @@ class ApplyInPandasWithStatePythonRunner(
   override protected lazy val timeZoneId: String = _timeZoneId
   override val errorOnDuplicatedFieldNames: Boolean = true
 
+  override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback
   override val simplifiedTraceback: Boolean = 
sqlConf.pysparkSimplifiedTraceback
 
   override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 579b49604685..1bddd81fbfe2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -50,6 +50,7 @@ abstract class BaseArrowPythonRunner(
 
   override val errorOnDuplicatedFieldNames: Boolean = true
 
+  override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
   // Use lazy val to initialize the fields before these are accessed in 
[[PythonArrowInput]]'s
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 99a9e706c662..f42c4b6106cb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -59,6 +59,7 @@ class ArrowPythonUDTFRunner(
 
   override val errorOnDuplicatedFieldNames: Boolean = true
 
+  override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
   override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index c5e86d010938..59e8970b9c9b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -60,6 +60,7 @@ class CoGroupedArrowPythonRunner(
 
   override val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
 
+  override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
   protected def newWriter(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
index ed7ff6a75348..4655f96425fd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
@@ -100,6 +100,7 @@ class PythonForeachWriter(func: PythonFunction, schema: 
StructType)
 
       override val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
 
+      override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
       override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
index 8cc2e1de7a4c..1974c393c472 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -52,6 +52,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
     val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
     val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
     val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
+    val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
     val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
     val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory
 
@@ -68,6 +69,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
     if (reuseWorker) {
       envVars.put("SPARK_REUSE_WORKER", "1")
     }
+    if (hideTraceback) {
+      envVars.put("SPARK_HIDE_TRACEBACK", "1")
+    }
     if (simplifiedTraceback) {
       envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 167e1fd8b0f0..a322dfa10df5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -42,6 +42,7 @@ abstract class BasePythonUDFRunner(
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
       funcs.head._1.funcs.head.pythonExec)
 
+  override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
   override val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to