This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 22094fef57b2 [SPARK-55161][PYTHON] Support profilers on python data 
source
22094fef57b2 is described below

commit 22094fef57b2a664495de710315707c14fa2b215
Author: Tian Gao <[email protected]>
AuthorDate: Tue Feb 3 08:18:16 2026 +0900

    [SPARK-55161][PYTHON] Support profilers on python data source
    
    ### What changes were proposed in this pull request?
    
    Make it possible to enable perf/memory profiler on python data source with 
`pyspark.sql.pyspark.dataSource.profiler`
    
    ### Why are the changes needed?
    
    Python data source needs profiler capabilities like UDFs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. A new config/feature is introduced.
    
    Also notice that the udf profiler config once happened to be able to log 
data source read/write because it uses udf. It can't anymore.
    
    ### How was this patch tested?
    
    3 new tests were added and they pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53945 from gaogaotiantian/datasource-profiler.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/profiler.py                         |  12 +++
 python/pyspark/sql/_typing.pyi                     |   2 +-
 python/pyspark/sql/profiler.py                     | 118 ++++++++++++++-------
 python/pyspark/sql/tests/test_python_datasource.py |  44 ++++++++
 python/pyspark/sql/tests/test_udf_profiler.py      |  30 ------
 python/pyspark/sql/worker/utils.py                 |  30 +++++-
 python/pyspark/worker.py                           |  46 +++-----
 python/pyspark/worker_util.py                      |  34 +++++-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  13 +++
 .../python/PythonStreamingSinkCommitRunner.scala   |   7 ++
 .../v2/python/UserDefinedPythonDataSource.scala    |  36 +++++++
 .../sql/execution/python/ArrowPythonRunner.scala   |   8 +-
 .../sql/execution/python/PythonPlannerRunner.scala |   3 +
 13 files changed, 272 insertions(+), 111 deletions(-)

diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
index 1e9e398b25f8..200c14870348 100644
--- a/python/pyspark/profiler.py
+++ b/python/pyspark/profiler.py
@@ -45,6 +45,7 @@ try:
 except Exception:
     has_memory_profiler = False
 
+import pyspark
 from pyspark.accumulators import AccumulatorParam
 from pyspark.errors import PySparkRuntimeError, PySparkValueError
 
@@ -477,6 +478,17 @@ class MemoryProfiler(Profiler):
             stream.write(header + "\n")
             stream.write("=" * len(header) + "\n")
 
