This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d0e2c06e874a [SPARK-50372][CONNECT][SQL] Make all DF execution path 
collect observed metrics
d0e2c06e874a is described below

commit d0e2c06e874a2a61fb95450bbb9085be2ba2c167
Author: Paddy Xu <[email protected]>
AuthorDate: Tue Dec 3 08:21:53 2024 +0900

    [SPARK-50372][CONNECT][SQL] Make all DF execution path collect observed 
metrics
    
    ### What changes were proposed in this pull request?
    
    This PR fixes an issue that some of DataFrame execution paths would not 
process `ObservedMetrics`. The fix is done by injecting a lazy processing logic 
into the result iterator.
    
    The following private execution APIs are affected by this issue:
    
    - `SparkSession.execute(proto.Relation.Builder)`
    - `SparkSession.execute(proto.Command)`
    - `SparkSession.execute(proto.Plan)`
    
    The following user-facing API is affected by this issue:
    - `DataFrame.write.format("...").mode("...").save()`
    
    This PR also fixes an issue in which on the Server side, two observed 
metrics can be assigned to the same Plan ID when they are in the same plan 
(e.g., one observation is used as the input of another). The fix is to traverse 
the plan and find all observations with correct IDs.
    
    Another bug is discovered as a byproduct of introducing a new test case. 
Copying the PR comment here from SparkConnectPlanner.scala:
    
    > This fixes a bug where the input of a `CollectMetrics` can be processed 
two times, once in Line 1190 and once here/below.
    >
    > When the `input` contains another `CollectMetrics`, transforming it twice 
will cause two `Observation` objects (in the input) to be initialised and 
registered two times to the system. Since only one of them will be fulfilled 
when the query finishes, the one we'll be looking at may not have any data.
    >
    > This issue is highlighted in the test case `Observation.get is blocked 
until the query is finished ...`, where we specifically execute 
`observedObservedDf`, which is a `CollectMetrics` that has another 
`CollectMetrics` as its input.
    
    ### Why are the changes needed?
    
    To fix a bug.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this bug is user-facing.
    
    ### How was this patch tested?
    
    New tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48920 from xupefei/observation-notify-fix.
    
    Authored-by: Paddy Xu <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  | 45 +++++++++++------
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  | 59 +++++++++++++++-------
 .../sql/connect/client/CloseableIterator.scala     | 10 ++++
 .../spark/sql/connect/client/SparkResult.scala     | 41 ++++++++-------
 .../connect/execution/ExecuteThreadRunner.scala    |  2 +-
 .../execution/SparkConnectPlanExecution.scala      | 14 ++---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  4 +-
 .../spark/sql/connect/service/ExecuteHolder.scala  | 26 ++++++++++
 8 files changed, 137 insertions(+), 64 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index b74d0c2ff224..3183a155c16a 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -34,6 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi, 
Experimental, Since}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalog.Catalog
@@ -371,13 +372,8 @@ class SparkSession private[sql] (
   private[sql] def timeZoneId: String = 
conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)
 
   private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): 
SparkResult[T] = {
-    val value = client.execute(plan)
-    new SparkResult(
-      value,
-      allocator,
-      encoder,
-      timeZoneId,
-      Some(setMetricsAndUnregisterObservation))
+    val value = executeInternal(plan)
+    new SparkResult(value, allocator, encoder, timeZoneId)
   }
 
   private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
@@ -386,7 +382,7 @@ class SparkSession private[sql] (
     builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
     val plan = proto.Plan.newBuilder().setRoot(builder).build()
     // .foreach forces that the iterator is consumed and closed
-    client.execute(plan).foreach(_ => ())
+    executeInternal(plan).foreach(_ => ())
   }
 
   @Since("4.0.0")
@@ -395,11 +391,26 @@ class SparkSession private[sql] (
     val plan = proto.Plan.newBuilder().setCommand(command).build()
     // .toSeq forces that the iterator is consumed and closed. On top, ignore 
all
     // progress messages.
-    client.execute(plan).filter(!_.hasExecutionProgress).toSeq
+    executeInternal(plan).filter(!_.hasExecutionProgress).toSeq
   }
 
-  private[sql] def execute(plan: proto.Plan): 
CloseableIterator[ExecutePlanResponse] =
-    client.execute(plan)
+  /**
+   * The real `execute` method that calls into `SparkConnectClient`.
+   *
+   * Here we inject a lazy map to process registered observed metrics, so 
consumers of the
+   * returned iterator does not need to worry about it.
+   *
+   * Please make sure all `execute` methods call this method.
+   */
+  private[sql] def executeInternal(plan: proto.Plan): 
CloseableIterator[ExecutePlanResponse] = {
+    client
+      .execute(plan)
+      .map { response =>
+        // Note, this map() is lazy.
+        processRegisteredObservedMetrics(response.getObservedMetricsList)
+        response
+      }
+  }
 
   private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): 
