This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 01a79c857639 [SPARK-54314][PYTHON][CONNECT] Improve Server-Side
debuggability in Spark Connect by capturing client application's file name and
line numbers
01a79c857639 is described below
commit 01a79c857639afb3fbafbb510520e9deaf28db11
Author: susheel-aroskar <[email protected]>
AuthorDate: Thu Jan 22 14:03:36 2026 -0800
[SPARK-54314][PYTHON][CONNECT] Improve Server-Side debuggability in Spark
Connect by capturing client application's file name and line numbers
### What changes were proposed in this pull request?
Optionally transmitting client-side code location details (function name,
file name and line number) along with actions.
### Why are the changes needed?
Right now there is no information sent to Spark Connect server that will
aid in pointing the location of the call (i.e. Spark data frame action) in the
client application code. By making this change, client application call stack
details are sent to the server as a list of (function name, file name, line
number) tuples where they can be logged in the server logs, included in
corresponding open telemetry spans as attributes etc. This will help users
looking from server side UI or Cons [...]
### Does this PR introduce _any_ user-facing change?
It includes a new ENV variable `SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK`
which user can set to true / 1 to opt into transmitting client application code
locations to server. If opted into, the client app call stack trace details are
included in the `user_context.extensions` field of the Spark Connect protobufs
### How was this patch tested?
By adding new unit test `test_client_call_stack_trace.py`
### Was this patch authored or co-authored using generative AI tooling?
Yes.
Some of the unit tests were Generated-by: Cursor
Closes #53076 from
susheel-aroskar/sarsokar-SPARK-54314-pyspark-connect-debug.
Lead-authored-by: susheel-aroskar <[email protected]>
Co-authored-by: Susheel Aroskar <[email protected]>
Signed-off-by: Huaxin Gao <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/client/core.py | 63 ++++
.../connect/client/test_client_call_stack_trace.py | 318 +++++++++++++++++++++
3 files changed, 382 insertions(+)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 1e91432f4abc..cf017c03ec11 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1169,6 +1169,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.client.test_artifact",
"pyspark.sql.tests.connect.client.test_artifact_localcluster",
"pyspark.sql.tests.connect.client.test_client",
+ "pyspark.sql.tests.connect.client.test_client_call_stack_trace",
"pyspark.sql.tests.connect.client.test_reattach",
"pyspark.sql.tests.connect.test_resources",
"pyspark.sql.tests.connect.shell.test_progress",
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 83249995bb7e..58ae6a0eea98 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -22,6 +22,8 @@ __all__ = [
import atexit
+import pyspark
+from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
@@ -35,6 +37,7 @@ import urllib.parse
import uuid
import sys
import time
+import traceback
from typing import (
Iterable,
Iterator,
@@ -65,6 +68,7 @@ from google.rpc import error_details_pb2
from pyspark.util import is_remote_only
from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.version import __version__
+from pyspark.traceback_utils import CallSite
from pyspark.resource.information import ResourceInformation
from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo,
ObservedMetrics
from pyspark.sql.connect.client.artifact import ArtifactManager
@@ -115,6 +119,9 @@ if TYPE_CHECKING:
from pyspark.sql.datasource import DataSource
+PYSPARK_ROOT = os.path.dirname(pyspark.__file__)
+
+
def _import_zstandard_if_available() -> Optional[Any]:
"""
Import zstandard if available, otherwise return None.
@@ -621,6 +628,11 @@ class ConfigResult:
)
+def _is_pyspark_source(filename: str) -> bool:
+ """Check if the given filename is from the pyspark package."""
+ return filename.startswith(PYSPARK_ROOT)
+
+
class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
@@ -812,6 +824,50 @@ class SparkConnectClient(object):
"""
return list(self._retry_policies)
+ @classmethod
+ def _retrieve_stack_frames(cls) -> List[CallSite]:
+ """
+ Return a list of CallSites representing the relevant stack frames in
the callstack.
+ """
+ frames = traceback.extract_stack()
+
+ filtered_stack_frames = []
+ for i, frame in enumerate(frames):
+ filename, lineno, func, _ = frame
+ if _is_pyspark_source(filename):
+ # Do not include PySpark internal frames as they are not user
application code
+ break
+ if i + 1 < len(frames):
+ _, _, func, _ = frames[i + 1]
+ filtered_stack_frames.append(CallSite(function=func,
file=filename, linenum=lineno))
+
+ return filtered_stack_frames
+
+ @classmethod
+ def _build_call_stack_trace(cls) -> Optional[any_pb2.Any]:
+ """
+ Build a call stack trace for the current Spark Connect action
+ Returns
+ -------
+ FetchErrorDetailsResponse.Error: An Error object containing list of
stack frames
+ of the user code packed as Any protobuf
+ """
+ if os.getenv("SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK", "false").lower()
in ("true", "1"):
+ stack_frames = cls._retrieve_stack_frames()
+ call_stack = FetchErrorDetailsResponse.Error()
+ for call_site in stack_frames:
+ stack_trace_element =
pb2.FetchErrorDetailsResponse.StackTraceElement()
+ stack_trace_element.declaring_class = "" # unknown information
+ stack_trace_element.method_name = call_site.function
+ stack_trace_element.file_name = call_site.file
+ stack_trace_element.line_number = call_site.linenum
+ call_stack.stack_trace.append(stack_trace_element)
+ if len(call_stack.stack_trace) > 0:
+ call_stack_details = any_pb2.Any()
+ call_stack_details.Pack(call_stack)
+ return call_stack_details
+ return None
+
def register_udf(
self,
function: Any,
@@ -1291,6 +1347,9 @@ class SparkConnectClient(object):
)
req.operation_id = operation_id
self._update_request_with_user_context_extensions(req)
+
+ if call_stack_trace := self.__class__._build_call_stack_trace():
+ req.user_context.extensions.append(call_stack_trace)
return req
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
@@ -1302,6 +1361,8 @@ class SparkConnectClient(object):
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
+ if call_stack_trace := self.__class__._build_call_stack_trace():
+ req.user_context.extensions.append(call_stack_trace)
return req
def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
@@ -1717,6 +1778,8 @@ class SparkConnectClient(object):
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
+ if call_stack_trace := self.__class__._build_call_stack_trace():
+ req.user_context.extensions.append(call_stack_trace)
return req
def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
diff --git
a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py
b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py
new file mode 100644
index 000000000000..cf5acb7ca88f
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py
@@ -0,0 +1,318 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+from unittest.mock import patch
+
+import pyspark
+from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
+
+if should_test_connect:
+ import pyspark.sql.connect.proto as pb2
+ from pyspark.sql.connect.client import SparkConnectClient, core
+ from pyspark.sql.connect.client.core import _is_pyspark_source
+
+ # The _cleanup_ml_cache invocation will hang in this test (no valid spark
cluster)
+ # and it blocks the test process exiting because it is registered as the
atexit handler
+ # in `SparkConnectClient` constructor. To bypass the issue, patch the
method in the test.
+ SparkConnectClient._cleanup_ml_cache = lambda _: None
+
+# SPARK-54314: Improve Server-Side debuggability in Spark Connect by capturing
client application's
+# file name and line numbers in PySpark
+# https://issues.apache.org/jira/browse/SPARK-54314
+
+
[email protected](not should_test_connect, connect_requirement_message)
+class CallStackTraceTestCase(unittest.TestCase):
+ """Test cases for call stack trace functionality in Spark Connect
client."""
+
+ def setUp(self):
+ # Since this test itself is under pyspark module path, stack frames
for test functions
+ # inside this file - for example, user_function() - will normally be
filtered out. So here
+ # we set the PYSPARK_ROOT to more specific pyspaark.sql.connect that
doesn't include this
+ # test file to ensure that the stack frames for user functions inside
this test file are
+ # not filtered out.
+ self.original_pyspark_root = core.PYSPARK_ROOT
+ core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__)
+
+ def tearDown(self):
+ # Restore the original PYSPARK_ROOT
+ core.PYSPARK_ROOT = self.original_pyspark_root
+
+ def test_is_pyspark_source_with_pyspark_file(self):
+ """Test that _is_pyspark_source correctly identifies PySpark files."""
+ # Get a known pyspark file path
+ from pyspark import sql
+
+ pyspark_file = sql.connect.client.__file__
+ self.assertTrue(_is_pyspark_source(pyspark_file))
+
+ def test_is_pyspark_source_with_non_pyspark_file(self):
+ """Test that _is_pyspark_source correctly identifies non-PySpark
files."""
+ # Use the current test file which is in pyspark but we'll simulate a
non-pyspark path
+ non_pyspark_file = "/tmp/user_script.py"
+ self.assertFalse(_is_pyspark_source(non_pyspark_file))
+
+ # Test with stdlib file
+ stdlib_file = os.__file__
+ self.assertFalse(_is_pyspark_source(stdlib_file))
+
+ def test_retrieve_stack_frames_includes_user_frames(self):
+ """Test that _retrieve_stack_frames includes user code frames."""
+
+ def user_function():
+ """Simulate a user function."""
+ return SparkConnectClient._retrieve_stack_frames()
+
+ def another_user_function():
+ """Another level of user code."""
+ return user_function()
+
+ stack_frames = another_user_function()
+
+ # We should have at least some frames from the test
+ self.assertGreater(len(stack_frames), 0)
+
+ # Check that we have frames with function names we expect
+ function_names = [frame.function for frame in stack_frames]
+ # At least one of our test functions should be in the stack
+ self.assertTrue(
+ "user_function" in function_names,
+ f"Expected user function names not found in: {function_names}",
+ )
+ self.assertTrue(
+ "another_user_function" in function_names,
+ f"Expected user function names not found in: {function_names}",
+ )
+ self.assertTrue(
+ "test_retrieve_stack_frames_includes_user_frames" in
function_names,
+ f"Expected user function names not found in: {function_names}",
+ )
+
+ def test_build_call_stack_trace_without_env_var(self):
+ """Test that _build_call_stack_trace returns empty list when env var
is not set."""
+ # Make sure the env var is not set
+ with patch.dict(os.environ, {}, clear=False):
+ if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ:
+ del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"]
+ call_stack = SparkConnectClient._build_call_stack_trace()
+ self.assertIsNone(call_stack, "Expected None when env var is not
set")
+
+ """Test that _build_call_stack_trace returns empty list when env var
is empty string."""
+ with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK":
""}):
+ call_stack = SparkConnectClient._build_call_stack_trace()
+ self.assertIsNone(call_stack, "Expected empty list when env var is
empty string")
+
+ def test_build_call_stack_trace_with_env_var_set(self):
+ """Test that _build_call_stack_trace builds trace when env var is
set."""
+ # Set the env var to enable call stack tracing
+ with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK":
"1"}):
+ stack_trace_details = SparkConnectClient._build_call_stack_trace()
+ self.assertIsNotNone(
+ stack_trace_details, "Expected non-None call stack when env
var is set"
+ )
+ error = pb2.FetchErrorDetailsResponse.Error()
+ if not stack_trace_details.Unpack(error):
+ self.assertTrue(False, "Expected to unpack stack trace details
into Error object")
+ # Should have at least one frame (this test function)
+ self.assertGreater(
+ len(error.stack_trace), 0, "Expected > 0 call stack frames
when env var is set"
+ )
+
+
[email protected](not should_test_connect, connect_requirement_message)
+class CallStackTraceIntegrationTestCase(unittest.TestCase):
+ """Integration tests for call stack trace in client request methods."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.client = SparkConnectClient("sc://localhost:15002",
use_reattachable_execute=False)
+ # Since this test itself is under pyspark module path, stack frames
for test functions
+ # inside this file - for example, user_function() - will normally be
filtered out. So here
+ # we set the PYSPARK_ROOT to more specific pyspaark.sql.connect that
doesn't include this
+ # test file to ensure that the stack frames for user functions inside
this test file are
+ # not filtered out.
+ self.original_pyspark_root = core.PYSPARK_ROOT
+ core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__)
+
+ def tearDown(self):
+ # Restore the original PYSPARK_ROOT
+ core.PYSPARK_ROOT = self.original_pyspark_root
+
+ def test_execute_plan_request_includes_call_stack(self):
+ """Test that _execute_plan_request_with_metadata includes call stack
with env var."""
+ with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK":
"1"}):
+ req = self.client._execute_plan_request_with_metadata()
+
+ # Should have extensions when env var is set
+ self.assertGreater(
+ len(req.user_context.extensions),
+ 0,
+ "Expected extensions with env var set",
+ )
+
+ # Verify each extension can be unpacked as Error containing
StackTraceElements
+ files = set()
+ functions = set()
+ for extension in req.user_context.extensions:
+ error = pb2.FetchErrorDetailsResponse.Error()
+ if extension.Unpack(error):
+ # Process stack trace elements within the Error
+ for stack_trace_element in error.stack_trace:
+ functions.add(stack_trace_element.method_name)
+ files.add(stack_trace_element.file_name)
+ self.assertIsInstance(stack_trace_element.method_name,
str)
+ self.assertIsInstance(stack_trace_element.file_name,
str)
+
+ self.assertTrue(
+ "test_execute_plan_request_includes_call_stack" in functions,
+ f"Expected user function names not found in: {functions}",
+ )
+ self.assertTrue(
+ __file__ in files, f"Expected user function names not found
in: {files}"
+ )
+
+ def test_analyze_plan_request_includes_call_stack(self):
+ """Test that _analyze_plan_request_with_metadata includes call stack
with env var."""
+ with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK":
"1"}):
+ req = self.client._analyze_plan_request_with_metadata()
+
+ # Should have extensions when env var is set
+ self.assertGreater(
+ len(req.user_context.extensions),
+ 0,
+ "Expected extensions with env var set",
+ )
+
+ # Verify each extension can be unpacked as Error containing
StackTraceElements
+ files = set()
+ functions = set()
+ for extension in req.user_context.extensions:
+ error = pb2.FetchErrorDetailsResponse.Error()
+ if extension.Unpack(error):
+ # Process stack trace elements within the Error
+ for stack_trace_element in error.stack_trace:
+ functions.add(stack_trace_element.method_name)
+ files.add(stack_trace_element.file_name)
+ self.assertIsInstance(stack_trace_element.method_name,
str)
+ self.assertIsInstance(stack_trace_element.file_name,
str)
+
+ self.assertTrue(
+ "test_analyze_plan_request_includes_call_stack" in functions,
+ f"Expected user function names not found in: {functions}",
+ )
+ self.assertTrue(
+ __file__ in files, f"Expected user function names not found
in: {files}"
+ )
+
+ def test_config_request_includes_call_stack_with_env_var(self):
+ """Test that _config_request_with_metadata includes call stack with
env var."""
+ with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK":
"1"}):
+ req = self.client._config_request_with_metadata()
+
+ # Should have extensions when env var is set
+ self.assertGreater(
+ len(req.user_context.extensions),
+ 0,
+ "Expected extensions with env var set",
+ )
+
+ # Verify each extension can be unpacked as Error containing
StackTraceElements
+ files = set()
+ functions = set()
+ for extension in req.user_context.extensions:
+ error = pb2.FetchErrorDetailsResponse.Error()
+ if extension.Unpack(error):
+ # Process stack trace elements within the Error
+ for stack_trace_element in error.stack_trace:
+ functions.add(stack_trace_element.method_name)
+ files.add(stack_trace_element.file_name)
+ self.assertIsInstance(stack_trace_element.method_name,
str)
+ self.assertIsInstance(stack_trace_element.file_name,
str)
+
+ self.assertTrue(
+ "test_config_request_includes_call_stack_with_env_var" in
functions,
+ f"Expected user function names not found in: {functions}",
+ )
+ self.assertTrue(
+ __file__ in files, f"Expected user function names not found
in: {files}"
+ )
+
+ def test_call_stack_trace_captures_correct_calling_context(self):
+ """Test that call stack trace captures the correct calling context."""
+
+ def level3():
+ """Third level function."""
+ with patch.dict(os.environ,
{"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}):
+ req = self.client._execute_plan_request_with_metadata()
+ return req
+
+ def level2():
+ """Second level function."""
+ return level3()
+
+ def level1():
+ """First level function."""
+ return level2()
+
+ req = level1()
+
+ # Verify we captured frames from our nested functions
+ self.assertGreater(len(req.user_context.extensions), 0)
+
+ # Unpack and check that we have function names from our call chain
+ functions = set()
+ files = set()
+ for extension in req.user_context.extensions:
+ error = pb2.FetchErrorDetailsResponse.Error()
+ if extension.Unpack(error):
+ # Process stack trace elements within the Error
+ for stack_trace_element in error.stack_trace:
+ functions.add(stack_trace_element.method_name)
+ files.add(stack_trace_element.file_name)
+ self.assertGreater(
+ stack_trace_element.line_number,
+ 0,
+ (
+ f"Expected line number to be greater than 0,"
+ f"got: {stack_trace_element.line_number}"
+ ),
+ )
+
+ self.assertTrue(
+ "level1" in functions, f"Expected user function names not found
in: {functions}"
+ )
+ self.assertTrue(
+ "level2" in functions, f"Expected user function names not found
in: {functions}"
+ )
+ self.assertTrue(
+ "level3" in functions, f"Expected user function names not found
in: {functions}"
+ )
+ self.assertTrue(__file__ in files, f"Expected user function names not
found in: {files}")
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.client.test_client_call_stack_trace import
* # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]