+            if "pyspark.zip/pyspark/" in filename:
+                # if the original filename is in pyspark.zip file, we try to 
find the actual
+                # file in pyspark module by concatenating pyspark module 
directory and the
+                # rest of the filename
+                # Eventually we should ask the data provider to provide the 
actual lines
+                # because there's no guarantee that we can always find the 
actual file
+                # on driver side
+                filename = os.path.join(
+                    os.path.dirname(pyspark.__file__), 
filename.rsplit("pyspark.zip/pyspark/", 1)[1]
+                )
+
             all_lines = linecache.getlines(filename)
             if len(all_lines) == 0:
                 raise PySparkValueError(
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index fafc9bdf15fc..cf0a1d522118 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -85,4 +85,4 @@ class UserDefinedFunctionLike(Protocol):
     def __call__(self, *args: ColumnOrName) -> Column: ...
     def asNondeterministic(self) -> UserDefinedFunctionLike: ...
 
-ProfileResults = Dict[int, Tuple[Optional[pstats.Stats], 
Optional[CodeMapDict]]]
+ProfileResults = Dict[Union[int, str], Tuple[Optional[pstats.Stats], 
Optional[CodeMapDict]]]
diff --git a/python/pyspark/sql/profiler.py b/python/pyspark/sql/profiler.py
index 2aee60eeb41d..87efff44363f 100644
--- a/python/pyspark/sql/profiler.py
+++ b/python/pyspark/sql/profiler.py
@@ -21,7 +21,18 @@ import os
 import pstats
 from threading import RLock
 from types import CodeType, TracebackType
-from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union, 
TYPE_CHECKING, overload
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    Literal,
+    Optional,
+    Tuple,
+    Union,
+    TYPE_CHECKING,
+    overload,
+)
 import warnings
 
 from pyspark.accumulators import (
@@ -82,10 +93,12 @@ class WorkerPerfProfiler:
     PerfProfiler is a profiler for performance profiling.
     """
 
-    def __init__(self, accumulator: Accumulator["ProfileResults"], result_id: 
int) -> None:
+    def __init__(
+        self, accumulator: Accumulator[Optional["ProfileResults"]], 
result_key: Union[int, str]
+    ) -> None:
         self._accumulator = accumulator
         self._profiler = cProfile.Profile()
-        self._result_id = result_id
+        self._result_key = result_key
 
     def start(self) -> None:
         self._profiler.enable()
@@ -98,7 +111,7 @@ class WorkerPerfProfiler:
         # make it picklable
         st.stream = None  # type: ignore[attr-defined]
         st.strip_dirs()
-        self._accumulator.add({self._result_id: (st, None)})
+        self._accumulator.add({self._result_key: (st, None)})
 
     def __enter__(self) -> "WorkerPerfProfiler":
         self.start()
@@ -121,8 +134,8 @@ class WorkerMemoryProfiler:
 
     def __init__(
         self,
-        accumulator: Accumulator["ProfileResults"],
-        result_id: int,
+        accumulator: Accumulator[Optional["ProfileResults"]],
+        result_key: Union[int, str],
         func_or_code: Union[Callable, CodeType],
     ) -> None:
         from pyspark.profiler import UDFLineProfilerV2
@@ -133,7 +146,7 @@ class WorkerMemoryProfiler:
             self._profiler.add_code(func_or_code)
         else:
             self._profiler.add_function(func_or_code)
-        self._result_id = result_id
+        self._result_key = result_key
 
     def start(self) -> None:
         self._profiler.enable_by_count()
@@ -146,7 +159,7 @@ class WorkerMemoryProfiler:
             filename: list(line_iterator)
             for filename, line_iterator in self._profiler.code_map.items()
         }
-        self._accumulator.add({self._result_id: (None, codemap_dict)})
+        self._accumulator.add({self._result_key: (None, codemap_dict)})
 
     def __enter__(self) -> "WorkerMemoryProfiler":
         self.start()
@@ -173,7 +186,12 @@ class ProfilerCollector(ABC):
     def __init__(self) -> None:
         self._lock = RLock()
 
-    def show_perf_profiles(self, id: Optional[int] = None) -> None:
+    def _sorted_keys(self, keys: Iterable[Union[int, str]]) -> list[Union[int, 
str]]:
+        int_keys = sorted(x for x in keys if isinstance(x, int))
+        str_keys = sorted(x for x in keys if isinstance(x, str))
+        return str_keys + int_keys
+
+    def show_perf_profiles(self, id: Optional[Union[int, str]] = None) -> None:
         """
         Show the perf profile results.
 
@@ -187,22 +205,25 @@ class ProfilerCollector(ABC):
         with self._lock:
             stats = self._perf_profile_results
 
-        def show(id: int) -> None:
+        def show(id: Union[int, str]) -> None:
             s = stats.get(id)
             if s is not None:
                 print("=" * 60)
-                print(f"Profile of UDF<id={id}>")
+                if isinstance(id, str):
+                    print(f"Profile of {id}")
+                else:
+                    print(f"Profile of UDF<id={id}>")
                 print("=" * 60)
                 s.sort_stats("time", "cumulative").print_stats()
 
         if id is not None:
             show(id)
         else:
-            for id in sorted(stats.keys()):
+            for id in self._sorted_keys(stats.keys()):
                 show(id)
 
     @property
-    def _perf_profile_results(self) -> Dict[int, pstats.Stats]:
+    def _perf_profile_results(self) -> Dict[Union[int, str], pstats.Stats]:
         with self._lock:
             return {
                 result_id: perf
@@ -210,7 +231,7 @@ class ProfilerCollector(ABC):
                 if perf is not None
             }
 
-    def show_memory_profiles(self, id: Optional[int] = None) -> None:
+    def show_memory_profiles(self, id: Optional[Union[int, str]] = None) -> 
None:
         """
         Show the memory profile results.
 
@@ -230,22 +251,25 @@ class ProfilerCollector(ABC):
                 UserWarning,
             )
 
-        def show(id: int) -> None:
+        def show(id: Union[int, str]) -> None:
             cm = code_map.get(id)
             if cm is not None:
                 print("=" * 60)
-                print(f"Profile of UDF<id={id}>")
+                if isinstance(id, str):
+                    print(f"Profile of {id}")
+                else:
+                    print(f"Profile of UDF<id={id}>")
                 print("=" * 60)
                 MemoryProfiler._show_results(cm)
 
         if id is not None:
             show(id)
         else:
-            for id in sorted(code_map.keys()):
+            for id in self._sorted_keys(code_map.keys()):
                 show(id)
 
     @property
-    def _memory_profile_results(self) -> Dict[int, CodeMapDict]:
+    def _memory_profile_results(self) -> Dict[Union[int, str], CodeMapDict]:
         with self._lock:
             return {
                 result_id: mem
@@ -261,7 +285,7 @@ class ProfilerCollector(ABC):
         """
         ...
 
-    def dump_perf_profiles(self, path: str, id: Optional[int] = None) -> None:
+    def dump_perf_profiles(self, path: str, id: Optional[Union[int, str]] = 
None) -> None:
         """
         Dump the perf profile results into directory `path`.
 
@@ -271,13 +295,13 @@ class ProfilerCollector(ABC):
         ----------
         path: str
             A directory in which to dump the perf profile.
-        id : int, optional
+        id : int or str, optional
             A UDF ID to be shown. If not specified, all the results will be 
shown.
         """
         with self._lock:
             stats = self._perf_profile_results
 
-        def dump(id: int) -> None:
+        def dump(id: Union[int, str]) -> None:
             s = stats.get(id)
 
             if s is not None:
@@ -288,10 +312,10 @@ class ProfilerCollector(ABC):
         if id is not None:
             dump(id)
         else:
-            for id in sorted(stats.keys()):
+            for id in self._sorted_keys(stats.keys()):
                 dump(id)
 
-    def dump_memory_profiles(self, path: str, id: Optional[int] = None) -> 
None:
+    def dump_memory_profiles(self, path: str, id: Optional[Union[int, str]] = 
None) -> None:
         """
         Dump the memory profile results into directory `path`.
 
@@ -301,7 +325,7 @@ class ProfilerCollector(ABC):
         ----------
         path: str
             A directory in which to dump the memory profile.
-        id : int, optional
+        id : int or str, optional
             A UDF ID to be shown. If not specified, all the results will be 
shown.
         """
         with self._lock:
@@ -313,7 +337,7 @@ class ProfilerCollector(ABC):
                 UserWarning,
             )
 
