This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7004f9ed94be [SPARK-46424][PYTHON][SQL] Support Python metrics in
Python Data Source
7004f9ed94be is described below
commit 7004f9ed94be29e1fa5ddc398a3786c50562742a
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Sat Dec 16 08:14:11 2023 -0800
[SPARK-46424][PYTHON][SQL] Support Python metrics in Python Data Source
### What changes were proposed in this pull request?
This PR proposes to support Python metrics in Python Data Source so the
metrics are reported same as other Python execution and API.
### Why are the changes needed?
Same metics (https://github.com/apache/spark/pull/33559) should be shown in
Python Data Source reading. This is last missing part compared to other Python
execution and API.
### Does this PR introduce _any_ user-facing change?
Python Data Source has not been released yet, so no end-user facing change.
It shows some new metrics in UI.
Example:
```python
from pyspark.sql.datasource import DataSource, DataSourceReader,
InputPartition
class TestDataSourceReader(DataSourceReader):
def __init__(self, options):
self.options = options
def partitions(self):
return [InputPartition(i) for i in range(3)]
def read(self, partition):
yield partition.value, str(partition.value)
class TestDataSource(DataSource):
classmethod
def name(cls):
return "test"
def schema(self):
return "x INT, y STRING"
def reader(self, schema) -> "DataSourceReader":
return TestDataSourceReader(self.options)
spark.dataSource.register(TestDataSource)
sql("CREATE TABLE tblA USING test")
sql("SELECT * from tblA").show()
```
<img width="515" alt="Screenshot 2023-12-15 at 5 54 55 PM"
src="https://github.com/apache/spark/assets/6477701/5b98af8c-798e-4b9f-9fde-5549ad8b3c65">
This is same as other Python nodes, UDFs, etc.
### How was this patch tested?
Unittests were added, and manually tested via UI.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44375 from HyukjinKwon/SPARK-46424.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../ApplyInPandasWithStatePythonRunner.scala | 6 +-
.../execution/python/ArrowEvalPythonUDTFExec.scala | 2 +-
.../sql/execution/python/ArrowPythonRunner.scala | 6 +-
.../execution/python/ArrowPythonUDTFRunner.scala | 2 +-
.../python/CoGroupedArrowPythonRunner.scala | 6 +-
.../python/FlatMapGroupsInBatchExec.scala | 2 +-
.../python/MapInBatchEvaluatorFactory.scala | 2 +-
.../sql/execution/python/MapInBatchExec.scala | 2 +-
.../sql/execution/python/PythonArrowInput.scala | 4 +-
.../sql/execution/python/PythonArrowOutput.scala | 6 +-
.../sql/execution/python/PythonSQLMetrics.scala | 33 ++++---
.../python/UserDefinedPythonDataSource.scala | 102 ++++++++++++++++++---
.../execution/python/PythonDataSourceSuite.scala | 59 +++++++++++-
13 files changed, 186 insertions(+), 46 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index 936ab866f5bf..8795374b2a72 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -61,14 +61,12 @@ class ApplyInPandasWithStatePythonRunner(
keySchema: StructType,
outputSchema: StructType,
stateValueSchema: StructType,
- pyMetrics: Map[String, SQLMetric],
+ override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets,
jobArtifactUUID)
with PythonArrowInput[InType]
with PythonArrowOutput[OutType] {
- override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics)
-
override val pythonExec: String =
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head.funcs.head.pythonExec)
@@ -151,7 +149,7 @@ class ApplyInPandasWithStatePythonRunner(
pandasWriter.finalizeGroup()
val deltaData = dataOut.size() - startData
- pythonMetrics.foreach(_("pythonDataSent") += deltaData)
+ pythonMetrics("pythonDataSent") += deltaData
true
} else {
pandasWriter.finalizeData()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 2503deae7d5a..9e210bf5241b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -70,7 +70,7 @@ case class ArrowEvalPythonUDTFExec(
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
- Some(pythonMetrics),
+ pythonMetrics,
jobArtifactUUID).compute(batchIter, context.partitionId(), context)
columnarBatchIter.map { batch =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 5dcb79cc2b91..33933b64bbaf 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -35,7 +35,7 @@ abstract class BaseArrowPythonRunner(
_timeZoneId: String,
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
- override val pythonMetrics: Option[Map[String, SQLMetric]],
+ override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
funcs, evalType, argOffsets, jobArtifactUUID)
@@ -74,7 +74,7 @@ class ArrowPythonRunner(
_timeZoneId: String,
largeVarTypes: Boolean,
workerConf: Map[String, String],
- pythonMetrics: Option[Map[String, SQLMetric]],
+ pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BaseArrowPythonRunner(
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
workerConf,
@@ -100,7 +100,7 @@ class ArrowPythonWithNamedArgumentRunner(
jobArtifactUUID: Option[String])
extends BaseArrowPythonRunner(
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId,
largeVarTypes, workerConf,
- Some(pythonMetrics), jobArtifactUUID) {
+ pythonMetrics, jobArtifactUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit =
PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index df2e89128124..f52b01b6646a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
protected override val timeZoneId: String,
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
- override val pythonMetrics: Option[Map[String, SQLMetric]],
+ override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType,
Array(argMetas.map(_.offset)),
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 70bd1ce82e2e..7e1c8c2ffc13 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -46,15 +46,13 @@ class CoGroupedArrowPythonRunner(
rightSchema: StructType,
timeZoneId: String,
conf: Map[String, String],
- pyMetrics: Map[String, SQLMetric],
+ override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
funcs, evalType, argOffsets, jobArtifactUUID)
with BasicPythonArrowOutput {
- override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics)
-
override val pythonExec: String =
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head.funcs.head.pythonExec)
@@ -95,7 +93,7 @@ class CoGroupedArrowPythonRunner(
writeGroup(nextRight, rightSchema, dataOut, "right")
val deltaData = dataOut.size() - startData
- pythonMetrics.foreach(_("pythonDataSent") += deltaData)
+ pythonMetrics("pythonDataSent") += deltaData
true
} else {
dataOut.writeInt(0)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
index 5550ddf72a14..facf7bc49c5a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala
@@ -88,7 +88,7 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with
UnaryExecNode with PythonS
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
- Some(pythonMetrics),
+ pythonMetrics,
jobArtifactUUID)
executePython(data, output, runner)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 00990ee46ea5..29dc6e0aa541 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -36,7 +36,7 @@ class MapInBatchEvaluatorFactory(
sessionLocalTimeZone: String,
largeVarTypes: Boolean,
pythonRunnerConf: Map[String, String],
- pythonMetrics: Option[Map[String, SQLMetric]],
+ val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 6db6c96b426a..8db389f02667 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -57,7 +57,7 @@ trait MapInBatchExec extends UnaryExecNode with
PythonSQLMetrics {
conf.sessionLocalTimeZone,
conf.arrowUseLargeVarTypes,
pythonRunnerConf,
- Some(pythonMetrics),
+ pythonMetrics,
jobArtifactUUID)
if (isBarrier) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 6d0f31f35ff7..1e075cab9224 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -46,7 +46,7 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected val largeVarTypes: Boolean
- protected def pythonMetrics: Option[Map[String, SQLMetric]]
+ protected def pythonMetrics: Map[String, SQLMetric]
protected def writeNextInputToArrowStream(
root: VectorSchemaRoot,
@@ -132,7 +132,7 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
writer.writeBatch()
arrowWriter.reset()
val deltaData = dataOut.size() - startData
- pythonMetrics.foreach(_("pythonDataSent") += deltaData)
+ pythonMetrics("pythonDataSent") += deltaData
true
} else {
super[PythonArrowInput].close()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 82e8e7aa4f64..90922d89ad10 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector,
ColumnarBatch, Column
*/
private[python] trait PythonArrowOutput[OUT <: AnyRef] { self:
BasePythonRunner[_, OUT] =>
- protected def pythonMetrics: Option[Map[String, SQLMetric]]
+ protected def pythonMetrics: Map[String, SQLMetric]
protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
@@ -91,8 +91,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] {
self: BasePythonRunner[
val rowCount = root.getRowCount
batch.setNumRows(root.getRowCount)
val bytesReadEnd = reader.bytesRead()
- pythonMetrics.foreach(_("pythonNumRowsReceived") += rowCount)
- pythonMetrics.foreach(_("pythonDataReceived") += bytesReadEnd -
bytesReadStart)
+ pythonMetrics("pythonNumRowsReceived") += rowCount
+ pythonMetrics("pythonDataReceived") += bytesReadEnd -
bytesReadStart
deserializeColumnarBatch(batch, schema)
} else {
reader.close(false)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
index a748c1bc1008..4df6d821c014 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala
@@ -18,18 +18,29 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-private[sql] trait PythonSQLMetrics { self: SparkPlan =>
+trait PythonSQLMetrics { self: SparkPlan =>
+ protected val pythonMetrics: Map[String, SQLMetric] = {
+ PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) =>
+ k -> SQLMetrics.createSizeMetric(sparkContext, v)
+ } ++ PythonSQLMetrics.pythonOtherMetricsDesc.map { case (k, v) =>
+ k -> SQLMetrics.createMetric(sparkContext, v)
+ }
+ }
- val pythonMetrics = Map(
- "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext,
- "data sent to Python workers"),
- "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext,
- "data returned from Python workers"),
- "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext,
- "number of output rows")
- )
+ override lazy val metrics: Map[String, SQLMetric] = pythonMetrics
+}
+
+object PythonSQLMetrics {
+ val pythonSizeMetricsDesc: Map[String, String] = {
+ Map(
+ "pythonDataSent" -> "data sent to Python workers",
+ "pythonDataReceived" -> "data returned from Python workers"
+ )
+ }
- override lazy val metrics = pythonMetrics
+ val pythonOtherMetricsDesc: Map[String, String] = {
+ Map("pythonNumRowsReceived" -> "number of output rows")
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 047a133a322a..f6a473adf08d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -34,8 +34,10 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table,
TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.read.{Batch, InputPartition,
PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BinaryType, DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -101,8 +103,10 @@ class PythonTableProvider extends TableProvider {
new PythonPartitionReaderFactory(
source, readerFunc, outputSchema, jobArtifactUUID)
}
-
override def description: String = "(Python)"
+
+ override def supportedCustomMetrics(): Array[CustomMetric] =
+ source.createPythonMetrics()
}
}
@@ -124,21 +128,78 @@ class PythonPartitionReaderFactory(
override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
new PartitionReader[InternalRow] {
- private val outputIter = source.createPartitionReadIteratorInPython(
- partition.asInstanceOf[PythonInputPartition],
- pickledReadFunc,
- outputSchema,
- jobArtifactUUID)
+ // Dummy SQLMetrics. The result is manually reported via DSv2 interface
+ // via passing the value to `CustomTaskMetric`. Note that
`pythonOtherMetricsDesc`
+ // is not used when it is reported. It is to reuse existing Python
runner.
+ // See also `UserDefinedPythonDataSource.createPythonMetrics`.
+ private[this] val metrics: Map[String, SQLMetric] = {
+ PythonSQLMetrics.pythonSizeMetricsDesc.keys
+ .map(_ -> new SQLMetric("size", -1)).toMap ++
+ PythonSQLMetrics.pythonOtherMetricsDesc.keys
+ .map(_ -> new SQLMetric("sum", -1)).toMap
+ }
+
+ private val outputIter = {
+ val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+ pickledReadFunc,
+ outputSchema,
+ metrics,
+ jobArtifactUUID)
+
+ val part = partition.asInstanceOf[PythonInputPartition]
+ evaluatorFactory.createEvaluator().eval(
+ part.index, Iterator.single(InternalRow(part.pickedPartition)))
+ }
override def next(): Boolean = outputIter.hasNext
override def get(): InternalRow = outputIter.next()
override def close(): Unit = {}
+
+ override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ source.createPythonTaskMetrics(metrics.map { case (k, v) => k ->
v.value})
+ }
}
}
}
+class PythonCustomMetric extends CustomMetric {
+ private var initName: String = _
+ private var initDescription: String = _
+ def initialize(n: String, d: String): Unit = {
+ initName = n
+ initDescription = d
+ }
+ override def name(): String = {
+ assert(initName != null)
+ initName
+ }
+ override def description(): String = {
+ assert(initDescription != null)
+ initDescription
+ }
+ override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
+ SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
+ }
+}
+
+class PythonCustomTaskMetric extends CustomTaskMetric {
+ private var initName: String = _
+ private var initValue: Long = -1L
+ def initialize(n: String, v: Long): Unit = {
+ initName = n
+ initValue = v
+ }
+ override def name(): String = {
+ assert(initName != null)
+ initName
+ }
+ override def value(): Long = {
+ initValue
+ }
+}
+
/**
* A user-defined Python data source. This is used by the Python API.
* Defines the interation between Python and JVM.
@@ -179,11 +240,11 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
/**
* (Executor-side) Create an iterator that reads the input partitions.
*/
- def createPartitionReadIteratorInPython(
- partition: PythonInputPartition,
+ def createMapInBatchEvaluatorFactory(
pickledReadFunc: Array[Byte],
outputSchema: StructType,
- jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
+ metrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
val readerFunc = createPythonFunction(pickledReadFunc)
val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
@@ -199,7 +260,7 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
val conf = SQLConf.get
val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
- val evaluatorFactory = new MapInBatchEvaluatorFactory(
+ new MapInBatchEvaluatorFactory(
toAttributes(outputSchema),
Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
inputSchema,
@@ -208,11 +269,26 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
conf.sessionLocalTimeZone,
conf.arrowUseLargeVarTypes,
pythonRunnerConf,
- None,
+ metrics,
jobArtifactUUID)
+ }
+
+ def createPythonMetrics(): Array[CustomMetric] = {
+ // Do not add other metrics such as number of rows,
+ // that is already included via DSv2.
+ PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) =>
+ val m = new PythonCustomMetric()
+ m.initialize(k, v)
+ m
+ }.toArray
+ }
- evaluatorFactory.createEvaluator().eval(
- partition.index, Iterator.single(InternalRow(partition.pickedPartition)))
+ def createPythonTaskMetrics(taskMetrics: Map[String, Long]):
Array[CustomTaskMetric] = {
+ taskMetrics.map { case (k, v) =>
+ val m = new PythonCustomTaskMetric()
+ m.initialize(k, v)
+ m
+ }.toArray
}
private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction =
{
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 53a54abf8392..e8a46449ac20 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils,
QueryTest, Row}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2ScanRelation}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
@@ -396,4 +396,61 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR"))
}
}
+
+ test("SPARK-46424: Support Python metrics") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |class SimpleDataSourceReader(DataSourceReader):
+ | def partitions(self):
+ | return []
+ |
+ | def read(self, partition):
+ | if partition is None:
+ | yield ("success", )
+ | else:
+ | yield ("failed", )
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "status STRING"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ val df = spark.read.format(dataSourceName).load()
+
+ val statusStore = spark.sharedState.statusStore
+ val oldCount = statusStore.executionsList().size
+
+ df.collect()
+
+ // Wait until the new execution is started and being tracked.
+ while (statusStore.executionsCount() < oldCount) {
+ Thread.sleep(100)
+ }
+
+ // Wait for listener to finish computing the metrics for the execution.
+ while (statusStore.executionsList().isEmpty ||
+ statusStore.executionsList().last.metricValues == null) {
+ Thread.sleep(100)
+ }
+
+ val executedPlan = df.queryExecution.executedPlan.collectFirst {
+ case p: BatchScanExec => p
+ }
+ assert(executedPlan.isDefined)
+
+ val execId = statusStore.executionsList().last.executionId
+ val metrics = statusStore.executionMetrics(execId)
+ val pythonDataSent = executedPlan.get.metrics("pythonDataSent")
+ val pythonDataReceived = executedPlan.get.metrics("pythonDataReceived")
+ assert(metrics.contains(pythonDataSent.id))
+ assert(metrics(pythonDataSent.id).asInstanceOf[String].endsWith("B"))
+ assert(metrics.contains(pythonDataReceived.id))
+ assert(metrics(pythonDataReceived.id).asInstanceOf[String].endsWith("B"))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]