This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e966c387f62 [SPARK-34265][PYTHON][SQL] Instrument Python UDFs using 
SQL metrics
e966c387f62 is described below

commit e966c387f624da3ece6e507f95f00cf20b42f45e
Author: Luca Canali <luca.can...@cern.ch>
AuthorDate: Mon Oct 24 17:04:37 2022 +0800

    [SPARK-34265][PYTHON][SQL] Instrument Python UDFs using SQL metrics
    
    ### What changes are proposed in this pull request?
    
    This proposes to add SQLMetrics instrumentation for Python UDF execution, 
including Pandas UDF, and related operations such as MapInPandas and MapInArrow.
    The proposed metrics are:
    
    - data sent to Python workers
    - data returned from Python workers
    - number of output rows
    
    ### Why are the changes needed?
    This aims at improving monitoring and performance troubleshooting of Python 
UDFs.
    In particular it is intended as an aid to answer performance-related 
questions such as:
    why is the UDF slow?, how much work has been done so far?, etc.
    
    ### Does this PR introduce _any_ user-facing change?
    SQL metrics are made available in the WEB UI.
    See the following examples:
    
    
![image1](https://issues.apache.org/jira/secure/attachment/13038693/PandasUDF_ArrowEvalPython_Metrics.png)
    
    ### How was this patch tested?
    
    Manually tested + a Python unit test and a Scala unit test have been added.
    
    Example code used for testing:
    
    ```
    from pyspark.sql.functions import col, pandas_udf
    import time
    
    pandas_udf("long")
    def test_pandas(col1):
      time.sleep(0.02)
      return col1 * col1
    
    spark.udf.register("test_pandas", test_pandas)
    spark.sql("select rand(42)*rand(51)*rand(12) col1 from 
range(10000000)").createOrReplaceTempView("t1")
    spark.sql("select max(test_pandas(col1)) from t1").collect()
    ```
    
    This is used to test with more data pushed to the Python workers:
    
    ```
    from pyspark.sql.functions import col, pandas_udf
    import time
    
    pandas_udf("long")
    def 
test_pandas(col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15,col16,col17):
      time.sleep(0.02)
      return col1
    
    spark.udf.register("test_pandas", test_pandas)
    spark.sql("select rand(42)*rand(51)*rand(12) col1 from 
range(10000000)").createOrReplaceTempView("t1")
    spark.sql("select 
max(test_pandas(col1,col1+1,col1+2,col1+3,col1+4,col1+5,col1+6,col1+7,col1+8,col1+9,col1+10,col1+11,col1+12,col1+13,col1+14,col1+15,col1+16))
 from t1").collect()
    ```
    
    This (from the Spark doc) has been used to test with MapInPandas, where the 
number of output rows is different from the number of input rows:
    
    ```
    import pandas as pd
    
    from pyspark.sql.functions import pandas_udf, PandasUDFType
    
    df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
    
    def filter_func(iterator):
        for pdf in iterator:
            yield pdf[pdf.id == 1]
    
    df.mapInPandas(filter_func, schema=df.schema).show()
    ```
    
    This for testing BatchEvalPython and metrics related to data transfer 
(bytes sent and received):
    ```
    from pyspark.sql.functions import udf
    
    udf
    def test_udf(col1, col2):
         return col1 * col1
    
    spark.sql("select id, 
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
 col2 from range(10)").select(test_udf("id", "col2")).collect()
    ```
    
    Closes #33559 from LucaCanali/pythonUDFKeySQLMetrics.
    
    Authored-by: Luca Canali <luca.can...@cern.ch>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 dev/sparktestsupport/modules.py                    |  1 +
 docs/web-ui.md                                     |  2 +
 python/pyspark/sql/tests/test_pandas_sqlmetrics.py | 68 ++++++++++++++++++++++
 .../execution/python/AggregateInPandasExec.scala   |  5 +-
 .../ApplyInPandasWithStatePythonRunner.scala       |  7 ++-
 .../sql/execution/python/ArrowEvalPythonExec.scala |  5 +-
 .../sql/execution/python/ArrowPythonRunner.scala   |  4 +-
 .../sql/execution/python/BatchEvalPythonExec.scala |  6 +-
 .../python/CoGroupedArrowPythonRunner.scala        |  8 ++-
 .../python/FlatMapCoGroupsInPandasExec.scala       |  6 +-
 .../python/FlatMapGroupsInPandasExec.scala         |  5 +-
 .../FlatMapGroupsInPandasWithStateExec.scala       |  6 +-
 .../sql/execution/python/MapInBatchExec.scala      |  5 +-
 .../sql/execution/python/PythonArrowInput.scala    |  6 ++
 .../sql/execution/python/PythonArrowOutput.scala   |  8 +++
 .../sql/execution/python/PythonSQLMetrics.scala    | 35 +++++++++++
 .../sql/execution/python/PythonUDFRunner.scala     | 10 +++-
 .../sql/execution/python/WindowInPandasExec.scala  |  5 +-
 .../execution/streaming/statefulOperators.scala    |  5 +-
 .../sql/execution/python/PythonUDFSuite.scala      | 19 ++++++
 20 files changed, 193 insertions(+), 23 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2a427139148..a439b4cbbed 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -484,6 +484,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.pandas.test_pandas_udf_typehints",
         
"pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations",
         "pyspark.sql.tests.pandas.test_pandas_udf_window",
+        "pyspark.sql.tests.test_pandas_sqlmetrics",
         "pyspark.sql.tests.test_readwriter",
         "pyspark.sql.tests.test_serde",
         "pyspark.sql.tests.test_session",
diff --git a/docs/web-ui.md b/docs/web-ui.md
index d3356ec5a43..e228d7fe2a9 100644
--- a/docs/web-ui.md
+++ b/docs/web-ui.md
@@ -406,6 +406,8 @@ Here is the list of SQL metrics:
 <tr><td> <code>time to build hash map</code> </td><td> the time spent on 
building hash map </td><td> ShuffledHashJoin </td></tr>
 <tr><td> <code>task commit time</code> </td><td> the time spent on committing 
the output of a task after the writes succeed </td><td> any write operation on 
a file-based table </td></tr>
 <tr><td> <code>job commit time</code> </td><td> the time spent on committing 
the output of a job after the writes succeed </td><td> any write operation on a 
file-based table </td></tr>
+<tr><td> <code>data sent to Python workers</code> </td><td> the number of 
bytes of serialized data sent to the Python workers </td><td> ArrowEvalPython, 
AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, 
FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, 
PythonMapInArrow, WindowsInPandas </td></tr>
+<tr><td> <code>data returned from Python workers</code> </td><td> the number 
of bytes of serialized data received back from the Python workers </td><td> 
ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, 
FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, 
PythonMapInArrow, WindowsInPandas </td></tr>
 </table>
 
 ## Structured Streaming Tab
diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py 
b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
new file mode 100644
index 00000000000..d182bafd8b5
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import cast
+
+from pyspark.sql.functions import pandas_udf
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+
+@unittest.skipIf(
+    not have_pandas or not have_pyarrow,
+    cast(str, pandas_requirement_message or pyarrow_requirement_message),
+)
+class PandasSQLMetrics(ReusedSQLTestCase):
+    def test_pandas_sql_metrics_basic(self):
+        # SPARK-34265: Instrument Python UDFs using SQL metrics
+
+        python_sql_metrics = [
+            "data sent to Python workers",
+            "data returned from Python workers",
+            "number of output rows",
+        ]
+
+        @pandas_udf("long")
+        def test_pandas(col1):
+            return col1 * col1
+
+        self.spark.range(10).select(test_pandas("id")).collect()
+
+        statusStore = self.spark._jsparkSession.sharedState().statusStore()
+        lastExecId = statusStore.executionsList().last().executionId()
+        executionMetrics = 
statusStore.execution(lastExecId).get().metrics().mkString()
+
+        for metric in python_sql_metrics:
+            self.assertIn(metric, executionMetrics)
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_pandas_sqlmetrics import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 2f85149ee8e..6a8b197742d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -46,7 +46,7 @@ case class AggregateInPandasExec(
     udfExpressions: Seq[PythonUDF],
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends UnaryExecNode {
+  extends UnaryExecNode with PythonSQLMetrics {
 
   override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
 
@@ -163,7 +163,8 @@ case class AggregateInPandasExec(
         argOffsets,
         aggInputSchema,
         sessionLocalTimeZone,
-        pythonRunnerConf).compute(projectedRowIter, context.partitionId(), 
context)
+        pythonRunnerConf,
+        pythonMetrics).compute(projectedRowIter, context.partitionId(), 
context)
 
       val joinedAttributes =
         groupingExpressions.map(_.toAttribute) ++ 
udfExpressions.map(_.resultAttribute)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index bd8c72029dc..f3531668c8e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.metric.SQLMetric
 import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType,
 OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
 import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
 import org.apache.spark.sql.execution.streaming.GroupStateImpl
@@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner(
     stateEncoder: ExpressionEncoder[Row],
     keySchema: StructType,
     outputSchema: StructType,
-    stateValueSchema: StructType)
+    stateValueSchema: StructType,
+    val pythonMetrics: Map[String, SQLMetric])
   extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
   with PythonArrowInput[InType]
   with PythonArrowOutput[OutType] {
@@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner(
     val w = new ApplyInPandasWithStateWriter(root, writer, 
arrowMaxRecordsPerBatch)
 
     while (inputIterator.hasNext) {
+      val startData = dataOut.size()
       val (keyRow, groupState, dataIter) = inputIterator.next()
       assert(dataIter.hasNext, "should have at least one data row!")
       w.startNewGroup(keyRow, groupState)
@@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner(
       }
 
       w.finalizeGroup()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
     }
 
     w.finalizeData()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 096712cf935..b11dd4947af 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T], 
batchSize: Int)
  */
 case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: 
Seq[Attribute], child: SparkPlan,
     evalType: Int)
-  extends EvalPythonExec {
+  extends EvalPythonExec with PythonSQLMetrics {
 
   private val batchSize = conf.arrowMaxRecordsPerBatch
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
@@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
       argOffsets,
       schema,
       sessionLocalTimeZone,
-      pythonRunnerConf).compute(batchIter, context.partitionId(), context)
+      pythonRunnerConf,
+      pythonMetrics).compute(batchIter, context.partitionId(), context)
 
     columnarBatchIter.flatMap { batch =>
       val actualDataTypes = (0 until batch.numCols()).map(i => 
batch.column(i).dataType())
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 8467feb91d1..dbafc444281 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
 
 import org.apache.spark.api.python._
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -32,7 +33,8 @@ class ArrowPythonRunner(
     argOffsets: Array[Array[Int]],
     protected override val schema: StructType,
     protected override val timeZoneId: String,
-    protected override val workerConf: Map[String, String])
+    protected override val workerConf: Map[String, String],
+    val pythonMetrics: Map[String, SQLMetric])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, 
evalType, argOffsets)
   with BasicPythonArrowInput
   with BasicPythonArrowOutput {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 10f7966b93d..ca7ca2e2f80 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
  * A physical plan that evaluates a [[PythonUDF]]
  */
 case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: 
Seq[Attribute], child: SparkPlan)
-  extends EvalPythonExec {
+  extends EvalPythonExec with PythonSQLMetrics {
 
   protected override def evaluate(
       funcs: Seq[ChainedPythonFunctions],
@@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
     }.grouped(100).map(x => pickle.dumps(x.toArray))
 
     // Output iterator for results from Python.
-    val outputIterator = new PythonUDFRunner(funcs, 
PythonEvalType.SQL_BATCHED_UDF, argOffsets)
+    val outputIterator =
+      new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, 
pythonMetrics)
       .compute(inputIterator, context.partitionId(), context)
 
     val unpickle = new Unpickler
@@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
       val unpickledBatch = unpickle.loads(pickedResult)
       unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
     }.map { result =>
+      pythonMetrics("pythonNumRowsReceived") += 1
       if (udfs.length == 1) {
         // fast path for single UDF
         mutableRow(0) = fromJava(result)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 2661896ecec..1df9f37188a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
@@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner(
     leftSchema: StructType,
     rightSchema: StructType,
     timeZoneId: String,
-    conf: Map[String, String])
+    conf: Map[String, String],
+    val pythonMetrics: Map[String, SQLMetric])
   extends BasePythonRunner[
     (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, 
evalType, argOffsets)
   with BasicPythonArrowOutput {
@@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner(
         // For each we first send the number of dataframes in each group then 
send
         // first df, then send second df.  End of data is marked by sending 0.
         while (inputIterator.hasNext) {
+          val startData = dataOut.size()
           dataOut.writeInt(2)
           val (nextLeft, nextRight) = inputIterator.next()
           writeGroup(nextLeft, leftSchema, dataOut, "left")
           writeGroup(nextRight, rightSchema, dataOut, "right")
+
+          val deltaData = dataOut.size() - startData
+          pythonMetrics("pythonDataSent") += deltaData
         }
         dataOut.writeInt(0)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index b39787b12a4..629df51e18a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec(
     output: Seq[Attribute],
     left: SparkPlan,
     right: SparkPlan)
-  extends SparkPlan with BinaryExecNode {
+  extends SparkPlan with BinaryExecNode with PythonSQLMetrics {
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
@@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec(
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-
     val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
     val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, 
rightGroup)
 
@@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec(
           StructType.fromAttributes(leftDedup),
           StructType.fromAttributes(rightDedup),
           sessionLocalTimeZone,
-          pythonRunnerConf)
+          pythonRunnerConf,
+          pythonMetrics)
 
         executePython(data, output, runner)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index f0e815e966e..271ccdb6b27 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -50,7 +50,7 @@ case class FlatMapGroupsInPandasExec(
     func: Expression,
     output: Seq[Attribute],
     child: SparkPlan)
-  extends SparkPlan with UnaryExecNode {
+  extends SparkPlan with UnaryExecNode with PythonSQLMetrics {
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
@@ -89,7 +89,8 @@ case class FlatMapGroupsInPandasExec(
         Array(argOffsets),
         StructType.fromAttributes(dedupAttributes),
         sessionLocalTimeZone,
-        pythonRunnerConf)
+        pythonRunnerConf,
+        pythonMetrics)
 
       executePython(data, output, runner)
     }}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index 09123344c2e..3b096f07241 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -62,7 +62,8 @@ case class FlatMapGroupsInPandasWithStateExec(
     timeoutConf: GroupStateTimeout,
     batchTimestampMs: Option[Long],
     eventTimeWatermark: Option[Long],
-    child: SparkPlan) extends UnaryExecNode with 
FlatMapGroupsWithStateExecBase {
+    child: SparkPlan)
+  extends UnaryExecNode with PythonSQLMetrics with 
FlatMapGroupsWithStateExecBase {
 
   // TODO(SPARK-40444): Add the support of initial state.
   override protected val initialStateDeserializer: Expression = null
@@ -166,7 +167,8 @@ case class FlatMapGroupsInPandasWithStateExec(
         stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
         groupingAttributes.toStructType,
         outAttributes.toStructType,
-        stateType)
+        stateType,
+        pythonMetrics)
 
       val context = TaskContext.get()
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index d25c1383540..450891c6948 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, 
ColumnarBatch}
  * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and
  * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow`
  */
-trait MapInBatchExec extends UnaryExecNode {
+trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
   protected val func: Expression
   protected val pythonEvalType: Int
 
@@ -75,7 +75,8 @@ trait MapInBatchExec extends UnaryExecNode {
         argOffsets,
         StructType(StructField("struct", outputTypes) :: Nil),
         sessionLocalTimeZone,
-        pythonRunnerConf).compute(batchIter, context.partitionId(), context)
+        pythonRunnerConf,
+        pythonMetrics).compute(batchIter, context.partitionId(), context)
 
       val unsafeProj = UnsafeProjection.create(output, output)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index bf66791183e..5a0541d11cb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.api.python.{BasePythonRunner, PythonRDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.util.Utils
@@ -41,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected val timeZoneId: String
 
+  protected def pythonMetrics: Map[String, SQLMetric]
+
   protected def writeIteratorToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
@@ -115,6 +118,7 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
     val arrowWriter = ArrowWriter.create(root)
 
     while (inputIterator.hasNext) {
+      val startData = dataOut.size()
       val nextBatch = inputIterator.next()
 
       while (nextBatch.hasNext) {
@@ -124,6 +128,8 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
       arrowWriter.finish()
       writer.writeBatch()
       arrowWriter.reset()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 339f114539c..c12c690f776 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
 
 import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths}
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
@@ -37,6 +38,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, 
ColumnarBatch, Column
  */
 private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: 
BasePythonRunner[_, OUT] =>
 
+  protected def pythonMetrics: Map[String, SQLMetric]
+
   protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
 
   protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: 
StructType): OUT
@@ -82,10 +85,15 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { 
self: BasePythonRunner[
         }
         try {
           if (reader != null && batchLoaded) {
+            val bytesReadStart = reader.bytesRead()
             batchLoaded = reader.loadNextBatch()
             if (batchLoaded) {
               val batch = new ColumnarBatch(vectors)
+              val rowCount = root.getRowCount
               batch.setNumRows(root.getRowCount)
+              val bytesReadEnd = reader.bytesRead()
+              pythonMetrics("pythonNumRowsReceived") += rowCount
+              pythonMetrics("pythonDataReceived") += bytesReadEnd - 
bytesReadStart
               deserializeColumnarBatch(batch, schema)
             } else {
               reader.close(false)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
new file mode 100644
index 00000000000..a748c1bc100
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+private[sql] trait PythonSQLMetrics { self: SparkPlan =>
+
+  val pythonMetrics = Map(
+    "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext,
+      "data sent to Python workers"),
+    "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext,
+      "data returned from Python workers"),
+    "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext,
+      "number of output rows")
+  )
+
+  override lazy val metrics = pythonMetrics
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index d1109d251c2..09e06b55df3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
 
 import org.apache.spark._
 import org.apache.spark.api.python._
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -31,7 +32,8 @@ import org.apache.spark.sql.internal.SQLConf
 class PythonUDFRunner(
     funcs: Seq[ChainedPythonFunctions],
     evalType: Int,
-    argOffsets: Array[Array[Int]])
+    argOffsets: Array[Array[Int]],
+    pythonMetrics: Map[String, SQLMetric])
   extends BasePythonRunner[Array[Byte], Array[Byte]](
     funcs, evalType, argOffsets) {
 
@@ -50,8 +52,13 @@ class PythonUDFRunner(
       }
 
       protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
+        val startData = dataOut.size()
+
         PythonRDD.writeIteratorToStream(inputIterator, dataOut)
         dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+
+        val deltaData = dataOut.size() - startData
+        pythonMetrics("pythonDataSent") += deltaData
       }
     }
   }
@@ -77,6 +84,7 @@ class PythonUDFRunner(
             case length if length > 0 =>
               val obj = new Array[Byte](length)
               stream.readFully(obj)
+              pythonMetrics("pythonDataReceived") += length
               obj
             case 0 => Array.emptyByteArray
             case SpecialLengths.TIMING_DATA =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index ccb1ed92525..dcaffed89cc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -84,7 +84,7 @@ case class WindowInPandasExec(
     partitionSpec: Seq[Expression],
     orderSpec: Seq[SortOrder],
     child: SparkPlan)
-  extends WindowExecBase {
+  extends WindowExecBase with PythonSQLMetrics {
 
   /**
    * Helper functions and data structures for window bounds
@@ -375,7 +375,8 @@ case class WindowInPandasExec(
         argOffsets,
         pythonInputSchema,
         sessionLocalTimeZone,
-        pythonRunnerConf).compute(pythonInput, context.partitionId(), context)
+        pythonRunnerConf,
+        pythonMetrics).compute(pythonInput, context.partitionId(), context)
 
       val joined = new JoinedRow
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 2b8fc651561..b540f9f0093 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -34,6 +34,7 @@ import 
org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.python.PythonSQLMetrics
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
 import org.apache.spark.sql.types._
@@ -93,7 +94,7 @@ trait StateStoreReader extends StatefulOperator {
 }
 
 /** An operator that writes to a StateStore. */
-trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
+trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: 
SparkPlan =>
 
   override lazy val metrics = statefulOperatorCustomMetrics ++ Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
@@ -109,7 +110,7 @@ trait StateStoreWriter extends StatefulOperator { self: 
SparkPlan =>
     "numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of 
shuffle partitions"),
     "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
       "number of state store instances")
-  ) ++ stateStoreCustomMetrics
+  ) ++ stateStoreCustomMetrics ++ pythonMetrics
 
   /**
    * Get the progress made by this stateful operator after execution. This 
should be called in
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
index 70784c20a8e..7850b2d79b0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
@@ -84,4 +84,23 @@ class PythonUDFSuite extends QueryTest with 
SharedSparkSession {
 
     checkAnswer(actual, expected)
   }
+
+  test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") {
+
+    val pythonSQLMetrics = List(
+      "data sent to Python workers",
+      "data returned from Python workers",
+      "number of output rows")
+
+    val df = base.groupBy(pythonTestUDF(base("a") + 1))
+      .agg(pythonTestUDF(pythonTestUDF(base("a") + 1)))
+    df.count()
+
+    val statusStore = spark.sharedState.statusStore
+    val lastExecId = statusStore.executionsList.last.executionId
+    val executionMetrics = 
statusStore.execution(lastExecId).get.metrics.mkString
+    for (metric <- pythonSQLMetrics) {
+      assert(executionMetrics.contains(metric))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to