This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5cbad6aef2fd [SPARK-46424][PYTHON][SQL][FOLLOW-UP] Refactor and clean up Python metric implementation 5cbad6aef2fd is described below commit 5cbad6aef2fd1214014d64ef602f9a726a019d99 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Tue Dec 19 13:58:55 2023 +0900 [SPARK-46424][PYTHON][SQL][FOLLOW-UP] Refactor and clean up Python metric implementation ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/44375 that refactors and clean up the codes (pointed by review comments). ### Why are the changes needed? For better readability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test cases should cover. I also manually tested as described in https://github.com/apache/spark/pull/44375 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44406 from HyukjinKwon/SPARK-46424-followup. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../python/UserDefinedPythonDataSource.scala | 80 +++++++--------------- 1 file changed, 26 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index f6a473adf08d..d31b3135d65e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -139,17 +139,12 @@ class PythonPartitionReaderFactory( .map(_ -> new SQLMetric("sum", -1)).toMap } - private val outputIter = { - val evaluatorFactory = source.createMapInBatchEvaluatorFactory( - pickledReadFunc, - outputSchema, - metrics, - jobArtifactUUID) - - val part = partition.asInstanceOf[PythonInputPartition] - evaluatorFactory.createEvaluator().eval( - part.index, Iterator.single(InternalRow(part.pickedPartition))) - } + private val outputIter = source.createPartitionReadIteratorInPython( + partition.asInstanceOf[PythonInputPartition], + pickledReadFunc, + outputSchema, + metrics, + jobArtifactUUID) override def next(): Boolean = outputIter.hasNext @@ -164,41 +159,20 @@ class PythonPartitionReaderFactory( } } -class PythonCustomMetric extends CustomMetric { - private var initName: String = _ - private var initDescription: String = _ - def initialize(n: String, d: String): Unit = { - initName = n - initDescription = d - } - override def name(): String = { - assert(initName != null) - initName - } - override def description(): String = { - assert(initDescription != null) - initDescription - } +class PythonCustomMetric( + override val name: String, + override val description: String) extends CustomMetric { + // To allow the aggregation can be called. See `SQLAppStatusListener.aggregateMetrics` + def this() = this(null, null) + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long]) } } -class PythonCustomTaskMetric extends CustomTaskMetric { - private var initName: String = _ - private var initValue: Long = -1L - def initialize(n: String, v: Long): Unit = { - initName = n - initValue = v - } - override def name(): String = { - assert(initName != null) - initName - } - override def value(): Long = { - initValue - } -} +class PythonCustomTaskMetric( + override val name: String, + override val value: Long) extends CustomTaskMetric /** * A user-defined Python data source. This is used by the Python API. @@ -240,11 +214,12 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { /** * (Executor-side) Create an iterator that reads the input partitions. */ - def createMapInBatchEvaluatorFactory( + def createPartitionReadIteratorInPython( + partition: PythonInputPartition, pickledReadFunc: Array[Byte], outputSchema: StructType, metrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = { + jobArtifactUUID: Option[String]): Iterator[InternalRow] = { val readerFunc = createPythonFunction(pickledReadFunc) val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF @@ -260,7 +235,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { val conf = SQLConf.get val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - new MapInBatchEvaluatorFactory( + val evaluatorFactory = new MapInBatchEvaluatorFactory( toAttributes(outputSchema), Seq(ChainedPythonFunctions(Seq(pythonUDF.func))), inputSchema, @@ -271,24 +246,21 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { pythonRunnerConf, metrics, jobArtifactUUID) + + val part = partition + evaluatorFactory.createEvaluator().eval( + part.index, Iterator.single(InternalRow(part.pickedPartition))) } def createPythonMetrics(): Array[CustomMetric] = { // Do not add other metrics such as number of rows, // that is already included via DSv2. - PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) => - val m = new PythonCustomMetric() - m.initialize(k, v) - m - }.toArray + PythonSQLMetrics.pythonSizeMetricsDesc + .map { case (k, v) => new PythonCustomMetric(k, v)}.toArray } def createPythonTaskMetrics(taskMetrics: Map[String, Long]): Array[CustomTaskMetric] = { - taskMetrics.map { case (k, v) => - val m = new PythonCustomTaskMetric() - m.initialize(k, v) - m - }.toArray + taskMetrics.map { case (k, v) => new PythonCustomTaskMetric(k, v)}.toArray } private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org