Unit = {
     val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
@@ -541,10 +552,14 @@ class SparkSession private[sql] (
     observationRegistry.putIfAbsent(planId, observation)
   }
 
-  private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: 
Row): Unit = {
-    val observationOrNull = observationRegistry.remove(planId)
-    if (observationOrNull != null) {
-      observationOrNull.setMetricsAndNotify(metrics)
+  private def processRegisteredObservedMetrics(metrics: 
java.util.List[ObservedMetrics]): Unit = {
+    metrics.asScala.map { metric =>
+      // Here we only process metrics that belong to a registered Observation 
object.
+      // All metrics, whether registered or not, will be collected by 
`SparkResult`.
+      val observationOrNull = observationRegistry.remove(metric.getPlanId)
+      if (observationOrNull != null) {
+        
observationOrNull.setMetricsAndNotify(SparkResult.transformObservedMetrics(metric))
+      }
     }
   }
 
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 0371981b728d..92b5808f4d62 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -1536,28 +1536,49 @@ class ClientE2ETestSuite
     val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), 
ob1Schema))
     val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), 
ob2Schema))
 
+    val obMetrics = observedDf.collectResult().getObservedMetrics
     assert(df.collectResult().getObservedMetrics === Map.empty)
     assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
-    assert(observedObservedDf.collectResult().getObservedMetrics === 
ob1Metrics ++ ob2Metrics)
-  }
-
-  test("Observation.get is blocked until the query is finished") {
-    val df = spark.range(99).withColumn("extra", col("id") - 1)
-    val observation = new Observation("ob1")
-    val observedDf = df.observe(observation, min("id"), avg("id"), max("id"))
-
-    // Start a new thread to get the observation
-    val future = Future(observation.get)(ExecutionContext.global)
-    // make sure the thread is blocked right now
-    val e = intercept[java.util.concurrent.TimeoutException] {
-      SparkThreadUtils.awaitResult(future, 2.seconds)
+    assert(obMetrics.map(_._2.schema) === Seq(ob1Schema))
+
+    val obObMetrics = observedObservedDf.collectResult().getObservedMetrics
+    assert(obObMetrics === ob1Metrics ++ ob2Metrics)
+    assert(obObMetrics.map(_._2.schema).exists(_.equals(ob1Schema)))
+    assert(obObMetrics.map(_._2.schema).exists(_.equals(ob2Schema)))
+  }
+
+  for (collectFunc <- Seq(
+      ("collect", (df: DataFrame) => df.collect()),
+      ("collectAsList", (df: DataFrame) => df.collectAsList()),
+      ("collectResult", (df: DataFrame) => df.collectResult().length),
+      ("write", (df: DataFrame) => 
df.write.format("noop").mode("append").save())))
+    test(
+      "Observation.get is blocked until the query is finished, " +
+        s"collect using method ${collectFunc._1}") {
+      val df = spark.range(99).withColumn("extra", col("id") - 1)
+      val ob1 = new Observation("ob1")
+      val ob2 = new Observation("ob2")
+      val observedDf = df.observe(ob1, min("id"), avg("id"), max("id"))
+      val observedObservedDf = observedDf.observe(ob2, min("extra"), 
avg("extra"), max("extra"))
+      // Start new threads to get observations
+      val future1 = Future(ob1.get)(ExecutionContext.global)
+      val future2 = Future(ob2.get)(ExecutionContext.global)
+      // make sure the threads are blocked right now
+      val e1 = intercept[java.util.concurrent.TimeoutException] {
+        SparkThreadUtils.awaitResult(future1, 2.seconds)
+      }
+      assert(e1.getMessage.contains("timed out after"))
+      val e2 = intercept[java.util.concurrent.TimeoutException] {
+        SparkThreadUtils.awaitResult(future2, 2.seconds)
+      }
+      assert(e2.getMessage.contains("timed out after"))
+      collectFunc._2(observedObservedDf)
+      // make sure the threads are unblocked after the query is finished
+      val metrics1 = SparkThreadUtils.awaitResult(future1, 5.seconds)
+      assert(metrics1 === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 
98))
+      val metrics2 = SparkThreadUtils.awaitResult(future2, 5.seconds)
+      assert(metrics2 === Map("min(extra)" -> -1, "avg(extra)" -> 48, 
"max(extra)" -> 97))
     }
