This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 22094fef57b2 [SPARK-55161][PYTHON] Support profilers on python data
source
22094fef57b2 is described below
commit 22094fef57b2a664495de710315707c14fa2b215
Author: Tian Gao <[email protected]>
AuthorDate: Tue Feb 3 08:18:16 2026 +0900
[SPARK-55161][PYTHON] Support profilers on python data source
### What changes were proposed in this pull request?
Make it possible to enable perf/memory profiler on python data source with
`pyspark.sql.pyspark.dataSource.profiler`
### Why are the changes needed?
Python data source needs profiler capabilities like UDFs.
### Does this PR introduce _any_ user-facing change?
Yes. A new config/feature is introduced.
Also notice that the udf profiler config once happened to be able to log
data source read/write because it uses udf. It can't anymore.
### How was this patch tested?
3 new tests were added and they pass.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53945 from gaogaotiantian/datasource-profiler.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/profiler.py | 12 +++
python/pyspark/sql/_typing.pyi | 2 +-
python/pyspark/sql/profiler.py | 118 ++++++++++++++-------
python/pyspark/sql/tests/test_python_datasource.py | 44 ++++++++
python/pyspark/sql/tests/test_udf_profiler.py | 30 ------
python/pyspark/sql/worker/utils.py | 30 +++++-
python/pyspark/worker.py | 46 +++-----
python/pyspark/worker_util.py | 34 +++++-
.../org/apache/spark/sql/internal/SQLConf.scala | 13 +++
.../python/PythonStreamingSinkCommitRunner.scala | 7 ++
.../v2/python/UserDefinedPythonDataSource.scala | 36 +++++++
.../sql/execution/python/ArrowPythonRunner.scala | 8 +-
.../sql/execution/python/PythonPlannerRunner.scala | 3 +
13 files changed, 272 insertions(+), 111 deletions(-)
diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
index 1e9e398b25f8..200c14870348 100644
--- a/python/pyspark/profiler.py
+++ b/python/pyspark/profiler.py
@@ -45,6 +45,7 @@ try:
except Exception:
has_memory_profiler = False
+import pyspark
from pyspark.accumulators import AccumulatorParam
from pyspark.errors import PySparkRuntimeError, PySparkValueError
@@ -477,6 +478,17 @@ class MemoryProfiler(Profiler):
stream.write(header + "\n")
stream.write("=" * len(header) + "\n")
+ if "pyspark.zip/pyspark/" in filename:
+ # if the original filename is in pyspark.zip file, we try to
find the actual
+ # file in pyspark module by concatenating pyspark module
directory and the
+ # rest of the filename
+ # Eventually we should ask the data provider to provide the
actual lines
+ # because there's no guarantee that we can always find the
actual file
+ # on driver side
+ filename = os.path.join(
+ os.path.dirname(pyspark.__file__),
filename.rsplit("pyspark.zip/pyspark/", 1)[1]
+ )
+
all_lines = linecache.getlines(filename)
if len(all_lines) == 0:
raise PySparkValueError(
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index fafc9bdf15fc..cf0a1d522118 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -85,4 +85,4 @@ class UserDefinedFunctionLike(Protocol):
def __call__(self, *args: ColumnOrName) -> Column: ...
def asNondeterministic(self) -> UserDefinedFunctionLike: ...
-ProfileResults = Dict[int, Tuple[Optional[pstats.Stats],
Optional[CodeMapDict]]]
+ProfileResults = Dict[Union[int, str], Tuple[Optional[pstats.Stats],
Optional[CodeMapDict]]]
diff --git a/python/pyspark/sql/profiler.py b/python/pyspark/sql/profiler.py
index 2aee60eeb41d..87efff44363f 100644
--- a/python/pyspark/sql/profiler.py
+++ b/python/pyspark/sql/profiler.py
@@ -21,7 +21,18 @@ import os
import pstats
from threading import RLock
from types import CodeType, TracebackType
-from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union,
TYPE_CHECKING, overload
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+ TYPE_CHECKING,
+ overload,
+)
import warnings
from pyspark.accumulators import (
@@ -82,10 +93,12 @@ class WorkerPerfProfiler:
PerfProfiler is a profiler for performance profiling.
"""
- def __init__(self, accumulator: Accumulator["ProfileResults"], result_id:
int) -> None:
+ def __init__(
+ self, accumulator: Accumulator[Optional["ProfileResults"]],
result_key: Union[int, str]
+ ) -> None:
self._accumulator = accumulator
self._profiler = cProfile.Profile()
- self._result_id = result_id
+ self._result_key = result_key
def start(self) -> None:
self._profiler.enable()
@@ -98,7 +111,7 @@ class WorkerPerfProfiler:
# make it picklable
st.stream = None # type: ignore[attr-defined]
st.strip_dirs()
- self._accumulator.add({self._result_id: (st, None)})
+ self._accumulator.add({self._result_key: (st, None)})
def __enter__(self) -> "WorkerPerfProfiler":
self.start()
@@ -121,8 +134,8 @@ class WorkerMemoryProfiler:
def __init__(
self,
- accumulator: Accumulator["ProfileResults"],
- result_id: int,
+ accumulator: Accumulator[Optional["ProfileResults"]],
+ result_key: Union[int, str],
func_or_code: Union[Callable, CodeType],
) -> None:
from pyspark.profiler import UDFLineProfilerV2
@@ -133,7 +146,7 @@ class WorkerMemoryProfiler:
self._profiler.add_code(func_or_code)
else:
self._profiler.add_function(func_or_code)
- self._result_id = result_id
+ self._result_key = result_key
def start(self) -> None:
self._profiler.enable_by_count()
@@ -146,7 +159,7 @@ class WorkerMemoryProfiler:
filename: list(line_iterator)
for filename, line_iterator in self._profiler.code_map.items()
}
- self._accumulator.add({self._result_id: (None, codemap_dict)})
+ self._accumulator.add({self._result_key: (None, codemap_dict)})
def __enter__(self) -> "WorkerMemoryProfiler":
self.start()
@@ -173,7 +186,12 @@ class ProfilerCollector(ABC):
def __init__(self) -> None:
self._lock = RLock()
- def show_perf_profiles(self, id: Optional[int] = None) -> None:
+ def _sorted_keys(self, keys: Iterable[Union[int, str]]) -> list[Union[int,
str]]:
+ int_keys = sorted(x for x in keys if isinstance(x, int))
+ str_keys = sorted(x for x in keys if isinstance(x, str))
+ return str_keys + int_keys
+
+ def show_perf_profiles(self, id: Optional[Union[int, str]] = None) -> None:
"""
Show the perf profile results.
@@ -187,22 +205,25 @@ class ProfilerCollector(ABC):
with self._lock:
stats = self._perf_profile_results
- def show(id: int) -> None:
+ def show(id: Union[int, str]) -> None:
s = stats.get(id)
if s is not None:
print("=" * 60)
- print(f"Profile of UDF<id={id}>")
+ if isinstance(id, str):
+ print(f"Profile of {id}")
+ else:
+ print(f"Profile of UDF<id={id}>")
print("=" * 60)
s.sort_stats("time", "cumulative").print_stats()
if id is not None:
show(id)
else:
- for id in sorted(stats.keys()):
+ for id in self._sorted_keys(stats.keys()):
show(id)
@property
- def _perf_profile_results(self) -> Dict[int, pstats.Stats]:
+ def _perf_profile_results(self) -> Dict[Union[int, str], pstats.Stats]:
with self._lock:
return {
result_id: perf
@@ -210,7 +231,7 @@ class ProfilerCollector(ABC):
if perf is not None
}
- def show_memory_profiles(self, id: Optional[int] = None) -> None:
+ def show_memory_profiles(self, id: Optional[Union[int, str]] = None) ->
None:
"""
Show the memory profile results.
@@ -230,22 +251,25 @@ class ProfilerCollector(ABC):
UserWarning,
)
- def show(id: int) -> None:
+ def show(id: Union[int, str]) -> None:
cm = code_map.get(id)
if cm is not None:
print("=" * 60)
- print(f"Profile of UDF<id={id}>")
+ if isinstance(id, str):
+ print(f"Profile of {id}")
+ else:
+ print(f"Profile of UDF<id={id}>")
print("=" * 60)
MemoryProfiler._show_results(cm)
if id is not None:
show(id)
else:
- for id in sorted(code_map.keys()):
+ for id in self._sorted_keys(code_map.keys()):
show(id)
@property
- def _memory_profile_results(self) -> Dict[int, CodeMapDict]:
+ def _memory_profile_results(self) -> Dict[Union[int, str], CodeMapDict]:
with self._lock:
return {
result_id: mem
@@ -261,7 +285,7 @@ class ProfilerCollector(ABC):
"""
...
- def dump_perf_profiles(self, path: str, id: Optional[int] = None) -> None:
+ def dump_perf_profiles(self, path: str, id: Optional[Union[int, str]] =
None) -> None:
"""
Dump the perf profile results into directory `path`.
@@ -271,13 +295,13 @@ class ProfilerCollector(ABC):
----------
path: str
A directory in which to dump the perf profile.
- id : int, optional
+ id : int or str, optional
A UDF ID to be shown. If not specified, all the results will be
shown.
"""
with self._lock:
stats = self._perf_profile_results
- def dump(id: int) -> None:
+ def dump(id: Union[int, str]) -> None:
s = stats.get(id)
if s is not None:
@@ -288,10 +312,10 @@ class ProfilerCollector(ABC):
if id is not None:
dump(id)
else:
- for id in sorted(stats.keys()):
+ for id in self._sorted_keys(stats.keys()):
dump(id)
- def dump_memory_profiles(self, path: str, id: Optional[int] = None) ->
None:
+ def dump_memory_profiles(self, path: str, id: Optional[Union[int, str]] =
None) -> None:
"""
Dump the memory profile results into directory `path`.
@@ -301,7 +325,7 @@ class ProfilerCollector(ABC):
----------
path: str
A directory in which to dump the memory profile.
- id : int, optional
+ id : int or str, optional
A UDF ID to be shown. If not specified, all the results will be
shown.
"""
with self._lock:
@@ -313,7 +337,7 @@ class ProfilerCollector(ABC):
UserWarning,
)
- def dump(id: int) -> None:
+ def dump(id: Union[int, str]) -> None:
cm = code_map.get(id)
if cm is not None:
@@ -326,10 +350,10 @@ class ProfilerCollector(ABC):
if id is not None:
dump(id)
else:
- for id in sorted(code_map.keys()):
+ for id in self._sorted_keys(code_map.keys()):
dump(id)
- def clear_perf_profiles(self, id: Optional[int] = None) -> None:
+ def clear_perf_profiles(self, id: Optional[Union[int, str]] = None) ->
None:
"""
Clear the perf profile results.
@@ -337,7 +361,7 @@ class ProfilerCollector(ABC):
Parameters
----------
- id : int, optional
+ id : int or str, optional
The UDF ID whose profiling results should be cleared.
If not specified, all the results will be cleared.
"""
@@ -354,7 +378,7 @@ class ProfilerCollector(ABC):
if mem is None:
self._profile_results.pop(id, None)
- def clear_memory_profiles(self, id: Optional[int] = None) -> None:
+ def clear_memory_profiles(self, id: Optional[Union[int, str]] = None) ->
None:
"""
Clear the memory profile results.
@@ -362,7 +386,7 @@ class ProfilerCollector(ABC):
Parameters
----------
- id : int, optional
+ id : int or str, optional
The UDF ID whose profiling results should be cleared.
If not specified, all the results will be cleared.
"""
@@ -407,7 +431,7 @@ class Profile:
def __init__(self, profiler_collector: ProfilerCollector):
self.profiler_collector = profiler_collector
- def show(self, id: Optional[int] = None, *, type: Optional[str] = None) ->
None:
+ def show(self, id: Optional[Union[int, str]] = None, *, type:
Optional[str] = None) -> None:
"""
Show the profile results.
@@ -415,7 +439,7 @@ class Profile:
Parameters
----------
- id : int, optional
+ id : int or str, optional
A UDF ID to be shown. If not specified, all the results will be
shown.
type : str, optional
The profiler type, which can be either "perf" or "memory".
@@ -441,7 +465,9 @@ class Profile:
},
)
- def dump(self, path: str, id: Optional[int] = None, *, type: Optional[str]
= None) -> None:
+ def dump(
+ self, path: str, id: Optional[Union[int, str]] = None, *, type:
Optional[str] = None
+ ) -> None:
"""
Dump the profile results into directory `path`.
@@ -451,7 +477,7 @@ class Profile:
----------
path: str
A directory in which to dump the profile.
- id : int, optional
+ id : int or str, optional
A UDF ID to be shown. If not specified, all the results will be
shown.
type : str, optional
The profiler type, which can be either "perf" or "memory".
@@ -472,24 +498,34 @@ class Profile:
)
@overload
- def render(self, id: int, *, type: Optional[str] = None, renderer:
Optional[str] = None) -> Any:
+ def render(
+ self, id: Union[int, str], *, type: Optional[str] = None, renderer:
Optional[str] = None
+ ) -> Any:
...
@overload
def render(
- self, id: int, *, type: Optional[Literal["perf"]], renderer:
Callable[[pstats.Stats], Any]
+ self,
+ id: Union[int, str],
+ *,
+ type: Optional[Literal["perf"]],
+ renderer: Callable[[pstats.Stats], Any],
) -> Any:
...
@overload
def render(
- self, id: int, *, type: Literal["memory"], renderer:
Callable[[CodeMapDict], Any]
+ self,
+ id: Union[int, str],
+ *,
+ type: Literal["memory"],
+ renderer: Callable[[CodeMapDict], Any],
) -> Any:
...
def render(
self,
- id: int,
+ id: Union[int, str],
*,
type: Optional[str] = None,
renderer: Optional[
@@ -503,7 +539,7 @@ class Profile:
Parameters
----------
- id : int
+ id : int or str
The UDF ID whose profiling results should be rendered.
type : str, optional
The profiler type to render results for, which can be either
"perf" or "memory".
@@ -550,7 +586,7 @@ class Profile:
if result is not None:
return render(result) # type:ignore[arg-type]
- def clear(self, id: Optional[int] = None, *, type: Optional[str] = None)
-> None:
+ def clear(self, id: Optional[Union[int, str]] = None, *, type:
Optional[str] = None) -> None:
"""
Clear the profile results.
@@ -558,7 +594,7 @@ class Profile:
Parameters
----------
- id : int, optional
+ id : int or str, optional
The UDF ID whose profiling results should be cleared.
If not specified, all the results will be cleared.
type : str, optional
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index a680fbdb8ef0..eefabb3e7ea0 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import contextlib
+import io
import os
import platform
import tempfile
@@ -26,6 +28,7 @@ from decimal import Decimal
from typing import Callable, Iterable, List, Union, Iterator, Tuple
from pyspark.errors import AnalysisException, PythonException
+from pyspark.profiler import has_memory_profiler
from pyspark.sql.datasource import (
CaseInsensitiveDict,
DataSource,
@@ -1260,6 +1263,47 @@ class BasePythonDataSourceTestsMixin:
],
)
+ def test_data_source_perf_profiler(self):
+ with self.sql_conf({"spark.sql.pyspark.dataSource.profiler": "perf"}):
+ self.test_custom_json_data_source_read()
+ with contextlib.redirect_stdout(io.StringIO()) as stdout_io:
+ self.spark.profile.show(type="perf")
+ self.spark.profile.clear()
+ stdout = stdout_io.getvalue()
+ self.assertIn("Profile of create_data_source", stdout)
+ self.assertIn("Profile of plan_data_source_read", stdout)
+ self.assertIn("ncalls", stdout)
+ self.assertIn("tottime", stdout)
+ # We should also found UDF profile results for data source read
+ self.assertIn("UDF<id=", stdout)
+
+ @unittest.skipIf(
+ "COVERAGE_PROCESS_START" in os.environ, "Fails with coverage enabled,
skipping for now."
+ )
+ @unittest.skipIf(not has_memory_profiler, "Must have memory-profiler
installed.")
+ def test_data_source_memory_profiler(self):
+ with self.sql_conf({"spark.sql.pyspark.dataSource.profiler":
"memory"}):
+ self.test_custom_json_data_source_read()
+ with contextlib.redirect_stdout(io.StringIO()) as stdout_io:
+ self.spark.profile.show(type="memory")
+ self.spark.profile.clear()
+ stdout = stdout_io.getvalue()
+ self.assertIn("Profile of create_data_source", stdout)
+ self.assertIn("Profile of plan_data_source_read", stdout)
+ self.assertIn("Mem usage", stdout)
+ # We should also found UDF profile results for data source read
+ self.assertIn("UDF<id=", stdout)
+
+ def test_data_source_read_with_udf_perf_profiler(self):
+ """udf profiler config should not enable data source profiling"""
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+ self.test_custom_json_data_source_read()
+ with contextlib.redirect_stdout(io.StringIO()) as stdout_io:
+ self.spark.profile.show(type="perf")
+ self.spark.profile.clear()
+ stdout = stdout_io.getvalue()
+ self.assertEqual(stdout, "")
+
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py
b/python/pyspark/sql/tests/test_udf_profiler.py
index 4389559c40d8..759aa9c68363 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -28,7 +28,6 @@ from typing import Iterator
from pyspark import SparkConf
from pyspark.errors import PySparkValueError
from pyspark.sql import SparkSession
-from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.functions import col, arrow_udf, pandas_udf, udf
from pyspark.sql.window import Window
from pyspark.profiler import UDFBasicProfiler
@@ -663,35 +662,6 @@ class UDFProfiler2TestsMixin:
for id in self.profile_results:
self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=2)
- @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
- def test_perf_profiler_data_source(self):
- class TestDataSourceReader(DataSourceReader):
- def __init__(self, schema):
- self.schema = schema
-
- def partitions(self):
- raise NotImplementedError
-
- def read(self, partition):
- yield from ((1,), (2,), (3,))
-
- class TestDataSource(DataSource):
- def schema(self):
- return "id long"
-
- def reader(self, schema) -> "DataSourceReader":
- return TestDataSourceReader(schema)
-
- self.spark.dataSource.register(TestDataSource)
-
- with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
- self.spark.read.format("TestDataSource").load().collect()
-
- self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
-
- for id in self.profile_results:
- self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=4)
-
def test_perf_profiler_render(self):
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
_do_computation(self.spark)
diff --git a/python/pyspark/sql/worker/utils.py
b/python/pyspark/sql/worker/utils.py
index bd5c6ffda9ee..8a99abe3e4e9 100644
--- a/python/pyspark/sql/worker/utils.py
+++ b/python/pyspark/sql/worker/utils.py
@@ -17,9 +17,14 @@
import os
import sys
-from typing import Callable, IO
+from typing import Callable, IO, Optional
-from pyspark.accumulators import _accumulatorRegistry
+from pyspark.accumulators import (
+ _accumulatorRegistry,
+ _deserialize_accumulator,
+ SpecialAccumulatorIds,
+)
+from pyspark.sql.profiler import ProfileResultsParam, WorkerPerfProfiler,
WorkerMemoryProfiler
from pyspark.serializers import (
read_int,
write_int,
@@ -36,9 +41,16 @@ from pyspark.worker_util import (
setup_memory_limits,
setup_spark_files,
setup_broadcasts,
+ Conf,
)
+class RunnerConf(Conf):
+ @property
+ def profiler(self) -> Optional[str]:
+ return self.get("spark.sql.pyspark.dataSource.profiler", None)
+
+
@with_faulthandler
def worker_run(main: Callable, infile: IO, outfile: IO) -> None:
try:
@@ -51,10 +63,22 @@ def worker_run(main: Callable, infile: IO, outfile: IO) ->
None:
setup_spark_files(infile)
setup_broadcasts(infile)
+ conf = RunnerConf(infile)
_accumulatorRegistry.clear()
+ accumulator = _deserialize_accumulator(
+ SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
+ )
- main(infile, outfile)
+ worker_module = main.__module__.split(".")[-1]
+ if conf.profiler == "perf":
+ with WorkerPerfProfiler(accumulator, worker_module):
+ main(infile, outfile)
+ elif conf.profiler == "memory":
+ with WorkerMemoryProfiler(accumulator, worker_module, main):
+ main(infile, outfile)
+ else:
+ main(infile, outfile)
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 03bc1366e875..238868fbced1 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -102,6 +102,7 @@ from pyspark.worker_util import (
setup_memory_limits,
setup_spark_files,
utf8_deserializer,
+ Conf,
)
from pyspark.logger.worker_io import capture_outputs
@@ -113,37 +114,7 @@ except Exception:
has_memory_profiler = False
-class RunnerConf:
- def __init__(self, infile=None):
- self._conf = {}
- if infile is not None:
- self.load(infile)
-
- def load(self, infile):
- num_conf = read_int(infile)
- # We do a sanity check here to reduce the possibility to stuck
indefinitely
- # due to an invalid messsage. If the numer of configurations is
obviously
- # wrong, we just raise an error directly.
- # We hand-pick the configurations to send to the worker so the number
should
- # be very small (less than 100).
- if num_conf < 0 or num_conf > 10000:
- raise PySparkRuntimeError(
- errorClass="PROTOCOL_ERROR",
- messageParameters={
- "failure": f"Invalid number of configurations: {num_conf}",
- },
- )
- for _ in range(num_conf):
- k = utf8_deserializer.loads(infile)
- v = utf8_deserializer.loads(infile)
- self._conf[k] = v
-
- def get(self, key: str, default=""):
- val = self._conf.get(key, default)
- if isinstance(val, str):
- return val.lower()
- return val
-
+class RunnerConf(Conf):
@property
def assign_cols_by_name(self) -> bool:
return (
@@ -201,9 +172,13 @@ class RunnerConf:
return
int(self.get("spark.sql.execution.pythonUDF.arrow.concurrency.level", -1))
@property
- def profiler(self) -> Optional[str]:
+ def udf_profiler(self) -> Optional[str]:
return self.get("spark.sql.pyspark.udf.profiler", None)
+ @property
+ def data_source_profiler(self) -> Optional[str]:
+ return self.get("spark.sql.pyspark.dataSource.profiler", None)
+
def report_times(outfile, boot, init, finish, processing_time_ms):
write_int(SpecialLengths.TIMING_DATA, outfile)
@@ -1424,7 +1399,12 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
result_id = read_long(infile)
- profiler = runner_conf.profiler
+ # If chained_func is from pyspark.sql.worker, it is to read/write data
source.
+ # In this case, we check the data_source_profiler config.
+ if getattr(chained_func, "__module__",
"").startswith("pyspark.sql.worker."):
+ profiler = runner_conf.data_source_profiler
+ else:
+ profiler = runner_conf.udf_profiler
if profiler == "perf":
profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id)
elif profiler == "memory":
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index ccca0630d2af..6c6db7e2a53d 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -22,7 +22,7 @@ import importlib
from inspect import currentframe, getframeinfo
import os
import sys
-from typing import Any, IO
+from typing import Any, IO, Optional
import warnings
# 'resource' is a Unix specific module.
@@ -194,3 +194,35 @@ def send_accumulator_updates(outfile: IO) -> None:
write_int(len(_accumulatorRegistry), outfile)
for aid, accum in _accumulatorRegistry.items():
pickleSer._write_with_length((aid, accum._value), outfile)
+
+
+class Conf:
+ def __init__(self, infile: Optional[IO] = None) -> None:
+ self._conf: dict[str, Any] = {}
+ if infile is not None:
+ self.load(infile)
+
+ def load(self, infile: IO) -> None:
+ num_conf = read_int(infile)
+ # We do a sanity check here to reduce the possibility to stuck
indefinitely
+ # due to an invalid messsage. If the numer of configurations is
obviously
+ # wrong, we just raise an error directly.
+ # We hand-pick the configurations to send to the worker so the number
should
+ # be very small (less than 100).
+ if num_conf < 0 or num_conf > 10000:
+ raise PySparkRuntimeError(
+ errorClass="PROTOCOL_ERROR",
+ messageParameters={
+ "failure": f"Invalid number of configurations: {num_conf}",
+ },
+ )
+ for _ in range(num_conf):
+ k = utf8_deserializer.loads(infile)
+ v = utf8_deserializer.loads(infile)
+ self._conf[k] = v
+
+ def get(self, key: str, default: Any = "") -> Any:
+ val = self._conf.get(key, default)
+ if isinstance(val, str):
+ return val.lower()
+ return val
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index cd892936c9b7..1bed0b3bebc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4086,6 +4086,17 @@ object SQLConf {
// show full stacktrace in tests but hide in production by default.
.createWithDefault(Utils.isTesting)
+ val PYTHON_DATA_SOURCE_PROFILER =
+ buildConf("spark.sql.pyspark.dataSource.profiler")
+ .doc("Configure the Python Data Source profiler by enabling or disabling
it " +
+ "with the option to choose between \"perf\" and \"memory\" types, " +
+ "or unsetting the config disables the profiler. This is disabled by
default.")
+ .version("4.2.0")
+ .stringConf
+ .transform(_.toLowerCase(Locale.ROOT))
+ .checkValues(Set("perf", "memory"))
+ .createOptional
+
val PYTHON_UDF_PROFILER =
buildConf("spark.sql.pyspark.udf.profiler")
.doc("Configure the Python/Pandas UDF profiler by enabling or disabling
it " +
@@ -7719,6 +7730,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def pythonUDFProfiler: Option[String] = getConf(PYTHON_UDF_PROFILER)
+ def pythonDataSourceProfiler: Option[String] =
getConf(PYTHON_DATA_SOURCE_PROFILER)
+
def pythonUDFWorkerFaulthandlerEnabled: Boolean =
getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)
def pythonUDFWorkerIdleTimeoutSeconds: Long =
getConf(PYTHON_UDF_WORKER_IDLE_TIMEOUT_SECONDS)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
index b04ebe92910a..3d814606f531 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingSinkCommitRunner.scala
@@ -25,6 +25,7 @@ import org.apache.spark.api.python.{PythonFunction,
PythonWorkerUtils, SpecialLe
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.python.PythonPlannerRunner
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
/**
@@ -43,6 +44,12 @@ class PythonStreamingSinkCommitRunner(
abort: Boolean) extends PythonPlannerRunner[Unit](dataSourceCls) {
override val workerModule: String =
"pyspark.sql.worker.python_streaming_sink_runner"
+ override protected def runnerConf: Map[String, String] = {
+ super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+ Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Map.empty)
+ }
+
override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
// Send the user function to python process.
PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index c147030037cd..7611cf676499 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -232,6 +232,12 @@ private class
UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunct
override val workerModule = "pyspark.sql.worker.lookup_data_sources"
+ override protected def runnerConf: Map[String, String] = {
+ super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+ Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Map.empty)
+ }
+
override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
// No input needed.
}
@@ -282,6 +288,12 @@ private class UserDefinedPythonDataSourceRunner(
override val workerModule = "pyspark.sql.worker.create_data_source"
+ override protected def runnerConf: Map[String, String] = {
+ super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+ Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Map.empty)
+ }
+
override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
// Send python data source
PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
@@ -446,6 +458,12 @@ private class
UserDefinedPythonDataSourceFilterPushdownRunner(
// See the logic in `pyspark.sql.worker.data_source_pushdown_filters.py`.
override val workerModule = "pyspark.sql.worker.data_source_pushdown_filters"
+ override protected def runnerConf: Map[String, String] = {
+ super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+ Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Map.empty)
+ }
+
def isAnyFilterSupported: Boolean = serializedFilters.nonEmpty
override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
@@ -538,6 +556,12 @@ private class UserDefinedPythonDataSourceReadRunner(
// See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
override val workerModule = "pyspark.sql.worker.plan_data_source_read"
+ override protected def runnerConf: Map[String, String] = {
+ super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+ Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Map.empty)
+ }
+
override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
// Send Python data source
PythonWorkerUtils.writePythonFunction(func, dataOut)
@@ -580,6 +604,12 @@ private class UserDefinedPythonDataSourceWriteRunner(
override val workerModule: String =
"pyspark.sql.worker.write_into_data_source"
+ override protected def runnerConf: Map[String, String] = {
+ super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+ Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Map.empty)
+ }
+
override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
// Send the Python data source class.
PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
@@ -639,6 +669,12 @@ private class UserDefinedPythonDataSourceCommitRunner(
abort: Boolean) extends PythonPlannerRunner[Unit](dataSourceCls) {
override val workerModule: String =
"pyspark.sql.worker.commit_data_source_write"
+ override protected def runnerConf: Map[String, String] = {
+ super.runnerConf ++ SQLConf.get.pythonDataSourceProfiler.map(p =>
+ Map(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Map.empty)
+ }
+
override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
// Send the Python data source writer.
PythonWorkerUtils.writeBytes(writer, dataOut)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 7a12dbd556bf..a5536621c531 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -168,12 +168,16 @@ object ArrowPythonRunner {
val binaryAsBytes = Seq(
SQLConf.PYSPARK_BINARY_AS_BYTES.key ->
conf.pysparkBinaryAsBytes.toString)
- val profiler = conf.pythonUDFProfiler.map(p =>
+ val udfProfiler = conf.pythonUDFProfiler.map(p =>
Seq(SQLConf.PYTHON_UDF_PROFILER.key -> p)
).getOrElse(Seq.empty)
+ val dataSourceProfiler = conf.pythonDataSourceProfiler.map(p =>
+ Seq(SQLConf.PYTHON_DATA_SOURCE_PROFILER.key -> p)
+ ).getOrElse(Seq.empty)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++
arrowAyncParallelism ++ useLargeVarTypes ++
intToDecimalCoercion ++ binaryAsBytes ++
- legacyPandasConversion ++ legacyPandasConversionUDF ++ profiler: _*)
+ legacyPandasConversion ++ legacyPandasConversionUDF ++
+ udfProfiler ++ dataSourceProfiler: _*)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
index 92e99cdc11d9..dd1a869ddf7d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -44,6 +44,8 @@ abstract class PythonPlannerRunner[T](func: PythonFunction)
extends Logging {
protected val workerModule: String
+ protected def runnerConf: Map[String, String] = Map.empty
+
protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler):
Unit
protected def receiveFromPython(dataIn: DataInputStream): T
@@ -123,6 +125,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction)
extends Logging {
PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes,
dataOut)
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
+ PythonWorkerUtils.writeConf(runnerConf, dataOut)
writeToPython(dataOut, pickler)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]