HyukjinKwon commented on code in PR #46063:
URL: https://github.com/apache/spark/pull/46063#discussion_r1568564309


##########
python/pyspark/errors/utils.py:
##########
@@ -119,3 +127,74 @@ def get_message_template(self, error_class: str) -> str:
             message_template = main_message_template + " " + 
sub_message_template
 
         return message_template
+
+
+def _capture_call_site(
+    spark_session: "SparkSession", pyspark_origin: "JavaClass", fragment: str
+) -> None:
+    """
+    Capture the call site information including file name, line number, and 
function name.
+    This function updates the thread-local storage from JVM side 
(PySparkCurrentOrigin)
+    with the current call site information when a PySpark API function is 
called.
+
+    Parameters
+    ----------
+    spark_session : SparkSession
+        Current active Spark session.
+    pyspark_origin : py4j.JavaClass
+        PySparkCurrentOrigin from current active Spark session.
+    fragment : str
+        The name of the PySpark API function being captured.
+
+    Notes
+    -----
+    The call site information is used to enhance error messages with the exact 
location
+    in the user code that led to the error.
+    """
+    stack = list(reversed(inspect.stack()))
+    depth = int(
+        spark_session.conf.get("spark.sql.stackTracesInDataFrameContext")  # 
type: ignore[arg-type]
+    )
+    selected_frames = stack[:depth]
+    call_sites = [f"{frame.filename}:{frame.lineno}" for frame in 
selected_frames]
+    call_sites_str = "\n".join(call_sites)
+
+    pyspark_origin.set(fragment, call_sites_str)
+
+
+def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
+    """
+    A decorator to capture and provide the call site information to the server 
side
+    when PySpark API functions are invoked.
+    """
+
+    @functools.wraps(func)
+    def wrapper(*args: Any, **kwargs: Any) -> Any:
+        from pyspark.sql import SparkSession
+
+        spark = SparkSession.getActiveSession()
+        if spark is not None:
+            assert spark._jvm is not None
+            pyspark_origin = 
spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+
+            # Update call site when the function is called
+            _capture_call_site(spark, pyspark_origin, func.__name__)
+
+            try:
+                return func(*args, **kwargs)
+            finally:
+                pyspark_origin.clear()
+        else:
+            return func(*args, **kwargs)
+
+    return wrapper
+
+
+def with_origin_to_class(cls: Type[T]) -> Type[T]:

Review Comment:
   One last thing. Can we disable this if `PYSPARK_PIN_THREAD` is set `false`? 
This wouldn't work when that's disabled.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to