-        def dump(id: int) -> None:
+        def dump(id: Union[int, str]) -> None:
             cm = code_map.get(id)
 
             if cm is not None:
@@ -326,10 +350,10 @@ class ProfilerCollector(ABC):
         if id is not None:
             dump(id)
         else:
-            for id in sorted(code_map.keys()):
+            for id in self._sorted_keys(code_map.keys()):
                 dump(id)
 
-    def clear_perf_profiles(self, id: Optional[int] = None) -> None:
+    def clear_perf_profiles(self, id: Optional[Union[int, str]] = None) -> 
None:
         """
         Clear the perf profile results.
 
@@ -337,7 +361,7 @@ class ProfilerCollector(ABC):
 
         Parameters
         ----------
-        id : int, optional
+        id : int or str, optional
             The UDF ID whose profiling results should be cleared.
             If not specified, all the results will be cleared.
         """
@@ -354,7 +378,7 @@ class ProfilerCollector(ABC):
                     if mem is None:
                         self._profile_results.pop(id, None)
 
-    def clear_memory_profiles(self, id: Optional[int] = None) -> None:
+    def clear_memory_profiles(self, id: Optional[Union[int, str]] = None) -> 
None:
         """
         Clear the memory profile results.
 
@@ -362,7 +386,7 @@ class ProfilerCollector(ABC):
 
         Parameters
         ----------
-        id : int, optional
+        id : int or str, optional
             The UDF ID whose profiling results should be cleared.
             If not specified, all the results will be cleared.
         """
