This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 01a79c857639 [SPARK-54314][PYTHON][CONNECT] Improve Server-Side 
debuggability in Spark Connect by capturing client application's file name and 
line numbers
01a79c857639 is described below

commit 01a79c857639afb3fbafbb510520e9deaf28db11
Author: susheel-aroskar <[email protected]>
AuthorDate: Thu Jan 22 14:03:36 2026 -0800

    [SPARK-54314][PYTHON][CONNECT] Improve Server-Side debuggability in Spark 
Connect by capturing client application's file name and line numbers
    
    ### What changes were proposed in this pull request?
    Optionally transmitting client-side code location details (function name, 
file name and line number) along with actions.
    
    ### Why are the changes needed?
    Right now there is no information sent to Spark Connect server that will 
aid in pointing the location of the call  (i.e. Spark data frame action) in the 
client application code. By making this change, client application call stack 
details are sent to the server as a list of (function name, file name, line 
number) tuples where they can be logged in the server logs, included in 
corresponding open telemetry spans as attributes etc. This will help users 
looking from server side UI or Cons [...]
    
    ### Does this PR introduce _any_ user-facing change?
    It includes a new ENV variable `SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK` 
which user can set to true / 1 to opt into transmitting client application code 
locations to server. If opted into, the client app call stack trace details are 
included in the `user_context.extensions` field of the Spark Connect protobufs
    
    ### How was this patch tested?
    By adding new unit test `test_client_call_stack_trace.py`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Yes.
    Some of the unit tests were Generated-by:  Cursor
    
    Closes #53076 from 
susheel-aroskar/sarsokar-SPARK-54314-pyspark-connect-debug.
    
    Lead-authored-by: susheel-aroskar <[email protected]>
    Co-authored-by: Susheel Aroskar <[email protected]>
    Signed-off-by: Huaxin Gao <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/sql/connect/client/core.py          |  63 ++++
 .../connect/client/test_client_call_stack_trace.py | 318 +++++++++++++++++++++
 3 files changed, 382 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 1e91432f4abc..cf017c03ec11 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1169,6 +1169,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.client.test_artifact",
         "pyspark.sql.tests.connect.client.test_artifact_localcluster",
         "pyspark.sql.tests.connect.client.test_client",
