This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e966c387f62 [SPARK-34265][PYTHON][SQL] Instrument Python UDFs using
SQL metrics
e966c387f62 is described below
commit e966c387f624da3ece6e507f95f00cf20b42f45e
Author: Luca Canali <[email protected]>
AuthorDate: Mon Oct 24 17:04:37 2022 +0800
[SPARK-34265][PYTHON][SQL] Instrument Python UDFs using SQL metrics
### What changes are proposed in this pull request?
This proposes to add SQLMetrics instrumentation for Python UDF execution,
including Pandas UDF, and related operations such as MapInPandas and MapInArrow.
The proposed metrics are:
- data sent to Python workers
- data returned from Python workers
- number of output rows
### Why are the changes needed?
This aims at improving monitoring and performance troubleshooting of Python
UDFs.
In particular it is intended as an aid to answer performance-related
questions such as:
why is the UDF slow?, how much work has been done so far?, etc.
### Does this PR introduce _any_ user-facing change?
SQL metrics are made available in the WEB UI.
See the following examples:

### How was this patch tested?
Manually tested + a Python unit test and a Scala unit test have been added.
Example code used for testing:
```
from pyspark.sql.functions import col, pandas_udf
import time
pandas_udf("long")
def test_pandas(col1):
time.sleep(0.02)
return col1 * col1
spark.udf.register("test_pandas", test_pandas)
spark.sql("select rand(42)*rand(51)*rand(12) col1 from
range(10000000)").createOrReplaceTempView("t1")
spark.sql("select max(test_pandas(col1)) from t1").collect()
```
This is used to test with more data pushed to the Python workers:
```
from pyspark.sql.functions import col, pandas_udf
import time
pandas_udf("long")
def
test_pandas(col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15,col16,col17):
time.sleep(0.02)
return col1
spark.udf.register("test_pandas", test_pandas)
spark.sql("select rand(42)*rand(51)*rand(12) col1 from
range(10000000)").createOrReplaceTempView("t1")
spark.sql("select
max(test_pandas(col1,col1+1,col1+2,col1+3,col1+4,col1+5,col1+6,col1+7,col1+8,col1+9,col1+10,col1+11,col1+12,col1+13,col1+14,col1+15,col1+16))
from t1").collect()
```
This (from the Spark doc) has been used to test with MapInPandas, where the
number of output rows is different from the number of input rows:
```
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator):
for pdf in iterator:
yield pdf[pdf.id == 1]
df.mapInPandas(filter_func, schema=df.schema).show()
```
This for testing BatchEvalPython and metrics related to data transfer
(bytes sent and received):
```
from pyspark.sql.functions import udf
udf
def test_udf(col1, col2):
return col1 * col1
spark.sql("select id,
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
col2 from range(10)").select(test_udf("id", "col2")).collect()
```
Closes #33559 from LucaCanali/pythonUDFKeySQLMetrics.
Authored-by: Luca Canali <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
docs/web-ui.md | 2 +
python/pyspark/sql/tests/test_pandas_sqlmetrics.py | 68 ++++++++++++++++++++++
.../execution/python/AggregateInPandasExec.scala | 5 +-
.../ApplyInPandasWithStatePythonRunner.scala | 7 ++-
.../sql/execution/python/ArrowEvalPythonExec.scala | 5 +-
.../sql/execution/python/ArrowPythonRunner.scala | 4 +-
.../sql/execution/python/BatchEvalPythonExec.scala | 6 +-
.../python/CoGroupedArrowPythonRunner.scala | 8 ++-
.../python/FlatMapCoGroupsInPandasExec.scala | 6 +-
.../python/FlatMapGroupsInPandasExec.scala | 5 +-
.../FlatMapGroupsInPandasWithStateExec.scala | 6 +-
.../sql/execution/python/MapInBatchExec.scala | 5 +-
.../sql/execution/python/PythonArrowInput.scala | 6 ++
.../sql/execution/python/PythonArrowOutput.scala | 8 +++
.../sql/execution/python/PythonSQLMetrics.scala | 35 +++++++++++
.../sql/execution/python/PythonUDFRunner.scala | 10 +++-
.../sql/execution/python/WindowInPandasExec.scala | 5 +-
.../execution/streaming/statefulOperators.scala | 5 +-
.../sql/execution/python/PythonUDFSuite.scala | 19 ++++++
20 files changed, 193 insertions(+), 23 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2a427139148..a439b4cbbed 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -484,6 +484,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.pandas.test_pandas_udf_typehints",
"pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations",
"pyspark.sql.tests.pandas.test_pandas_udf_window",
+ "pyspark.sql.tests.test_pandas_sqlmetrics",
"pyspark.sql.tests.test_readwriter",
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
diff --git a/docs/web-ui.md b/docs/web-ui.md
index d3356ec5a43..e228d7fe2a9 100644
--- a/docs/web-ui.md
+++ b/docs/web-ui.md
@@ -406,6 +406,8 @@ Here is the list of SQL metrics:
<tr><td> <code>time to build hash map</code> </td><td> the time spent on
building hash map </td><td> ShuffledHashJoin </td></tr>
<tr><td> <code>task commit time</code> </td><td> the time spent on committing
the output of a task after the writes succeed </td><td> any write operation on
a file-based table </td></tr>
<tr><td> <code>job commit time</code> </td><td> the time spent on committing
the output of a job after the writes succeed </td><td> any write operation on a
file-based table </td></tr>
+<tr><td> <code>data sent to Python workers</code> </td><td> the number of
bytes of serialized data sent to the Python workers </td><td> ArrowEvalPython,
AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas,
FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas,
PythonMapInArrow, WindowsInPandas </td></tr>
+<tr><td> <code>data returned from Python workers</code> </td><td> the number
of bytes of serialized data received back from the Python workers </td><td>
ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas,
FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas,
PythonMapInArrow, WindowsInPandas </td></tr>
</table>
## Structured Streaming Tab
diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
new file mode 100644
index 00000000000..d182bafd8b5
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import cast
+
+from pyspark.sql.functions import pandas_udf
+from pyspark.testing.sqlutils import (
+ ReusedSQLTestCase,
+ have_pandas,
+ have_pyarrow,
+ pandas_requirement_message,
+ pyarrow_requirement_message,
+)
+
+
[email protected](
+ not have_pandas or not have_pyarrow,
+ cast(str, pandas_requirement_message or pyarrow_requirement_message),
+)
+class PandasSQLMetrics(ReusedSQLTestCase):
+ def test_pandas_sql_metrics_basic(self):
+ # SPARK-34265: Instrument Python UDFs using SQL metrics
+
+ python_sql_metrics = [
+ "data sent to Python workers",
+ "data returned from Python workers",
+ "number of output rows",
+ ]
+
+ @pandas_udf("long")
+ def test_pandas(col1):
+ return col1 * col1
+
+ self.spark.range(10).select(test_pandas("id")).collect()
+
+ statusStore = self.spark._jsparkSession.sharedState().statusStore()
+ lastExecId = statusStore.executionsList().last().executionId()
+ executionMetrics =
statusStore.execution(lastExecId).get().metrics().mkString()
+
+ for metric in python_sql_metrics:
+ self.assertIn(metric, executionMetrics)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 2f85149ee8e..6a8b197742d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -46,7 +46,7 @@ case class AggregateInPandasExec(
udfExpressions: Seq[PythonUDF],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends UnaryExecNode {
+ extends UnaryExecNode with PythonSQLMetrics {
override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -163,7 +163,8 @@ case class AggregateInPandasExec(
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
- pythonRunnerConf).compute(projectedRowIter, context.partitionId(),
context)
+ pythonRunnerConf,
+ pythonMetrics).compute(projectedRowIter, context.partitionId(),
context)
val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++
udfExpressions.map(_.resultAttribute)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index bd8c72029dc..f3531668c8e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType,
OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
import
org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
import org.apache.spark.sql.execution.streaming.GroupStateImpl
@@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner(
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
outputSchema: StructType,
- stateValueSchema: StructType)
+ stateValueSchema: StructType,
+ val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
with PythonArrowInput[InType]
with PythonArrowOutput[OutType] {
@@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner(
val w = new ApplyInPandasWithStateWriter(root, writer,
arrowMaxRecordsPerBatch)
while (inputIterator.hasNext) {
+ val startData = dataOut.size()
val (keyRow, groupState, dataIter) = inputIterator.next()
assert(dataIter.hasNext, "should have at least one data row!")
w.startNewGroup(keyRow, groupState)
@@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner(
}
w.finalizeGroup()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
w.finalizeData()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 096712cf935..b11dd4947af 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T],
batchSize: Int)
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs:
Seq[Attribute], child: SparkPlan,
evalType: Int)
- extends EvalPythonExec {
+ extends EvalPythonExec with PythonSQLMetrics {
private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
@@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF],
resultAttrs: Seq[Attribute]
argOffsets,
schema,
sessionLocalTimeZone,
- pythonRunnerConf).compute(batchIter, context.partitionId(), context)
+ pythonRunnerConf,
+ pythonMetrics).compute(batchIter, context.partitionId(), context)
columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i =>
batch.column(i).dataType())
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 8467feb91d1..dbafc444281 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -32,7 +33,8 @@ class ArrowPythonRunner(
argOffsets: Array[Array[Int]],
protected override val schema: StructType,
protected override val timeZoneId: String,
- protected override val workerConf: Map[String, String])
+ protected override val workerConf: Map[String, String],
+ val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs,
evalType, argOffsets)
with BasicPythonArrowInput
with BasicPythonArrowOutput {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 10f7966b93d..ca7ca2e2f80 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
* A physical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs:
Seq[Attribute], child: SparkPlan)
- extends EvalPythonExec {
+ extends EvalPythonExec with PythonSQLMetrics {
protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
@@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF],
resultAttrs: Seq[Attribute]
}.grouped(100).map(x => pickle.dumps(x.toArray))
// Output iterator for results from Python.
- val outputIterator = new PythonUDFRunner(funcs,
PythonEvalType.SQL_BATCHED_UDF, argOffsets)
+ val outputIterator =
+ new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets,
pythonMetrics)
.compute(inputIterator, context.partitionId(), context)
val unpickle = new Unpickler
@@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF],
resultAttrs: Seq[Attribute]
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
+ pythonMetrics("pythonNumRowsReceived") += 1
if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = fromJava(result)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 2661896ecec..1df9f37188a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions,
PythonRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
@@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner(
leftSchema: StructType,
rightSchema: StructType,
timeZoneId: String,
- conf: Map[String, String])
+ conf: Map[String, String],
+ val pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs,
evalType, argOffsets)
with BasicPythonArrowOutput {
@@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner(
// For each we first send the number of dataframes in each group then
send
// first df, then send second df. End of data is marked by sending 0.
while (inputIterator.hasNext) {
+ val startData = dataOut.size()
dataOut.writeInt(2)
val (nextLeft, nextRight) = inputIterator.next()
writeGroup(nextLeft, leftSchema, dataOut, "left")
writeGroup(nextRight, rightSchema, dataOut, "right")
+
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
dataOut.writeInt(0)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index b39787b12a4..629df51e18a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec(
output: Seq[Attribute],
left: SparkPlan,
right: SparkPlan)
- extends SparkPlan with BinaryExecNode {
+ extends SparkPlan with BinaryExecNode with PythonSQLMetrics {
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
@@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec(
}
override protected def doExecute(): RDD[InternalRow] = {
-
val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output,
rightGroup)
@@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec(
StructType.fromAttributes(leftDedup),
StructType.fromAttributes(rightDedup),
sessionLocalTimeZone,
- pythonRunnerConf)
+ pythonRunnerConf,
+ pythonMetrics)
executePython(data, output, runner)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index f0e815e966e..271ccdb6b27 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -50,7 +50,7 @@ case class FlatMapGroupsInPandasExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan)
- extends SparkPlan with UnaryExecNode {
+ extends SparkPlan with UnaryExecNode with PythonSQLMetrics {
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
@@ -89,7 +89,8 @@ case class FlatMapGroupsInPandasExec(
Array(argOffsets),
StructType.fromAttributes(dedupAttributes),
sessionLocalTimeZone,
- pythonRunnerConf)
+ pythonRunnerConf,
+ pythonMetrics)
executePython(data, output, runner)
}}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index 09123344c2e..3b096f07241 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -62,7 +62,8 @@ case class FlatMapGroupsInPandasWithStateExec(
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
- child: SparkPlan) extends UnaryExecNode with
FlatMapGroupsWithStateExecBase {
+ child: SparkPlan)
+ extends UnaryExecNode with PythonSQLMetrics with
FlatMapGroupsWithStateExecBase {
// TODO(SPARK-40444): Add the support of initial state.
override protected val initialStateDeserializer: Expression = null
@@ -166,7 +167,8 @@ case class FlatMapGroupsInPandasWithStateExec(
stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
groupingAttributes.toStructType,
outAttributes.toStructType,
- stateType)
+ stateType,
+ pythonMetrics)
val context = TaskContext.get()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index d25c1383540..450891c6948 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector,
ColumnarBatch}
* This is somewhat similar with [[FlatMapGroupsInPandasExec]] and
* `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow`
*/
-trait MapInBatchExec extends UnaryExecNode {
+trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
protected val func: Expression
protected val pythonEvalType: Int
@@ -75,7 +75,8 @@ trait MapInBatchExec extends UnaryExecNode {
argOffsets,
StructType(StructField("struct", outputTypes) :: Nil),
sessionLocalTimeZone,
- pythonRunnerConf).compute(batchIter, context.partitionId(), context)
+ pythonRunnerConf,
+ pythonMetrics).compute(batchIter, context.partitionId(), context)
val unsafeProj = UnsafeProjection.create(output, output)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index bf66791183e..5a0541d11cb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
@@ -41,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected val timeZoneId: String
+ protected def pythonMetrics: Map[String, SQLMetric]
+
protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
@@ -115,6 +118,7 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
val arrowWriter = ArrowWriter.create(root)
while (inputIterator.hasNext) {
+ val startData = dataOut.size()
val nextBatch = inputIterator.next()
while (nextBatch.hasNext) {
@@ -124,6 +128,8 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 339f114539c..c12c690f776 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch,
ColumnVector}
@@ -37,6 +38,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector,
ColumnarBatch, Column
*/
private[python] trait PythonArrowOutput[OUT <: AnyRef] { self:
BasePythonRunner[_, OUT] =>
+ protected def pythonMetrics: Map[String, SQLMetric]
+
protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
protected def deserializeColumnarBatch(batch: ColumnarBatch, schema:
StructType): OUT
@@ -82,10 +85,15 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] {
self: BasePythonRunner[
}
try {
if (reader != null && batchLoaded) {
+ val bytesReadStart = reader.bytesRead()
batchLoaded = reader.loadNextBatch()
if (batchLoaded) {
val batch = new ColumnarBatch(vectors)
+ val rowCount = root.getRowCount
batch.setNumRows(root.getRowCount)
+ val bytesReadEnd = reader.bytesRead()
+ pythonMetrics("pythonNumRowsReceived") += rowCount
+ pythonMetrics("pythonDataReceived") += bytesReadEnd -
bytesReadStart
deserializeColumnarBatch(batch, schema)
} else {
reader.close(false)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
new file mode 100644
index 00000000000..a748c1bc100
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+private[sql] trait PythonSQLMetrics { self: SparkPlan =>
+
+ val pythonMetrics = Map(
+ "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext,
+ "data sent to Python workers"),
+ "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext,
+ "data returned from Python workers"),
+ "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext,
+ "number of output rows")
+ )
+
+ override lazy val metrics = pythonMetrics
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index d1109d251c2..09e06b55df3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark._
import org.apache.spark.api.python._
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
/**
@@ -31,7 +32,8 @@ import org.apache.spark.sql.internal.SQLConf
class PythonUDFRunner(
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
- argOffsets: Array[Array[Int]])
+ argOffsets: Array[Array[Int]],
+ pythonMetrics: Map[String, SQLMetric])
extends BasePythonRunner[Array[Byte], Array[Byte]](
funcs, evalType, argOffsets) {
@@ -50,8 +52,13 @@ class PythonUDFRunner(
}
protected override def writeIteratorToStream(dataOut: DataOutputStream):
Unit = {
+ val startData = dataOut.size()
+
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
}
}
@@ -77,6 +84,7 @@ class PythonUDFRunner(
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
+ pythonMetrics("pythonDataReceived") += length
obj
case 0 => Array.emptyByteArray
case SpecialLengths.TIMING_DATA =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index ccb1ed92525..dcaffed89cc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -84,7 +84,7 @@ case class WindowInPandasExec(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
- extends WindowExecBase {
+ extends WindowExecBase with PythonSQLMetrics {
/**
* Helper functions and data structures for window bounds
@@ -375,7 +375,8 @@ case class WindowInPandasExec(
argOffsets,
pythonInputSchema,
sessionLocalTimeZone,
- pythonRunnerConf).compute(pythonInput, context.partitionId(), context)
+ pythonRunnerConf,
+ pythonMetrics).compute(pythonInput, context.partitionId(), context)
val joined = new JoinedRow
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 2b8fc651561..b540f9f0093 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -34,6 +34,7 @@ import
org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.python.PythonSQLMetrics
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
@@ -93,7 +94,7 @@ trait StateStoreReader extends StatefulOperator {
}
/** An operator that writes to a StateStore. */
-trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
+trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self:
SparkPlan =>
override lazy val metrics = statefulOperatorCustomMetrics ++ Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output
rows"),
@@ -109,7 +110,7 @@ trait StateStoreWriter extends StatefulOperator { self:
SparkPlan =>
"numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of
shuffle partitions"),
"numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
"number of state store instances")
- ) ++ stateStoreCustomMetrics
+ ) ++ stateStoreCustomMetrics ++ pythonMetrics
/**
* Get the progress made by this stateful operator after execution. This
should be called in
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
index 70784c20a8e..7850b2d79b0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
@@ -84,4 +84,23 @@ class PythonUDFSuite extends QueryTest with
SharedSparkSession {
checkAnswer(actual, expected)
}
+
+ test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") {
+
+ val pythonSQLMetrics = List(
+ "data sent to Python workers",
+ "data returned from Python workers",
+ "number of output rows")
+
+ val df = base.groupBy(pythonTestUDF(base("a") + 1))
+ .agg(pythonTestUDF(pythonTestUDF(base("a") + 1)))
+ df.count()
+
+ val statusStore = spark.sharedState.statusStore
+ val lastExecId = statusStore.executionsList.last.executionId
+ val executionMetrics =
statusStore.execution(lastExecId).get.metrics.mkString
+ for (metric <- pythonSQLMetrics) {
+ assert(executionMetrics.contains(metric))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]