@@ -407,7 +431,7 @@ class Profile:
     def __init__(self, profiler_collector: ProfilerCollector):
         self.profiler_collector = profiler_collector
 
-    def show(self, id: Optional[int] = None, *, type: Optional[str] = None) -> 
None:
+    def show(self, id: Optional[Union[int, str]] = None, *, type: 
Optional[str] = None) -> None:
         """
         Show the profile results.
 
@@ -415,7 +439,7 @@ class Profile:
 
         Parameters
         ----------
-        id : int, optional
+        id : int or str, optional
             A UDF ID to be shown. If not specified, all the results will be 
shown.
         type : str, optional
             The profiler type, which can be either "perf" or "memory".
@@ -441,7 +465,9 @@ class Profile:
                 },
             )
 
-    def dump(self, path: str, id: Optional[int] = None, *, type: Optional[str] 
= None) -> None:
+    def dump(
+        self, path: str, id: Optional[Union[int, str]] = None, *, type: 
Optional[str] = None
+    ) -> None:
         """
         Dump the profile results into directory `path`.
 
@@ -451,7 +477,7 @@ class Profile:
         ----------
         path: str
             A directory in which to dump the profile.
-        id : int, optional
+        id : int or str, optional
             A UDF ID to be shown. If not specified, all the results will be 
shown.
         type : str, optional
             The profiler type, which can be either "perf" or "memory".
@@ -472,24 +498,34 @@ class Profile:
             )
 
     @overload