+        "pyspark.sql.tests.connect.client.test_client_call_stack_trace",
         "pyspark.sql.tests.connect.client.test_reattach",
         "pyspark.sql.tests.connect.test_resources",
         "pyspark.sql.tests.connect.shell.test_progress",
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 83249995bb7e..58ae6a0eea98 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -22,6 +22,8 @@ __all__ = [
 
 import atexit
 
+import pyspark
+from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -35,6 +37,7 @@ import urllib.parse
 import uuid
 import sys
 import time
+import traceback
 from typing import (
     Iterable,
     Iterator,
@@ -65,6 +68,7 @@ from google.rpc import error_details_pb2
 from pyspark.util import is_remote_only
 from pyspark.accumulators import SpecialAccumulatorIds
 from pyspark.version import __version__
+from pyspark.traceback_utils import CallSite
 from pyspark.resource.information import ResourceInformation
 from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, 
ObservedMetrics
 from pyspark.sql.connect.client.artifact import ArtifactManager
@@ -115,6 +119,9 @@ if TYPE_CHECKING:
     from pyspark.sql.datasource import DataSource
 
 
+PYSPARK_ROOT = os.path.dirname(pyspark.__file__)
+
+
 def _import_zstandard_if_available() -> Optional[Any]:
     """
     Import zstandard if available, otherwise return None.
@@ -621,6 +628,11 @@ class ConfigResult:
         )
 
 
+def _is_pyspark_source(filename: str) -> bool:
+    """Check if the given filename is from the pyspark package."""
+    return filename.startswith(PYSPARK_ROOT)
+
+
 class SparkConnectClient(object):
     """
     Conceptually the remote spark session that communicates with the server
@@ -812,6 +824,50 @@ class SparkConnectClient(object):
         """
         return list(self._retry_policies)
 
+    @classmethod
+    def _retrieve_stack_frames(cls) -> List[CallSite]:
+        """
+        Return a list of CallSites representing the relevant stack frames in 
the callstack.
+        """
+        frames = traceback.extract_stack()
+
+        filtered_stack_frames = []
+        for i, frame in enumerate(frames):
+            filename, lineno, func, _ = frame
+            if _is_pyspark_source(filename):
+                # Do not include PySpark internal frames as they are not user 
application code
+                break
+            if i + 1 < len(frames):
+                _, _, func, _ = frames[i + 1]
+            filtered_stack_frames.append(CallSite(function=func, 
file=filename, linenum=lineno))
+
+        return filtered_stack_frames
+
+    @classmethod
+    def _build_call_stack_trace(cls) -> Optional[any_pb2.Any]:
+        """
+        Build a call stack trace for the current Spark Connect action
+        Returns
+        -------
+        FetchErrorDetailsResponse.Error: An Error object containing list of 
stack frames
+        of the user code packed as Any protobuf
+        """
+        if os.getenv("SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK", "false").lower() 
in ("true", "1"):
+            stack_frames = cls._retrieve_stack_frames()
+            call_stack = FetchErrorDetailsResponse.Error()
+            for call_site in stack_frames:
+                stack_trace_element = 
pb2.FetchErrorDetailsResponse.StackTraceElement()
+                stack_trace_element.declaring_class = ""  # unknown information
+                stack_trace_element.method_name = call_site.function
+                stack_trace_element.file_name = call_site.file
+                stack_trace_element.line_number = call_site.linenum
+                call_stack.stack_trace.append(stack_trace_element)
+            if len(call_stack.stack_trace) > 0:
+                call_stack_details = any_pb2.Any()
+                call_stack_details.Pack(call_stack)
+                return call_stack_details
+        return None
+
     def register_udf(
         self,
         function: Any,
@@ -1291,6 +1347,9 @@ class SparkConnectClient(object):
                 )
             req.operation_id = operation_id
         self._update_request_with_user_context_extensions(req)
+
+        if call_stack_trace := self.__class__._build_call_stack_trace():
+            req.user_context.extensions.append(call_stack_trace)
         return req
 
     def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
@@ -1302,6 +1361,8 @@ class SparkConnectClient(object):
         if self._user_id:
             req.user_context.user_id = self._user_id
         self._update_request_with_user_context_extensions(req)
+        if call_stack_trace := self.__class__._build_call_stack_trace():
+            req.user_context.extensions.append(call_stack_trace)
         return req
 
     def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
@@ -1717,6 +1778,8 @@ class SparkConnectClient(object):
         if self._user_id:
             req.user_context.user_id = self._user_id
         self._update_request_with_user_context_extensions(req)
+        if call_stack_trace := self.__class__._build_call_stack_trace():
+            req.user_context.extensions.append(call_stack_trace)
         return req
 
     def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
diff --git 
a/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py 
b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py
new file mode 100644
index 000000000000..cf5acb7ca88f
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/client/test_client_call_stack_trace.py
@@ -0,0 +1,318 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+from unittest.mock import patch
+
+import pyspark
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
+
+if should_test_connect:
+    import pyspark.sql.connect.proto as pb2
+    from pyspark.sql.connect.client import SparkConnectClient, core
+    from pyspark.sql.connect.client.core import _is_pyspark_source
+
+    # The _cleanup_ml_cache invocation will hang in this test (no valid spark 
cluster)
+    # and it blocks the test process exiting because it is registered as the 
atexit handler
+    # in `SparkConnectClient` constructor. To bypass the issue, patch the 
method in the test.
+    SparkConnectClient._cleanup_ml_cache = lambda _: None
+
+# SPARK-54314: Improve Server-Side debuggability in Spark Connect by capturing 
client application's
+# file name and line numbers in PySpark
+# https://issues.apache.org/jira/browse/SPARK-54314
+
+
[email protected](not should_test_connect, connect_requirement_message)
+class CallStackTraceTestCase(unittest.TestCase):
+    """Test cases for call stack trace functionality in Spark Connect 
client."""
+
+    def setUp(self):
+        # Since this test itself is under pyspark module path, stack frames 
for test functions
+        # inside this file - for example, user_function() - will normally be 
filtered out. So here
+        # we set the PYSPARK_ROOT to more specific pyspaark.sql.connect that 
doesn't include this
+        # test file to ensure that the stack frames for user functions inside 
this test file are
+        # not filtered out.
+        self.original_pyspark_root = core.PYSPARK_ROOT
+        core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__)
+
+    def tearDown(self):
+        # Restore the original PYSPARK_ROOT
+        core.PYSPARK_ROOT = self.original_pyspark_root
+
+    def test_is_pyspark_source_with_pyspark_file(self):
+        """Test that _is_pyspark_source correctly identifies PySpark files."""
+        # Get a known pyspark file path
+        from pyspark import sql
+
+        pyspark_file = sql.connect.client.__file__
+        self.assertTrue(_is_pyspark_source(pyspark_file))
+
+    def test_is_pyspark_source_with_non_pyspark_file(self):
+        """Test that _is_pyspark_source correctly identifies non-PySpark 
files."""
+        # Use the current test file which is in pyspark but we'll simulate a 
non-pyspark path
+        non_pyspark_file = "/tmp/user_script.py"
+        self.assertFalse(_is_pyspark_source(non_pyspark_file))
+
+        # Test with stdlib file
+        stdlib_file = os.__file__
+        self.assertFalse(_is_pyspark_source(stdlib_file))
+
+    def test_retrieve_stack_frames_includes_user_frames(self):
+        """Test that _retrieve_stack_frames includes user code frames."""
+
+        def user_function():
+            """Simulate a user function."""
+            return SparkConnectClient._retrieve_stack_frames()
+
+        def another_user_function():
+            """Another level of user code."""
+            return user_function()
+
+        stack_frames = another_user_function()
+
+        # We should have at least some frames from the test
+        self.assertGreater(len(stack_frames), 0)
+
+        # Check that we have frames with function names we expect
+        function_names = [frame.function for frame in stack_frames]
+        # At least one of our test functions should be in the stack
+        self.assertTrue(
+            "user_function" in function_names,
+            f"Expected user function names not found in: {function_names}",
+        )
+        self.assertTrue(
+            "another_user_function" in function_names,
+            f"Expected user function names not found in: {function_names}",
+        )
+        self.assertTrue(
+            "test_retrieve_stack_frames_includes_user_frames" in 
function_names,
+            f"Expected user function names not found in: {function_names}",
+        )
+
+    def test_build_call_stack_trace_without_env_var(self):
+        """Test that _build_call_stack_trace returns empty list when env var 
is not set."""
+        # Make sure the env var is not set
+        with patch.dict(os.environ, {}, clear=False):
+            if "SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK" in os.environ:
+                del os.environ["SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK"]
+            call_stack = SparkConnectClient._build_call_stack_trace()
+            self.assertIsNone(call_stack, "Expected None when env var is not 
set")
+
+        """Test that _build_call_stack_trace returns empty list when env var 
is empty string."""
+        with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": 
""}):
+            call_stack = SparkConnectClient._build_call_stack_trace()
+            self.assertIsNone(call_stack, "Expected empty list when env var is 
empty string")
+
+    def test_build_call_stack_trace_with_env_var_set(self):
+        """Test that _build_call_stack_trace builds trace when env var is 
set."""
+        # Set the env var to enable call stack tracing
+        with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": 
"1"}):
+            stack_trace_details = SparkConnectClient._build_call_stack_trace()
+            self.assertIsNotNone(
+                stack_trace_details, "Expected non-None call stack when env 
var is set"
+            )
+            error = pb2.FetchErrorDetailsResponse.Error()
+            if not stack_trace_details.Unpack(error):
+                self.assertTrue(False, "Expected to unpack stack trace details 
into Error object")
+            # Should have at least one frame (this test function)
+            self.assertGreater(
+                len(error.stack_trace), 0, "Expected > 0 call stack frames 
when env var is set"
+            )
+
+
[email protected](not should_test_connect, connect_requirement_message)
+class CallStackTraceIntegrationTestCase(unittest.TestCase):
+    """Integration tests for call stack trace in client request methods."""
+
+    def setUp(self):
+        """Set up test fixtures."""
+        self.client = SparkConnectClient("sc://localhost:15002", 
use_reattachable_execute=False)
+        # Since this test itself is under pyspark module path, stack frames 
for test functions
+        # inside this file - for example, user_function() - will normally be 
filtered out. So here
+        # we set the PYSPARK_ROOT to more specific pyspaark.sql.connect that 
doesn't include this
+        # test file to ensure that the stack frames for user functions inside 
this test file are
+        # not filtered out.
+        self.original_pyspark_root = core.PYSPARK_ROOT
+        core.PYSPARK_ROOT = os.path.dirname(pyspark.sql.connect.__file__)
+
+    def tearDown(self):
+        # Restore the original PYSPARK_ROOT
+        core.PYSPARK_ROOT = self.original_pyspark_root
+
+    def test_execute_plan_request_includes_call_stack(self):
+        """Test that _execute_plan_request_with_metadata includes call stack 
with env var."""
+        with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": 
"1"}):
+            req = self.client._execute_plan_request_with_metadata()
+
+            # Should have extensions when env var is set
+            self.assertGreater(
+                len(req.user_context.extensions),
+                0,
+                "Expected extensions with env var set",
+            )
+
+            # Verify each extension can be unpacked as Error containing 
StackTraceElements
+            files = set()
+            functions = set()
+            for extension in req.user_context.extensions:
+                error = pb2.FetchErrorDetailsResponse.Error()
+                if extension.Unpack(error):
+                    # Process stack trace elements within the Error
+                    for stack_trace_element in error.stack_trace:
+                        functions.add(stack_trace_element.method_name)
+                        files.add(stack_trace_element.file_name)
+                        self.assertIsInstance(stack_trace_element.method_name, 
str)
+                        self.assertIsInstance(stack_trace_element.file_name, 
str)
+
+            self.assertTrue(
+                "test_execute_plan_request_includes_call_stack" in functions,
+                f"Expected user function names not found in: {functions}",
+            )
+            self.assertTrue(
+                __file__ in files, f"Expected user function names not found 
in: {files}"
+            )
+
+    def test_analyze_plan_request_includes_call_stack(self):
+        """Test that _analyze_plan_request_with_metadata includes call stack 
with env var."""
+        with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": 
"1"}):
+            req = self.client._analyze_plan_request_with_metadata()
+
+            # Should have extensions when env var is set
+            self.assertGreater(
+                len(req.user_context.extensions),
+                0,
+                "Expected extensions with env var set",
+            )
+
+            # Verify each extension can be unpacked as Error containing 
StackTraceElements
+            files = set()
+            functions = set()
+            for extension in req.user_context.extensions:
+                error = pb2.FetchErrorDetailsResponse.Error()
+                if extension.Unpack(error):
+                    # Process stack trace elements within the Error
+                    for stack_trace_element in error.stack_trace:
+                        functions.add(stack_trace_element.method_name)
+                        files.add(stack_trace_element.file_name)
+                        self.assertIsInstance(stack_trace_element.method_name, 
str)
+                        self.assertIsInstance(stack_trace_element.file_name, 
str)
+
+            self.assertTrue(
+                "test_analyze_plan_request_includes_call_stack" in functions,
+                f"Expected user function names not found in: {functions}",
+            )
+            self.assertTrue(
+                __file__ in files, f"Expected user function names not found 
in: {files}"
+            )
+
+    def test_config_request_includes_call_stack_with_env_var(self):
+        """Test that _config_request_with_metadata includes call stack with 
env var."""
+        with patch.dict(os.environ, {"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": 
"1"}):
+            req = self.client._config_request_with_metadata()
+
+            # Should have extensions when env var is set
+            self.assertGreater(
+                len(req.user_context.extensions),
+                0,
+                "Expected extensions with env var set",
+            )
+
+            # Verify each extension can be unpacked as Error containing 
StackTraceElements
+            files = set()
+            functions = set()
+            for extension in req.user_context.extensions:
+                error = pb2.FetchErrorDetailsResponse.Error()
+                if extension.Unpack(error):
+                    # Process stack trace elements within the Error
+                    for stack_trace_element in error.stack_trace:
+                        functions.add(stack_trace_element.method_name)
+                        files.add(stack_trace_element.file_name)
+                        self.assertIsInstance(stack_trace_element.method_name, 
str)
+                        self.assertIsInstance(stack_trace_element.file_name, 
str)
+
+            self.assertTrue(
+                "test_config_request_includes_call_stack_with_env_var" in 
functions,
+                f"Expected user function names not found in: {functions}",
+            )
+            self.assertTrue(
+                __file__ in files, f"Expected user function names not found 
in: {files}"
+            )
+
+    def test_call_stack_trace_captures_correct_calling_context(self):
+        """Test that call stack trace captures the correct calling context."""
+
+        def level3():
+            """Third level function."""
+            with patch.dict(os.environ, 
{"SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK": "1"}):
+                req = self.client._execute_plan_request_with_metadata()
+                return req
+
+        def level2():
+            """Second level function."""
+            return level3()
+
+        def level1():
+            """First level function."""
+            return level2()
+
+        req = level1()
+
+        # Verify we captured frames from our nested functions
+        self.assertGreater(len(req.user_context.extensions), 0)
+
+        # Unpack and check that we have function names from our call chain
+        functions = set()
+        files = set()
+        for extension in req.user_context.extensions:
+            error = pb2.FetchErrorDetailsResponse.Error()
+            if extension.Unpack(error):
+                # Process stack trace elements within the Error
+                for stack_trace_element in error.stack_trace:
+                    functions.add(stack_trace_element.method_name)
+                    files.add(stack_trace_element.file_name)
+                    self.assertGreater(
+                        stack_trace_element.line_number,
+                        0,
+                        (
+                            f"Expected line number to be greater than 0,"
+                            f"got: {stack_trace_element.line_number}"
+                        ),
+                    )
+
+        self.assertTrue(
+            "level1" in functions, f"Expected user function names not found 
in: {functions}"
+        )
+        self.assertTrue(
+            "level2" in functions, f"Expected user function names not found 
in: {functions}"
+        )
+        self.assertTrue(
+            "level3" in functions, f"Expected user function names not found 
in: {functions}"
+        )
+        self.assertTrue(__file__ in files, f"Expected user function names not 
found in: {files}")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.client.test_client_call_stack_trace import 
*  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to