This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5cbad6aef2fd [SPARK-46424][PYTHON][SQL][FOLLOW-UP] Refactor and clean
up Python metric implementation
5cbad6aef2fd is described below
commit 5cbad6aef2fd1214014d64ef602f9a726a019d99
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Dec 19 13:58:55 2023 +0900
[SPARK-46424][PYTHON][SQL][FOLLOW-UP] Refactor and clean up Python metric
implementation
### What changes were proposed in this pull request?
This PR is a followup of https://github.com/apache/spark/pull/44375 that
refactors and clean up the codes (pointed by review comments).
### Why are the changes needed?
For better readability.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test cases should cover. I also manually tested as described in
https://github.com/apache/spark/pull/44375
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44406 from HyukjinKwon/SPARK-46424-followup.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../python/UserDefinedPythonDataSource.scala | 80 +++++++---------------
1 file changed, 26 insertions(+), 54 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index f6a473adf08d..d31b3135d65e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -139,17 +139,12 @@ class PythonPartitionReaderFactory(
.map(_ -> new SQLMetric("sum", -1)).toMap
}
- private val outputIter = {
- val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
- pickledReadFunc,
- outputSchema,
- metrics,
- jobArtifactUUID)
-
- val part = partition.asInstanceOf[PythonInputPartition]
- evaluatorFactory.createEvaluator().eval(
- part.index, Iterator.single(InternalRow(part.pickedPartition)))
- }
+ private val outputIter = source.createPartitionReadIteratorInPython(
+ partition.asInstanceOf[PythonInputPartition],
+ pickledReadFunc,
+ outputSchema,
+ metrics,
+ jobArtifactUUID)
override def next(): Boolean = outputIter.hasNext
@@ -164,41 +159,20 @@ class PythonPartitionReaderFactory(
}
}
-class PythonCustomMetric extends CustomMetric {
- private var initName: String = _
- private var initDescription: String = _
- def initialize(n: String, d: String): Unit = {
- initName = n
- initDescription = d
- }
- override def name(): String = {
- assert(initName != null)
- initName
- }
- override def description(): String = {
- assert(initDescription != null)
- initDescription
- }
+class PythonCustomMetric(
+ override val name: String,
+ override val description: String) extends CustomMetric {
+ // To allow the aggregation can be called. See
`SQLAppStatusListener.aggregateMetrics`
+ def this() = this(null, null)
+
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
}
}
-class PythonCustomTaskMetric extends CustomTaskMetric {
- private var initName: String = _
- private var initValue: Long = -1L
- def initialize(n: String, v: Long): Unit = {
- initName = n
- initValue = v
- }
- override def name(): String = {
- assert(initName != null)
- initName
- }
- override def value(): Long = {
- initValue
- }
-}
+class PythonCustomTaskMetric(
+ override val name: String,
+ override val value: Long) extends CustomTaskMetric
/**
* A user-defined Python data source. This is used by the Python API.
@@ -240,11 +214,12 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
/**
* (Executor-side) Create an iterator that reads the input partitions.
*/
- def createMapInBatchEvaluatorFactory(
+ def createPartitionReadIteratorInPython(
+ partition: PythonInputPartition,
pickledReadFunc: Array[Byte],
outputSchema: StructType,
metrics: Map[String, SQLMetric],
- jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
+ jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
val readerFunc = createPythonFunction(pickledReadFunc)
val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
@@ -260,7 +235,7 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
val conf = SQLConf.get
val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
- new MapInBatchEvaluatorFactory(
+ val evaluatorFactory = new MapInBatchEvaluatorFactory(
toAttributes(outputSchema),
Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
inputSchema,
@@ -271,24 +246,21 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
pythonRunnerConf,
metrics,
jobArtifactUUID)
+
+ val part = partition
+ evaluatorFactory.createEvaluator().eval(
+ part.index, Iterator.single(InternalRow(part.pickedPartition)))
}
def createPythonMetrics(): Array[CustomMetric] = {
// Do not add other metrics such as number of rows,
// that is already included via DSv2.
- PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) =>
- val m = new PythonCustomMetric()
- m.initialize(k, v)
- m
- }.toArray
+ PythonSQLMetrics.pythonSizeMetricsDesc
+ .map { case (k, v) => new PythonCustomMetric(k, v)}.toArray
}
def createPythonTaskMetrics(taskMetrics: Map[String, Long]):
Array[CustomTaskMetric] = {
- taskMetrics.map { case (k, v) =>
- val m = new PythonCustomTaskMetric()
- m.initialize(k, v)
- m
- }.toArray
+ taskMetrics.map { case (k, v) => new PythonCustomTaskMetric(k, v)}.toArray
}
private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction =
{
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]