-    def render(self, id: int, *, type: Optional[str] = None, renderer: 
Optional[str] = None) -> Any:
+    def render(
+        self, id: Union[int, str], *, type: Optional[str] = None, renderer: 
Optional[str] = None
+    ) -> Any:
         ...
 
     @overload
     def render(
-        self, id: int, *, type: Optional[Literal["perf"]], renderer: 
Callable[[pstats.Stats], Any]
+        self,
+        id: Union[int, str],
+        *,
+        type: Optional[Literal["perf"]],
+        renderer: Callable[[pstats.Stats], Any],
     ) -> Any:
         ...
 
     @overload
     def render(
-        self, id: int, *, type: Literal["memory"], renderer: 
Callable[[CodeMapDict], Any]
+        self,
+        id: Union[int, str],
+        *,
+        type: Literal["memory"],
+        renderer: Callable[[CodeMapDict], Any],
     ) -> Any:
         ...
 
     def render(
         self,
-        id: int,
+        id: Union[int, str],
         *,
         type: Optional[str] = None,
         renderer: Optional[
@@ -503,7 +539,7 @@ class Profile:
 
         Parameters
         ----------
-        id : int
+        id : int or str
             The UDF ID whose profiling results should be rendered.
         type : str, optional
             The profiler type to render results for, which can be either 
"perf" or "memory".
@@ -550,7 +586,7 @@ class Profile:
         if result is not None:
             return render(result)  # type:ignore[arg-type]
 
-    def clear(self, id: Optional[int] = None, *, type: Optional[str] = None) 
-> None:
+    def clear(self, id: Optional[Union[int, str]] = None, *, type: 
Optional[str] = None) -> None:
         """
         Clear the profile results.
 
@@ -558,7 +594,7 @@ class Profile:
 
         Parameters
         ----------
-        id : int, optional
+        id : int or str, optional
             The UDF ID whose profiling results should be cleared.
             If not specified, all the results will be cleared.
         type : str, optional
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index a680fbdb8ef0..eefabb3e7ea0 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import contextlib
+import io
 import os
 import platform
 import tempfile
@@ -26,6 +28,7 @@ from decimal import Decimal
 from typing import Callable, Iterable, List, Union, Iterator, Tuple
 
 from pyspark.errors import AnalysisException, PythonException
+from pyspark.profiler import has_memory_profiler
 from pyspark.sql.datasource import (
     CaseInsensitiveDict,
     DataSource,
@@ -1260,6 +1263,47 @@ class BasePythonDataSourceTestsMixin:
                     ],
                 )
 
+    def test_data_source_perf_profiler(self):
+        with self.sql_conf({"spark.sql.pyspark.dataSource.profiler": "perf"}):
+            self.test_custom_json_data_source_read()
+            with contextlib.redirect_stdout(io.StringIO()) as stdout_io:
+                self.spark.profile.show(type="perf")
+            self.spark.profile.clear()
+            stdout = stdout_io.getvalue()
+            self.assertIn("Profile of create_data_source", stdout)
+            self.assertIn("Profile of plan_data_source_read", stdout)
+            self.assertIn("ncalls", stdout)
+            self.assertIn("tottime", stdout)
+            # We should also found UDF profile results for data source read
+            self.assertIn("UDF<id=", stdout)
+
+    @unittest.skipIf(
+        "COVERAGE_PROCESS_START" in os.environ, "Fails with coverage enabled, 
skipping for now."
+    )
+    @unittest.skipIf(not has_memory_profiler, "Must have memory-profiler 
installed.")
+    def test_data_source_memory_profiler(self):
+        with self.sql_conf({"spark.sql.pyspark.dataSource.profiler": 
"memory"}):
+            self.test_custom_json_data_source_read()
+            with contextlib.redirect_stdout(io.StringIO()) as stdout_io:
+                self.spark.profile.show(type="memory")
+            self.spark.profile.clear()
+            stdout = stdout_io.getvalue()
+            self.assertIn("Profile of create_data_source", stdout)
+            self.assertIn("Profile of plan_data_source_read", stdout)
+            self.assertIn("Mem usage", stdout)
+            # We should also found UDF profile results for data source read
+            self.assertIn("UDF<id=", stdout)
+
+    def test_data_source_read_with_udf_perf_profiler(self):
+        """udf profiler config should not enable data source profiling"""
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+            self.test_custom_json_data_source_read()
+            with contextlib.redirect_stdout(io.StringIO()) as stdout_io:
+                self.spark.profile.show(type="perf")
+            self.spark.profile.clear()
+            stdout = stdout_io.getvalue()
+            self.assertEqual(stdout, "")
+
 
 class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
     ...
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py 
b/python/pyspark/sql/tests/test_udf_profiler.py
index 4389559c40d8..759aa9c68363 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -28,7 +28,6 @@ from typing import Iterator
 from pyspark import SparkConf
 from pyspark.errors import PySparkValueError
 from pyspark.sql import SparkSession
-from pyspark.sql.datasource import DataSource, DataSourceReader
 from pyspark.sql.functions import col, arrow_udf, pandas_udf, udf
 from pyspark.sql.window import Window
 from pyspark.profiler import UDFBasicProfiler
@@ -663,35 +662,6 @@ class UDFProfiler2TestsMixin:
         for id in self.profile_results:
             self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=2)
 
-    @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
-    def test_perf_profiler_data_source(self):
-        class TestDataSourceReader(DataSourceReader):
-            def __init__(self, schema):
-                self.schema = schema
-
-            def partitions(self):
-                raise NotImplementedError
-
-            def read(self, partition):
-                yield from ((1,), (2,), (3,))
-
-        class TestDataSource(DataSource):
-            def schema(self):
-                return "id long"
-
-            def reader(self, schema) -> "DataSourceReader":
-                return TestDataSourceReader(schema)
-
-        self.spark.dataSource.register(TestDataSource)
-
-        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
-            self.spark.read.format("TestDataSource").load().collect()
-
-        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
-
-        for id in self.profile_results:
-            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=4)
-
     def test_perf_profiler_render(self):
         with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
             _do_computation(self.spark)
diff --git a/python/pyspark/sql/worker/utils.py 
b/python/pyspark/sql/worker/utils.py
index bd5c6ffda9ee..8a99abe3e4e9 100644
--- a/python/pyspark/sql/worker/utils.py
+++ b/python/pyspark/sql/worker/utils.py
@@ -17,9 +17,14 @@
 
 import os
 import sys
-from typing import Callable, IO
+from typing import Callable, IO, Optional
 
-from pyspark.accumulators import _accumulatorRegistry
+from pyspark.accumulators import (
+    _accumulatorRegistry,
+    _deserialize_accumulator,
+    SpecialAccumulatorIds,
+)
+from pyspark.sql.profiler import ProfileResultsParam, WorkerPerfProfiler, 
WorkerMemoryProfiler
 from pyspark.serializers import (
     read_int,
     write_int,
@@ -36,9 +41,16 @@ from pyspark.worker_util import (
     setup_memory_limits,
     setup_spark_files,
     setup_broadcasts,
+    Conf,
 )
 
 
+class RunnerConf(Conf):
+    @property
+    def profiler(self) -> Optional[str]:
+        return self.get("spark.sql.pyspark.dataSource.profiler", None)
+
+
 @with_faulthandler
 def worker_run(main: Callable, infile: IO, outfile: IO) -> None:
     try:
@@ -51,10 +63,22 @@ def worker_run(main: Callable, infile: IO, outfile: IO) -> 
None:
 
         setup_spark_files(infile)
         setup_broadcasts(infile)
+        conf = RunnerConf(infile)
 
         _accumulatorRegistry.clear()
+        accumulator = _deserialize_accumulator(
+            SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
+        )
 
-        main(infile, outfile)
+        worker_module = main.__module__.split(".")[-1]
+        if conf.profiler == "perf":
+            with WorkerPerfProfiler(accumulator, worker_module):
+                main(infile, outfile)
+        elif conf.profiler == "memory":
+            with WorkerMemoryProfiler(accumulator, worker_module, main):
+                main(infile, outfile)
+        else:
+            main(infile, outfile)
     except BaseException as e:
         handle_worker_exception(e, outfile)
         sys.exit(-1)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 03bc1366e875..238868fbced1 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -102,6 +102,7 @@ from pyspark.worker_util import (
     setup_memory_limits,
     setup_spark_files,
     utf8_deserializer,
+    Conf,
 )
 from pyspark.logger.worker_io import capture_outputs
 
@@ -113,37 +114,7 @@ except Exception:
     has_memory_profiler = False
 
 
-class RunnerConf:
-    def __init__(self, infile=None):
-        self._conf = {}
-        if infile is not None:
-            self.load(infile)
-
-    def load(self, infile):
-        num_conf = read_int(infile)
-        # We do a sanity check here to reduce the possibility to stuck 
indefinitely
-        # due to an invalid messsage. If the numer of configurations is 
obviously
-        # wrong, we just raise an error directly.
-        # We hand-pick the configurations to send to the worker so the number 
should
-        # be very small (less than 100).
-        if num_conf < 0 or num_conf > 10000:
-            raise PySparkRuntimeError(
-                errorClass="PROTOCOL_ERROR",
-                messageParameters={
-                    "failure": f"Invalid number of configurations: {num_conf}",
-                },
-            )
-        for _ in range(num_conf):
-            k = utf8_deserializer.loads(infile)
-            v = utf8_deserializer.loads(infile)
-            self._conf[k] = v
-
-    def get(self, key: str, default=""):
-        val = self._conf.get(key, default)
-        if isinstance(val, str):
-            return val.lower()
-        return val
-
+class RunnerConf(Conf):
     @property
     def assign_cols_by_name(self) -> bool:
         return (
@@ -201,9 +172,13 @@ class RunnerConf:
         return 
int(self.get("spark.sql.execution.pythonUDF.arrow.concurrency.level", -1))
 
     @property
-    def profiler(self) -> Optional[str]:
+    def udf_profiler(self) -> Optional[str]:
         return self.get("spark.sql.pyspark.udf.profiler", None)
 
+    @property
+    def data_source_profiler(self) -> Optional[str]:
+        return self.get("spark.sql.pyspark.dataSource.profiler", None)
+
 
 def report_times(outfile, boot, init, finish, processing_time_ms):
     write_int(SpecialLengths.TIMING_DATA, outfile)
@@ -1424,7 +1399,12 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
 
     result_id = read_long(infile)
 
-    profiler = runner_conf.profiler
+    # If chained_func is from pyspark.sql.worker, it is to read/write data 
source.
+    # In this case, we check the data_source_profiler config.
+    if getattr(chained_func, "__module__", 
"").startswith("pyspark.sql.worker."):
+        profiler = runner_conf.data_source_profiler
+    else:
+        profiler = runner_conf.udf_profiler
     if profiler == "perf":
         profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id)
     elif profiler == "memory":
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index ccca0630d2af..6c6db7e2a53d 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -22,7 +22,7 @@ import importlib
 from inspect import currentframe, getframeinfo
 import os
 import sys
-from typing import Any, IO
+from typing import Any, IO, Optional
 import warnings
 
 # 'resource' is a Unix specific module.
@@ -194,3 +194,35 @@ def send_accumulator_updates(outfile: IO) -> None:
     write_int(len(_accumulatorRegistry), outfile)
     for aid, accum in _accumulatorRegistry.items():
         pickleSer._write_with_length((aid, accum._value), outfile)
+
+
+class Conf:
+    def __init__(self, infile: Optional[IO] = None) -> None:
+        self._conf: dict[str, Any] = {}
+        if infile is not None:
+            self.load(infile)
+
+    def load(self, infile: IO) -> None:
+        num_conf = read_int(infile)
+        # We do a sanity check here to reduce the possibility to stuck 
indefinitely
+        # due to an invalid messsage. If the numer of configurations is 
obviously
+        # wrong, we just raise an error directly.
+        # We hand-pick the configurations to send to the worker so the number 
should
+        # be very small (less than 100).
+        if num_conf < 0 or num_conf > 10000:
+            raise PySparkRuntimeError(
+                errorClass="PROTOCOL_ERROR",
+                messageParameters={
+                    "failure": f"Invalid number of configurations: {num_conf}",
+                },
+            )
+        for _ in range(num_conf):
+            k = utf8_deserializer.loads(infile)
+            v = utf8_deserializer.loads(infile)
+            self._conf[k] = v
+
+    def get(self, key: str, default: Any = "") -> Any:
+        val = self._conf.get(key, default)
+        if isinstance(val, str):
+            return val.lower()
+        return val
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index cd892936c9b7..1bed0b3bebc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4086,6 +4086,17 @@ object SQLConf {
       // show full stacktrace in tests but hide in production by default.
       .createWithDefault(Utils.isTesting)
 
+  val PYTHON_DATA_SOURCE_PROFILER =
+    buildConf("spark.sql.pyspark.dataSource.profiler")
+      .doc("Configure the Python Data Source profiler by enabling or disabling 
it " +
+        "with the option to choose between \"perf\" and \"memory\" types, " +
+        "or unsetting the config disables the profiler. This is disabled by 
default.")
+      .version("4.2.0")
+      .stringConf
+      .transform(_.toLowerCase(Locale.ROOT))
+      .checkValues(Set("perf", "memory"))
+      .createOptional
+
   val PYTHON_UDF_PROFILER =
     buildConf("spark.sql.pyspark.udf.profiler")
       .doc("Configure the Python/Pandas UDF profiler by enabling or disabling 
it " +
@@ -7719,6 +7730,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def pythonUDFProfiler: Option[String] = getConf(PYTHON_UDF_PROFILER)
 
+  def pythonDataSourceProfiler: Option[String] = 
getConf(PYTHON_DATA_SOURCE_PROFILER)
+
   def pythonUDFWorkerFaulthandlerEnabled: Boolean = 
getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)
 
   def pythonUDFWorkerIdleTimeoutSeconds: Long = 
getConf(PYTHON_UDF_WORKER_IDLE_TIMEOUT_SECONDS)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
index b04ebe92910a..3d814606f531 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
@@ -25,6 +25,7 @@ import org.apache.spark.api.python.{PythonFunction, 
PythonWorkerUtils, SpecialLe
 import org.apache.spark.sql.connector.write.WriterCommitMessage
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.python.PythonPlannerRunner
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -43,6 +44,12 @@ class PythonStreamingSinkCommitRunner(
     abort: Boolean) extends PythonPlannerRunner[Unit](dataSourceCls) {
   override val workerModule: String = 
"pyspark.sql.worker.python_streaming_sink_runner"
 
+  override protected def runnerConf: Map[String, String] = {
+    super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+      Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Map.empty)
+  }
+
   override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
     // Send the user function to python process.
     PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index c147030037cd..7611cf676499 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -232,6 +232,12 @@ private class 
UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunct
 
   override val workerModule = "pyspark.sql.worker.lookup_data_sources"
 
+  override protected def runnerConf: Map[String, String] = {
+    super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+      Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Map.empty)
+  }
+
   override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
     // No input needed.
   }
@@ -282,6 +288,12 @@ private class UserDefinedPythonDataSourceRunner(
 
   override val workerModule = "pyspark.sql.worker.create_data_source"
 
+  override protected def runnerConf: Map[String, String] = {
+    super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+      Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Map.empty)
+  }
+
   override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
     // Send python data source
     PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
@@ -446,6 +458,12 @@ private class 
UserDefinedPythonDataSourceFilterPushdownRunner(
   // See the logic in `pyspark.sql.worker.data_source_pushdown_filters.py`.
   override val workerModule = "pyspark.sql.worker.data_source_pushdown_filters"
 
+  override protected def runnerConf: Map[String, String] = {
+    super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+      Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Map.empty)
+  }
+
   def isAnyFilterSupported: Boolean = serializedFilters.nonEmpty
 
   override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
@@ -538,6 +556,12 @@ private class UserDefinedPythonDataSourceReadRunner(
   // See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
   override val workerModule = "pyspark.sql.worker.plan_data_source_read"
 
+  override protected def runnerConf: Map[String, String] = {
+    super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+      Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Map.empty)
+  }
+
   override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
     // Send Python data source
     PythonWorkerUtils.writePythonFunction(func, dataOut)
@@ -580,6 +604,12 @@ private class UserDefinedPythonDataSourceWriteRunner(
 
   override val workerModule: String = 
"pyspark.sql.worker.write_into_data_source"
 
+  override protected def runnerConf: Map[String, String] = {
+    super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+      Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Map.empty)
+  }
+
   override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
     // Send the Python data source class.
     PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
@@ -639,6 +669,12 @@ private class UserDefinedPythonDataSourceCommitRunner(
     abort: Boolean) extends PythonPlannerRunner[Unit](dataSourceCls) {
   override val workerModule: String = 
"pyspark.sql.worker.commit_data_source_write"
 
+  override protected def runnerConf: Map[String, String] = {
+    super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+      Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Map.empty)
+  }
+
   override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
     // Send the Python data source writer.
     PythonWorkerUtils.writeBytes(writer, dataOut)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 7a12dbd556bf..a5536621c531 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -168,12 +168,16 @@ object ArrowPythonRunner {
     val binaryAsBytes = Seq(
       SQLConf.PYSPARK_BINARY_AS_BYTES.key ->
       conf.pysparkBinaryAsBytes.toString)
-    val profiler = conf.pythonUDFProfiler.map(p =>
+    val udfProfiler = conf.pythonUDFProfiler.map(p =>
       Seq(SQLConf.PYTHON_UDF_PROFILER.key -> p)
     ).getOrElse(Seq.empty)
+    val dataSourceProfiler = conf.pythonDataSourceProfiler.map(p =>
+      Seq(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+    ).getOrElse(Seq.empty)
     Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
       arrowAyncParallelism ++ useLargeVarTypes ++
       intToDecimalCoercion ++ binaryAsBytes ++
-      legacyPandasConversion ++ legacyPandasConversionUDF ++ profiler: _*)
+      legacyPandasConversion ++ legacyPandasConversionUDF ++
+      udfProfiler ++ dataSourceProfiler: _*)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
index 92e99cdc11d9..dd1a869ddf7d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -44,6 +44,8 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) 
extends Logging {
 
   protected val workerModule: String
 
+  protected def runnerConf: Map[String, String] = Map.empty
+
   protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): 
Unit
 
   protected def receiveFromPython(dataIn: DataInputStream): T
@@ -123,6 +125,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) 
extends Logging {
       PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
       PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, 
dataOut)
       PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
+      PythonWorkerUtils.writeConf(runnerConf, dataOut)
 
       writeToPython(dataOut, pickler)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to