This is an automated email from the ASF dual-hosted git repository.

joshrosen pushed a commit to branch python-udf-accumulator
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 9213a85a40499fc7f0e24ea14c5051c45a022ef2
Author: Josh Rosen <joshro...@databricks.com>
AuthorDate: Wed Oct 20 16:17:44 2021 -0700

    hacky wip towards python udf profiling
---
 python/pyspark/profiler.py |  1 +
 python/pyspark/sql/udf.py  | 15 ++++++++++++---
 python/pyspark/worker.py   | 37 +++++++++++++++++++------------------
 3 files changed, 32 insertions(+), 21 deletions(-)

diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
index 99cfe71..8d8458c 100644
--- a/python/pyspark/profiler.py
+++ b/python/pyspark/profiler.py
@@ -142,6 +142,7 @@ class PStatsParam(AccumulatorParam):
 
     @staticmethod
     def addInPlace(value1, value2):
+        print("ACCUM UPDATE PARAM")
         if value1 is None:
             return value2
         value1.add(value2)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 752ccca..164588f 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -24,6 +24,7 @@ from typing import Callable, Any, TYPE_CHECKING, Optional, 
cast, Union
 from py4j.java_gateway import JavaObject
 
 from pyspark import SparkContext
+from pyspark.profiler import Profiler
 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType  # type: 
ignore[attr-defined]
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.types import (  # type: ignore[attr-defined]
@@ -44,9 +45,10 @@ __all__ = ["UDFRegistration"]
 def _wrap_function(
     sc: SparkContext,
     func: Callable[..., Any],
-    returnType: "DataTypeOrString"
+    returnType: "DataTypeOrString",
+    profiler: Optional[Profiler] = None
 ) -> JavaObject:
-    command = (func, returnType)
+    command = (func, returnType, profiler)
     pickled_command, broadcast_vars, env, includes = 
_prepare_for_python_RDD(sc, command)
     return sc._jvm.PythonFunction(  # type: ignore[attr-defined]
         bytearray(pickled_command),
@@ -199,7 +201,14 @@ class UserDefinedFunction(object):
         spark = SparkSession.builder.getOrCreate()
         sc = spark.sparkContext
 
-        wrapped_func = _wrap_function(sc, self.func, self.returnType)
+        if sc.profiler_collector:
+            profiler = sc.profiler_collector.new_profiler(sc)
+                       # TODO: better ID
+            sc.profiler_collector.add_profiler(0, profiler)
+        else:
+            profiler = None
+
+        wrapped_func = _wrap_function(sc, self.func, self.returnType, profiler)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = (
             sc._jvm.org.apache.spark.sql.execution.python  # type: 
ignore[attr-defined]
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index ad6c003..8f13822 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -247,8 +247,9 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     num_arg = read_int(infile)
     arg_offsets = [read_int(infile) for i in range(num_arg)]
     chained_func = None
+    profiler = None
     for i in range(read_int(infile)):
-        f, return_type = read_command(pickleSer, infile)
+        f, return_type, profiler = read_command(pickleSer, infile)
         if chained_func is None:
             chained_func = f
         else:
@@ -263,28 +264,29 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
 
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
-        return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
+        return arg_offsets, profiler, wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_pandas_iter_udf(func, return_type)
+        return arg_offsets, profiler, wrap_pandas_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_pandas_iter_udf(func, return_type)
+        return arg_offsets, profiler, wrap_pandas_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec)
+        return arg_offsets, profiler, wrap_grouped_map_pandas_udf(func, 
return_type, argspec)
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec)
+        return arg_offsets, profiler, wrap_cogrouped_map_pandas_udf(func, 
return_type, argspec)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
-        return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
+        return arg_offsets, profiler, wrap_grouped_agg_pandas_udf(func, 
return_type)
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
-        return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, 
runner_conf, udf_index)
+        return arg_offsets, profiler, wrap_window_agg_pandas_udf(func, 
return_type, runner_conf, udf_index)
     elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
-        return arg_offsets, wrap_udf(func, return_type)
+        return arg_offsets, profiler, wrap_udf(func, return_type)
     else:
         raise ValueError("Unknown eval type: {}".format(eval_type))
 
 
 def read_udfs(pickleSer, infile, eval_type):
+    profiler = None
     runner_conf = {}
 
     if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
@@ -335,7 +337,7 @@ def read_udfs(pickleSer, infile, eval_type):
         if is_map_iter:
             assert num_udfs == 1, "One MAP_ITER UDF expected here."
 
-        arg_offsets, udf = read_single_udf(
+        arg_offsets, profiler, udf = read_single_udf(
             pickleSer, infile, eval_type, runner_conf, udf_index=0)
 
         def func(_, iterator):
@@ -381,8 +383,7 @@ def read_udfs(pickleSer, infile, eval_type):
                         "the same with the input's; however, the length of 
output was %d and the "
                         "length of input was %d." % (num_output_rows, 
num_input_rows))
 
-        # profiling is not supported for UDF
-        return func, None, ser, ser
+        return func, profiler, ser, ser
 
     def extract_key_value_indexes(grouped_arg_offsets):
         """
@@ -420,7 +421,7 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # See FlatMapGroupsInPandasExec for how arg_offsets are used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, profiler, f = read_single_udf(pickleSer, infile, 
eval_type, runner_conf, udf_index=0)
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
         # Create function like this:
@@ -433,7 +434,7 @@ def read_udfs(pickleSer, infile, eval_type):
         # We assume there is only one UDF here because cogrouped map doesn't
         # support combining multiple UDFs.
         assert num_udfs == 1
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, profiler, f = read_single_udf(pickleSer, infile, 
eval_type, runner_conf, udf_index=0)
 
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
@@ -446,7 +447,8 @@ def read_udfs(pickleSer, infile, eval_type):
     else:
         udfs = []
         for i in range(num_udfs):
-            udfs.append(read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=i))
+            arg_offsets, profiler, f = read_single_udf(pickleSer, infile, 
eval_type, runner_conf, udf_index=i)
+            udfs.append((arg_offsets, f))
 
         def mapper(a):
             result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, 
f) in udfs)
@@ -459,8 +461,7 @@ def read_udfs(pickleSer, infile, eval_type):
 
     func = lambda _, it: map(mapper, it)
 
-    # profiling is not supported for UDF
-    return func, None, ser, ser
+    return func, profiler, ser, ser
 
 
 def main(infile, outfile):
@@ -599,7 +600,7 @@ def main(infile, outfile):
         _accumulatorRegistry.clear()
         eval_type = read_int(infile)
         if eval_type == PythonEvalType.NON_UDF:
-            func, profiler, deserializer, serializer = read_command(pickleSer, 
infile)
+            func, rofiler, deserializer, serializer = read_command(pickleSer, 
infile)
         else:
             func, profiler, deserializer, serializer = read_udfs(pickleSer, 
infile, eval_type)
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to