-    assert(e.getMessage.contains("Future timed out"))
-    observedDf.collect()
-    // make sure the thread is unblocked after the query is finished
-    val metrics = SparkThreadUtils.awaitResult(future, 2.seconds)
-    assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
-  }
 
   test("SPARK-48852: trim function on a string column returns correct 
results") {
     val session: SparkSession = spark
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
index 4ec6828d885a..9de585503a50 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
@@ -25,6 +25,16 @@ private[sql] trait CloseableIterator[E] extends Iterator[E] 
with AutoCloseable {
 
     override def close() = self.close()
   }
+
+  override def map[B](f: E => B): CloseableIterator[B] = {
+    new CloseableIterator[B] {
+      override def next(): B = f(self.next())
+
+      override def hasNext: Boolean = self.hasNext
+
+      override def close(): Unit = self.close()
+    }
+  }
 }
 
 private[sql] abstract class WrappedCloseableIterator[E] extends 
CloseableIterator[E] {
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 3aad90e96f8c..959779b357c2 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -41,8 +41,7 @@ private[sql] class SparkResult[T](
     responses: CloseableIterator[proto.ExecutePlanResponse],
     allocator: BufferAllocator,
     encoder: AgnosticEncoder[T],
-    timeZoneId: String,
-    setObservationMetricsOpt: Option[(Long, Row) => Unit] = None)
+    timeZoneId: String)
     extends AutoCloseable { self =>
 
   case class StageInfo(
@@ -122,7 +121,8 @@ private[sql] class SparkResult[T](
     while (!stop && responses.hasNext) {
       val response = responses.next()
 
-      // Collect metrics for this response
+      // Collect **all** metrics for this response, whether or not registered 
to an Observation
+      // object.
       observedMetrics ++= 
processObservedMetrics(response.getObservedMetricsList)
 
       // Save and validate operationId
@@ -209,23 +209,7 @@ private[sql] class SparkResult[T](
   private def processObservedMetrics(
       metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
     metrics.asScala.map { metric =>
-      assert(metric.getKeysCount == metric.getValuesCount)
-      var schema = new StructType()
-      val values = mutable.ArrayBuilder.make[Any]
-      values.sizeHint(metric.getKeysCount)
-      (0 until metric.getKeysCount).foreach { i =>
-        val key = metric.getKeys(i)
-        val value = 
LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
-        schema = schema.add(key, 
LiteralValueProtoConverter.toDataType(value.getClass))
-        values += value
-      }
-      val row = new GenericRowWithSchema(values.result(), schema)
-      // If the metrics is registered by an Observation object, attach them 
and unblock any
-      // blocked thread.
-      setObservationMetricsOpt.foreach { setObservationMetrics =>
-        setObservationMetrics(metric.getPlanId, row)
-      }
-      metric.getName -> row
+      metric.getName -> SparkResult.transformObservedMetrics(metric)
     }
   }
 
@@ -387,8 +371,23 @@ private[sql] class SparkResult[T](
   }
 }
 
-private object SparkResult {
+private[sql] object SparkResult {
   private val cleaner: Cleaner = Cleaner.create()
+
+  /** Return value is a Seq of pairs, to preserve the order of values. */
+  private[sql] def transformObservedMetrics(metric: ObservedMetrics): Row = {
+    assert(metric.getKeysCount == metric.getValuesCount)
+    var schema = new StructType()
+    val values = mutable.ArrayBuilder.make[Any]
+    values.sizeHint(metric.getKeysCount)
+    (0 until metric.getKeysCount).foreach { i =>
+      val key = metric.getKeys(i)
+      val value = 
LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
+      schema = schema.add(key, 
LiteralValueProtoConverter.toDataType(value.getClass))
+      values += value
+    }
+    new GenericRowWithSchema(values.result(), schema)
+  }
 }
 
 private[client] class SparkResultCloseable(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index d27f390a23f9..05e3395a5316 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -245,7 +245,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
             .createObservedMetricsResponse(
               executeHolder.sessionHolder.sessionId,
               executeHolder.sessionHolder.serverSessionId,
-              executeHolder.request.getPlan.getRoot.getCommon.getPlanId,
+              executeHolder.allObservationAndPlanIds,
               observedMetrics ++ accumulatedInPython))
       }
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index c0fd00b2eeaa..5e3499573e9d 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -77,8 +77,10 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
     responseObserver.onNext(createSchemaResponse(request.getSessionId, 
dataframe.schema))
     processAsArrowBatches(dataframe, responseObserver, executeHolder)
     
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, 
dataframe))
-    createObservedMetricsResponse(request.getSessionId, dataframe).foreach(
-      responseObserver.onNext)
+    createObservedMetricsResponse(
+      request.getSessionId,
+      executeHolder.allObservationAndPlanIds,
+      dataframe).foreach(responseObserver.onNext)
   }
 
   type Batch = (Array[Byte], Long)
@@ -255,6 +257,7 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
 
   private def createObservedMetricsResponse(
       sessionId: String,
+      observationAndPlanIds: Map[String, Long],
       dataframe: DataFrame): Option[ExecutePlanResponse] = {
     val observedMetrics = dataframe.queryExecution.observedMetrics.collect {
       case (name, row) if !executeHolder.observations.contains(name) =>
@@ -264,13 +267,12 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
         name -> values
     }
     if (observedMetrics.nonEmpty) {
-      val planId = executeHolder.request.getPlan.getRoot.getCommon.getPlanId
       Some(
         SparkConnectPlanExecution
           .createObservedMetricsResponse(
             sessionId,
             sessionHolder.serverSessionId,
-            planId,
+            observationAndPlanIds,
             observedMetrics))
     } else None
   }
@@ -280,17 +282,17 @@ object SparkConnectPlanExecution {
   def createObservedMetricsResponse(
       sessionId: String,
       serverSessionId: String,
-      planId: Long,
+      observationAndPlanIds: Map[String, Long],
       metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = 
{
     val observedMetrics = metrics.map { case (name, values) =>
       val metrics = ExecutePlanResponse.ObservedMetrics
         .newBuilder()
         .setName(name)
-        .setPlanId(planId)
       values.foreach { case (key, value) =>
         metrics.addValues(toLiteralProto(value))
         key.foreach(metrics.addKeys)
       }
+      observationAndPlanIds.get(name).foreach(metrics.setPlanId)
       metrics.build()
     }
     // Prepare a response with the observed metrics.
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 979fd83612e7..ee030a52b221 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1190,14 +1190,14 @@ class SparkConnectPlanner(
     val input = transformRelation(rel.getInput)
 
     if (input.isStreaming || executeHolderOpt.isEmpty) {
-      CollectMetrics(name, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
+      CollectMetrics(name, metrics.map(_.named), input, planId)
     } else {
       // TODO this might be too complex for no good reason. It might
       //  be easier to inspect the plan after it completes.
       val observation = Observation(name)
       session.observationManager.register(observation, planId)
       executeHolderOpt.get.addObservation(name, observation)
-      CollectMetrics(name, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
+      CollectMetrics(name, metrics.map(_.named), input, planId)
     }
   }
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 821ddb2c85d5..94638151f7f1 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicBoolean
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
+import com.google.protobuf.GeneratedMessage
+
 import org.apache.spark.SparkEnv
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
@@ -81,6 +83,10 @@ private[connect] class ExecuteHolder(
 
   val observations: mutable.Map[String, Observation] = mutable.Map.empty
 
+  lazy val allObservationAndPlanIds: Map[String, Long] = {
+    ExecuteHolder.collectAllObservationAndPlanIds(request.getPlan).toMap
+  }
+
   private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this)
 
   /** System.currentTimeMillis when this ExecuteHolder was created. */
@@ -289,6 +295,26 @@ private[connect] class ExecuteHolder(
   def operationId: String = key.operationId
 }
 
+private object ExecuteHolder {
+  private def collectAllObservationAndPlanIds(
+      planOrMessage: GeneratedMessage,
+      collected: mutable.Map[String, Long] = mutable.Map.empty): 
mutable.Map[String, Long] = {
+    planOrMessage match {
+      case relation: proto.Relation if relation.hasCollectMetrics =>
+        collected += relation.getCollectMetrics.getName -> 
relation.getCommon.getPlanId
+        collectAllObservationAndPlanIds(relation.getCollectMetrics.getInput, 
collected)
+      case _ =>
+        planOrMessage.getAllFields.values().asScala.foreach {
+          case message: GeneratedMessage =>
+            collectAllObservationAndPlanIds(message, collected)
+          case _ =>
+          // not a message (probably a primitive type), do nothing
+        }
+    }
+    collected
+  }
+}
+
 /** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */
 object ExecuteJobTag {
   private val prefix = "SparkConnect_OperationTag"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to