This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 75479fdecf82 [SPARK-55150][CONNECT][SQL] Improve observation error 
handling
75479fdecf82 is described below

commit 75479fdecf82d71fb0b729ff330ecca6936d7834
Author: Yihong He <[email protected]>
AuthorDate: Thu Jan 29 10:40:11 2026 +0800

    [SPARK-55150][CONNECT][SQL] Improve observation error handling
    
    ### What changes were proposed in this pull request?
    
    This change improves `ObservationManager.tryComplete` error handling:
    1. Changed `Observation.setMetricsAndNotify` to accept `Try[Row]` and use 
`promise.tryComplete` to handle both success and failure cases
    2. Wrapped metric collection in `Try` blocks in both classic and connect 
modes
    3. Added tests for error handling in both modes
    
    ### Why are the changes needed?
    
    When errors occur during observation metric collection (e.g., division by 
zero), they should not fail the query execution itself.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    **Classic mode - Before (Spark 4.1)**
    observation.get is stuck because the metrics collection error causes the 
tryComplete call in Listener to be skipped.
    
    **Classic mode - After**
    `observation.get` now throws a `SparkException` with the underlying  
metrics collection error.
    
    **Connect mode - Before (Spark 4.1)**
    For the Spark Connect query, an error occurs in 
createObservedMetricsResponse during plan execution, causing the plan execution 
to fail.
    
    For the Spark Connect command, `observation.get` returns an empty result if 
errors occur during metric collection.
    
    **Connect mode - After**
    `observation.get` returns an empty result when errors occur during metric 
collection. For connect mode, it is also possible to throw an exception, but 
the protobuf needs to be modified to pass the error from the server to the 
client.
    
    ### How was this patch tested?
    
    `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite -- -z 
SPARK-55150"`
    `build/sbt "sql/testOnly *DatasetSuite -- -z SPARK-55150"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes
    
    Closes #53935 from heyihong/SPARK-55150.
    
    Authored-by: Yihong He <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Observation.scala |  5 +++--
 .../apache/spark/sql/connect/ClientE2ETestSuite.scala | 11 +++++++++++
 .../org/apache/spark/sql/connect/SparkSession.scala   |  3 ++-
 .../apache/spark/sql/classic/ObservationManager.scala |  8 ++++++--
 .../scala/org/apache/spark/sql/DatasetSuite.scala     | 19 ++++++++++++++++++-
 5 files changed, 40 insertions(+), 6 deletions(-)

diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index 3e0a2515d7fe..41b90b934a28 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
 import scala.concurrent.{Future, Promise}
 import scala.concurrent.duration.Duration
 import scala.jdk.CollectionConverters.MapHasAsJava
+import scala.util.Try
 
 import org.apache.spark.util.SparkThreadUtils
 
@@ -118,8 +119,8 @@ class Observation(val name: String) {
    * @return
    *   `true` if all waiting threads were notified, `false` if otherwise.
    */
-  private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
-    promise.trySuccess(metrics)
+  private[sql] def setMetricsAndNotify(metrics: Try[Row]): Boolean = {
+    promise.tryComplete(metrics)
   }
 
   /**
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index 8161e327569a..a352c67b8721 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -1616,6 +1616,17 @@ class ClientE2ETestSuite
       assert(metrics2 === Map("min(extra)" -> -1, "avg(extra)" -> 48, 
"max(extra)" -> 97))
     }
 
+  test("SPARK-55150: observation errors leads to empty result in connect 
mode") {
+    val observation = Observation("test_observation")
+    val observed_df = spark
+      .range(10)
+      .observe(observation, sum("id").as("sum_id"), (sum("id") / 
lit(0)).as("sum_id_div_by_zero"))
+
+    observed_df.collect()
+
+    assert(observation.get.isEmpty)
+  }
+
   test("SPARK-48852: trim function on a string column returns correct 
results") {
     val session: SparkSession = spark
     import session.implicits._
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index 0ca34328c062..be49f96a3958 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -747,7 +747,8 @@ class SparkSession private[sql] (
       // All metrics, whether registered or not, will be collected by 
`SparkResult`.
       val observationOrNull = observationRegistry.remove(metric.getPlanId)
       if (observationOrNull != null) {
-        
observationOrNull.setMetricsAndNotify(SparkResult.transformObservedMetrics(metric))
+        val metricsResult = Try(SparkResult.transformObservedMetrics(metric))
+        observationOrNull.setMetricsAndNotify(metricsResult)
       }
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
index 972f572218ce..764c25809ec9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
@@ -18,6 +18,8 @@ package org.apache.spark.sql.classic
 
 import java.util.concurrent.ConcurrentHashMap
 
+import scala.util.Try
+
 import org.apache.spark.sql.{Observation, Row}
 import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
 import org.apache.spark.sql.catalyst.trees.TreePattern
@@ -54,7 +56,8 @@ private[sql] class ObservationManager(session: SparkSession) {
   private[sql] def tryComplete(qe: QueryExecution): Unit = {
     // Use lazy val to defer collecting the observed metrics until it is 
needed so that tryComplete
     // can finish faster (e.g., when the logical plan doesn't contain 
CollectMetrics).
-    lazy val lazyObservedMetrics = qe.observedMetrics
+    // Wrap in Try to capture potential failures when collecting metrics.
+    lazy val lazyObservedMetrics = Try(qe.observedMetrics)
     qe.logical.foreachWithSubqueriesAndPruning(
       _.containsPattern(TreePattern.COLLECT_METRICS)) {
       case c: CollectMetrics =>
@@ -63,7 +66,8 @@ private[sql] class ObservationManager(session: SparkSession) {
           // If the key exists but no metrics were collected, it means for 
some reason the
           // metrics could not be collected. This can happen e.g., if the 
CollectMetricsExec
           // was optimized away.
-          
observation.setMetricsAndNotify(lazyObservedMetrics.getOrElse(c.name, 
Row.empty))
+          val metricsResult = lazyObservedMetrics.map(_.getOrElse(c.name, 
Row.empty))
+          observation.setMetricsAndNotify(metricsResult)
         }
       case _ =>
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index f359fb98be3a..1f9a30b1cc44 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -31,7 +31,7 @@ import org.scalatest.Assertions._
 import org.scalatest.exceptions.TestFailedException
 import org.scalatest.prop.TableDrivenPropertyChecks._
 
-import org.apache.spark.{SparkConf, SparkRuntimeException, 
SparkUnsupportedOperationException, TaskContext}
+import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException, 
SparkUnsupportedOperationException, TaskContext}
 import org.apache.spark.TestUtils.withListener
 import org.apache.spark.internal.config.MAX_RESULT_SIZE
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
@@ -1159,6 +1159,23 @@ class DatasetSuite extends QueryTest
     assert(namedObservation2.get === expected2)
   }
 
+  test("SPARK-55150: observation errors are threw in Obseravtion.get in 
classic mode") {
+    val observation = Observation("test_observation")
+    val observed_df = spark.range(10).observe(
+      observation,
+      sum($"id").as("sum_id"),
+      (sum($"id") / lit(0)).as("sum_id_div_by_zero")
+    )
+
+    observed_df.collect()
+
+    val exception = intercept[SparkException] {
+      observation.get
+    }
+
+    assert(exception.getCause.getMessage.contains("DIVIDE_BY_ZERO"))
+  }
+
   test("sample with replacement") {
     val n = 100
     val data = sparkContext.parallelize(1 to n, 2).toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to