This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5cbad6aef2fd [SPARK-46424][PYTHON][SQL][FOLLOW-UP] Refactor and clean 
up Python metric implementation
5cbad6aef2fd is described below

commit 5cbad6aef2fd1214014d64ef602f9a726a019d99
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Tue Dec 19 13:58:55 2023 +0900

    [SPARK-46424][PYTHON][SQL][FOLLOW-UP] Refactor and clean up Python metric 
implementation
    
    ### What changes were proposed in this pull request?
    
    This PR is a followup of https://github.com/apache/spark/pull/44375 that 
refactors and clean up the codes (pointed by review comments).
    
    ### Why are the changes needed?
    
    For better readability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test cases should cover. I also manually tested as described in 
https://github.com/apache/spark/pull/44375
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44406 from HyukjinKwon/SPARK-46424-followup.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../python/UserDefinedPythonDataSource.scala       | 80 +++++++---------------
 1 file changed, 26 insertions(+), 54 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index f6a473adf08d..d31b3135d65e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -139,17 +139,12 @@ class PythonPartitionReaderFactory(
           .map(_ -> new SQLMetric("sum", -1)).toMap
       }
 
-      private val outputIter = {
-        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
-          pickledReadFunc,
-          outputSchema,
-          metrics,
-          jobArtifactUUID)
-
-        val part = partition.asInstanceOf[PythonInputPartition]
-        evaluatorFactory.createEvaluator().eval(
-          part.index, Iterator.single(InternalRow(part.pickedPartition)))
-      }
+      private val outputIter = source.createPartitionReadIteratorInPython(
+        partition.asInstanceOf[PythonInputPartition],
+        pickledReadFunc,
+        outputSchema,
+        metrics,
+        jobArtifactUUID)
 
       override def next(): Boolean = outputIter.hasNext
 
@@ -164,41 +159,20 @@ class PythonPartitionReaderFactory(
   }
 }
 
-class PythonCustomMetric extends CustomMetric {
-  private var initName: String = _
-  private var initDescription: String = _
-  def initialize(n: String, d: String): Unit = {
-    initName = n
-    initDescription = d
-  }
-  override def name(): String = {
-    assert(initName != null)
-    initName
-  }
-  override def description(): String = {
-    assert(initDescription != null)
-    initDescription
-  }
+class PythonCustomMetric(
+    override val name: String,
+    override val description: String) extends CustomMetric {
+  // To allow the aggregation can be called. See 
`SQLAppStatusListener.aggregateMetrics`
+  def this() = this(null, null)
+
   override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
     SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
   }
 }
 
-class PythonCustomTaskMetric extends CustomTaskMetric {
-  private var initName: String = _
-  private var initValue: Long = -1L
-  def initialize(n: String, v: Long): Unit = {
-    initName = n
-    initValue = v
-  }
-  override def name(): String = {
-    assert(initName != null)
-    initName
-  }
-  override def value(): Long = {
-    initValue
-  }
-}
+class PythonCustomTaskMetric(
+    override val name: String,
+    override val value: Long) extends CustomTaskMetric
 
 /**
  * A user-defined Python data source. This is used by the Python API.
@@ -240,11 +214,12 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
   /**
    * (Executor-side) Create an iterator that reads the input partitions.
    */
-  def createMapInBatchEvaluatorFactory(
+  def createPartitionReadIteratorInPython(
+      partition: PythonInputPartition,
       pickledReadFunc: Array[Byte],
       outputSchema: StructType,
       metrics: Map[String, SQLMetric],
-      jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
+      jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
     val readerFunc = createPythonFunction(pickledReadFunc)
 
     val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
@@ -260,7 +235,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
     val conf = SQLConf.get
 
     val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-    new MapInBatchEvaluatorFactory(
+    val evaluatorFactory = new MapInBatchEvaluatorFactory(
       toAttributes(outputSchema),
       Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
       inputSchema,
@@ -271,24 +246,21 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       pythonRunnerConf,
       metrics,
       jobArtifactUUID)
+
+    val part = partition
+    evaluatorFactory.createEvaluator().eval(
+      part.index, Iterator.single(InternalRow(part.pickedPartition)))
   }
 
   def createPythonMetrics(): Array[CustomMetric] = {
     // Do not add other metrics such as number of rows,
     // that is already included via DSv2.
-    PythonSQLMetrics.pythonSizeMetricsDesc.map { case (k, v) =>
-      val m = new PythonCustomMetric()
-      m.initialize(k, v)
-      m
-    }.toArray
+    PythonSQLMetrics.pythonSizeMetricsDesc
+      .map { case (k, v) => new PythonCustomMetric(k, v)}.toArray
   }
 
   def createPythonTaskMetrics(taskMetrics: Map[String, Long]): 
Array[CustomTaskMetric] = {
-    taskMetrics.map { case (k, v) =>
-      val m = new PythonCustomTaskMetric()
-      m.initialize(k, v)
-      m
-    }.toArray
+    taskMetrics.map { case (k, v) => new PythonCustomTaskMetric(k, v)}.toArray
   }
 
   private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = 
{


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to