HyukjinKwon commented on code in PR #46063:
URL: https://github.com/apache/spark/pull/46063#discussion_r1568564309
##########
python/pyspark/errors/utils.py:
##########
@@ -119,3 +127,74 @@ def get_message_template(self, error_class: str) -> str:
message_template = main_message_template + " " +
sub_message_template
return message_template
+
+
+def _capture_call_site(
+ spark_session: "SparkSession", pyspark_origin: "JavaClass", fragment: str
+) -> None:
+ """
+ Capture the call site information including file name, line number, and
function name.
+ This function updates the thread-local storage from JVM side
(PySparkCurrentOrigin)
+ with the current call site information when a PySpark API function is
called.
+
+ Parameters
+ ----------
+ spark_session : SparkSession
+ Current active Spark session.
+ pyspark_origin : py4j.JavaClass
+ PySparkCurrentOrigin from current active Spark session.
+ fragment : str
+ The name of the PySpark API function being captured.
+
+ Notes
+ -----
+ The call site information is used to enhance error messages with the exact
location
+ in the user code that led to the error.
+ """
+ stack = list(reversed(inspect.stack()))
+ depth = int(
+ spark_session.conf.get("spark.sql.stackTracesInDataFrameContext") #
type: ignore[arg-type]
+ )
+ selected_frames = stack[:depth]
+ call_sites = [f"{frame.filename}:{frame.lineno}" for frame in
selected_frames]
+ call_sites_str = "\n".join(call_sites)
+
+ pyspark_origin.set(fragment, call_sites_str)
+
+
+def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
+ """
+ A decorator to capture and provide the call site information to the server
side
+ when PySpark API functions are invoked.
+ """
+
+ @functools.wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ from pyspark.sql import SparkSession
+
+ spark = SparkSession.getActiveSession()
+ if spark is not None:
+ assert spark._jvm is not None
+ pyspark_origin =
spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+
+ # Update call site when the function is called
+ _capture_call_site(spark, pyspark_origin, func.__name__)
+
+ try:
+ return func(*args, **kwargs)
+ finally:
+ pyspark_origin.clear()
+ else:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def with_origin_to_class(cls: Type[T]) -> Type[T]:
Review Comment:
One last thing. Can we disable this if `PYSPARK_PIN_THREAD` is set `false`?
This wouldn't work when that's disabled.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]