This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f7ce07aadd64 [SPARK-55171][PYTHON] Fix memory profiler on iter UDF
f7ce07aadd64 is described below

commit f7ce07aadd64cbb65698fa0cc413eb36d36ceb20
Author: Tian Gao <[email protected]>
AuthorDate: Mon Jan 26 08:09:31 2026 +0900

    [SPARK-55171][PYTHON] Fix memory profiler on iter UDF
    
    ### What changes were proposed in this pull request?
    
    For iter based UDF, we use the function `f` directly for memory profiler to 
track. However, function `f` might not be the function that we need to track. 
It could just return another generator. We should use the return value from 
`f()` and use that code object.
    
    ### Why are the changes needed?
    
    For Python data source, we can't track the correct function. It's also 
possible that users use similar structure which we can't track.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    A new test is added which failed before fix and passed after.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53954 from gaogaotiantian/fix-iter-memory-profiler.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/profiler.py                   |  5 +++++
 python/pyspark/sql/profiler.py               | 12 +++++++++---
 python/pyspark/tests/test_memory_profiler.py | 22 ++++++++++++++++++++++
 python/pyspark/worker.py                     |  5 +++--
 4 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
index f4fc83e66dd5..1e9e398b25f8 100644
--- a/python/pyspark/profiler.py
+++ b/python/pyspark/profiler.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+from types import CodeType
 from typing import (
     Any,
     Callable,
@@ -278,6 +279,10 @@ if has_memory_profiler:
             backend = kw.get("backend", "psutil")
             self.code_map = CodeMapForUDFV2(include_children=include_children, 
backend=backend)
 
+        def add_code(self, code: CodeType) -> None:
+            """Record line profiling information for the given code object."""
+            self.code_map.add(code)
+
 
 class PStatsParam(AccumulatorParam[Optional[pstats.Stats]]):
     """PStatsParam is used to merge pstats.Stats"""
diff --git a/python/pyspark/sql/profiler.py b/python/pyspark/sql/profiler.py
index 8455aacafc45..2aee60eeb41d 100644
--- a/python/pyspark/sql/profiler.py
+++ b/python/pyspark/sql/profiler.py
@@ -20,7 +20,7 @@ import cProfile
 import os
 import pstats
 from threading import RLock
-from types import TracebackType
+from types import CodeType, TracebackType
 from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union, 
TYPE_CHECKING, overload
 import warnings
 
@@ -120,13 +120,19 @@ class WorkerMemoryProfiler:
     """
 
     def __init__(
-        self, accumulator: Accumulator["ProfileResults"], result_id: int, 
func: Callable
+        self,
+        accumulator: Accumulator["ProfileResults"],
+        result_id: int,
+        func_or_code: Union[Callable, CodeType],
     ) -> None:
         from pyspark.profiler import UDFLineProfilerV2
 
         self._accumulator = accumulator
         self._profiler = UDFLineProfilerV2()
-        self._profiler.add_function(func)
+        if isinstance(func_or_code, CodeType):
+            self._profiler.add_code(func_or_code)
+        else:
+            self._profiler.add_function(func_or_code)
         self._result_id = result_id
 
     def start(self) -> None:
diff --git a/python/pyspark/tests/test_memory_profiler.py 
b/python/pyspark/tests/test_memory_profiler.py
index 5a77c751e6a0..c23ea5e5bc32 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -384,6 +384,28 @@ class MemoryProfiler2TestsMixin:
         for id in self.profile_results:
             self.assert_udf_memory_profile_present(udf_id=id)
 
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_memory_profiler_different_function(self):
+        df = self.spark.createDataFrame([(1,), (2,), (3,)], ["x"])
+
+        def ident(batches):
+            for b in batches:
+                yield b
+
+        def func(batches):
+            return ident(batches)
+
+        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
+            df.mapInArrow(func, schema="y long").show()
+
+        self.assertEqual(1, len(self.profile_results))
+
+        for id in self.profile_results:
+            self.assert_udf_memory_profile_present(udf_id=id)
+
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
         pandas_requirement_message or pyarrow_requirement_message,
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 2e79c981d818..d093beffda95 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1376,11 +1376,12 @@ def wrap_memory_profiler(f, eval_type, result_id):
     if _is_iter_based(eval_type):
 
         def profiling_func(*args, **kwargs):
-            iterator = iter(f(*args, **kwargs))
+            g = f(*args, **kwargs)
+            iterator = iter(g)
 
             while True:
                 try:
-                    with WorkerMemoryProfiler(accumulator, result_id, f):
+                    with WorkerMemoryProfiler(accumulator, result_id, 
g.gi_code):
                         item = next(iterator)
                     yield item
                 except StopIteration:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to