This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 528ba0e0db7d [SPARK-49423][CONNECT][SQL] Consolidate Observation in 
sql/api
528ba0e0db7d is described below

commit 528ba0e0db7d6f72317063c04c2dd6ddadfcfaab
Author: Herman van Hovell <[email protected]>
AuthorDate: Fri Aug 30 08:01:22 2024 -0400

    [SPARK-49423][CONNECT][SQL] Consolidate Observation in sql/api
    
    ### What changes were proposed in this pull request?
    This PR moves Observation into sql/api. For classic I moved the wiring into 
the `observe` method itself, and the required listener is now part of the 
`org.apache.spark.sql.internal` package. I have also take the liberty to get 
rid of most of the homegrown locking, and I have replaced it with a promise.
    
    ### Why are the changes needed?
    We are creating a shared interface for the classic and connect Scala 
DataFrame API. This class is part of that API.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47921 from hvanhovell/SPARK-49423.
    
    Authored-by: Herman van Hovell <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  23 +---
 .../scala/org/apache/spark/sql/Observation.scala   |  46 -------
 .../scala/org/apache/spark/sql/SparkSession.scala  |  11 +-
 .../CheckConnectJvmClientCompatibility.scala       |   2 -
 project/MimaExcludes.scala                         |   4 +
 python/pyspark/sql/observation.py                  |   6 +-
 .../{ObservationBase.scala => Observation.scala}   |  87 ++++++++-----
 .../scala/org/apache/spark/sql/api/Dataset.scala   |  27 +++-
 .../spark/sql/connect/client/SparkResult.scala     |  14 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |   4 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  25 +---
 .../scala/org/apache/spark/sql/Observation.scala   | 143 ---------------------
 .../scala/org/apache/spark/sql/SparkSession.scala  |   2 +
 .../spark/sql/internal/ObservationManager.scala    |  69 ++++++++++
 14 files changed, 173 insertions(+), 290 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index d05834c4fc6c..3b10978b7c8b 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1201,28 +1201,7 @@ class Dataset[T] private[sql] (
     }
   }
 
-  /**
-   * Observe (named) metrics through an `org.apache.spark.sql.Observation` 
instance. This is
-   * equivalent to calling `observe(String, Column, Column*)` but does not 
require to collect all
-   * results before returning the metrics - the metrics are filled during 
iterating the results,
-   * as soon as they are available. This method does not support streaming 
datasets.
-   *
-   * A user can retrieve the metrics by accessing 
`org.apache.spark.sql.Observation.get`.
-   *
-   * {{{
-   *   // Observe row count (rows) and highest id (maxid) in the Dataset while 
writing it
-   *   val observation = Observation("my_metrics")
-   *   val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), 
max($"id").as("maxid"))
-   *   observed_ds.write.parquet("ds.parquet")
-   *   val metrics = observation.get
-   * }}}
-   *
-   * @throws IllegalArgumentException
-   *   If this is a streaming Dataset (this.isStreaming == true)
-   *
-   * @group typedrel
-   * @since 4.0.0
-   */
+  /** @inheritdoc */
   @scala.annotation.varargs
   def observe(observation: Observation, expr: Column, exprs: Column*): 
Dataset[T] = {
     val df = observe(observation.name, expr, exprs: _*)
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
deleted file mode 100644
index 75629b6000f9..000000000000
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.util.UUID
-
-class Observation(name: String) extends ObservationBase(name) {
-
-  /**
-   * Create an Observation instance without providing a name. This generates a 
random name.
-   */
-  def this() = this(UUID.randomUUID().toString)
-}
-
-/**
- * (Scala-specific) Create instances of Observation via Scala `apply`.
- * @since 4.0.0
- */
-object Observation {
-
-  /**
-   * Observation constructor for creating an anonymous observation.
-   */
-  def apply(): Observation = new Observation()
-
-  /**
-   * Observation constructor for creating a named observation.
-   */
-  def apply(name: String): Observation = new Observation(name)
-
-}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 3837db00acc6..24d0a5ac7262 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -569,17 +569,14 @@ class SparkSession private[sql] (
   private[sql] var releaseSessionOnClose = true
 
   private[sql] def registerObservation(planId: Long, observation: 
Observation): Unit = {
-    if (observationRegistry.putIfAbsent(planId, observation) != null) {
-      throw new IllegalArgumentException("An Observation can be used with a 
Dataset only once")
-    }
+    observation.markRegistered()
+    observationRegistry.putIfAbsent(planId, observation)
   }
 
-  private[sql] def setMetricsAndUnregisterObservation(
-      planId: Long,
-      metrics: Map[String, Any]): Unit = {
+  private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: 
Row): Unit = {
     val observationOrNull = observationRegistry.remove(planId)
     if (observationOrNull != null) {
-      observationOrNull.setMetricsAndNotify(Some(metrics))
+      observationOrNull.setMetricsAndNotify(metrics)
     }
   }
 
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index af9168339dcf..10b31155376f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[MissingClassProblem](
         "org.apache.spark.sql.Dataset$" // private[sql]
       ),
-      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
-      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
       // TODO (SPARK-49096):
       // Mima check might complain the following Dataset rules does not filter 
any problem.
       // This is due to a potential bug in Mima that all methods in `class 
Dataset` are not being
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 2fc0725df5bc..fe4a08971509 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -120,6 +120,10 @@ object MimaExcludes {
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.Window"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.Window$"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.WindowSpec"),
+
+    // SPARK-49423: Consolidate Observation in sql/api
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"),
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"),
   )
 
   // Default exclude rules
diff --git a/python/pyspark/sql/observation.py 
b/python/pyspark/sql/observation.py
index 5f26b439b048..6ceb6bc90327 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -124,10 +124,8 @@ class Observation:
         assert self._jvm is not None
         cls = self._jvm.org.apache.spark.sql.Observation
         self._jo = cls(self._name) if self._name is not None else cls()
-        observed_df = self._jo.on(
-            df._jdf,
-            exprs[0]._jc,
-            _to_seq(df._sc, [c._jc for c in exprs[1:]]),
+        observed_df = df._jdf.observe(
+            self._jo, exprs[0]._jc, _to_seq(df._sc, [c._jc for c in exprs[1:]])
         )
         return DataFrame(observed_df, df.sparkSession)
 
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
similarity index 62%
rename from sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index 4789ae8975d1..02f5a8de1e3f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -17,7 +17,15 @@
 
 package org.apache.spark.sql
 
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.concurrent.{Future, Promise}
+import scala.concurrent.duration.{Duration, DurationInt}
 import scala.jdk.CollectionConverters.MapHasAsJava
+import scala.util.Try
+
+import org.apache.spark.util.SparkThreadUtils
 
 /**
  * Helper class to simplify usage of `Dataset.observe(String, Column, 
Column*)`:
@@ -39,11 +47,22 @@ import scala.jdk.CollectionConverters.MapHasAsJava
  * @param name name of the metric
  * @since 3.3.0
  */
-abstract class ObservationBase(val name: String) {
+class Observation(val name: String) {
+  require(name.nonEmpty, "Name must not be empty")
 
-  if (name.isEmpty) throw new IllegalArgumentException("Name must not be 
empty")
+  /**
+   * Create an Observation with a random name.
+   */
+  def this() = this(UUID.randomUUID().toString)
 
-  @volatile protected var metrics: Option[Map[String, Any]] = None
+  private val isRegistered = new AtomicBoolean()
+
+  private val promise = Promise[Map[String, Any]]()
+
+  /**
+   * Future holding the (yet to be completed) observation.
+   */
+  val future: Future[Map[String, Any]] = promise.future
 
   /**
    * (Scala-specific) Get the observed metrics. This waits for the observed 
dataset to finish
@@ -54,17 +73,7 @@ abstract class ObservationBase(val name: String) {
    * @throws InterruptedException interrupted while waiting
    */
   @throws[InterruptedException]
-  def get: Map[String, _] = {
-    synchronized {
-      // we need to loop as wait might return without us calling notify
-      // 
https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610
-      while (this.metrics.isEmpty) {
-        wait()
-      }
-    }
-
-    this.metrics.get
-  }
+  def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, 
Duration.Inf)
 
   /**
    * (Java-specific) Get the observed metrics. This waits for the observed 
dataset to finish
@@ -75,9 +84,7 @@ abstract class ObservationBase(val name: String) {
    * @throws InterruptedException interrupted while waiting
    */
   @throws[InterruptedException]
-  def getAsJava: java.util.Map[String, AnyRef] = {
-    get.map { case (key, value) => (key, value.asInstanceOf[Object]) }.asJava
-  }
+  def getAsJava: java.util.Map[String, Any] = get.asJava
 
   /**
    * Get the observed metrics. This returns the metrics if they are available, 
otherwise an empty.
@@ -85,12 +92,16 @@ abstract class ObservationBase(val name: String) {
    * @return the observed metrics as a `Map[String, Any]`
    */
   @throws[InterruptedException]
-  private[sql] def getOrEmpty: Map[String, _] = {
-    synchronized {
-      if (metrics.isEmpty) {
-        wait(100) // Wait for 100ms to see if metrics are available
-      }
-      metrics.getOrElse(Map.empty)
+  private[sql] def getOrEmpty: Map[String, Any] = {
+    Try(SparkThreadUtils.awaitResult(future, 100.millis)).getOrElse(Map.empty)
+  }
+
+  /**
+   * Mark this Observation as registered.
+   */
+  private[sql] def markRegistered(): Unit = {
+    if (!isRegistered.compareAndSet(false, true)) {
+      throw new IllegalArgumentException("An Observation can be used with a 
Dataset only once")
     }
   }
 
@@ -99,15 +110,25 @@ abstract class ObservationBase(val name: String) {
    *
    * @return `true` if all waiting threads were notified, `false` if otherwise.
    */
-  private[spark] def setMetricsAndNotify(metrics: Option[Map[String, Any]]): 
Boolean = {
-    synchronized {
-      this.metrics = metrics
-      if(metrics.isDefined) {
-        notifyAll()
-        true
-      } else {
-        false
-      }
-    }
+  private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
+    val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
+    promise.trySuccess(metricsMap)
   }
 }
+
+/**
+ * (Scala-specific) Create instances of Observation via Scala `apply`.
+ * @since 3.3.0
+ */
+object Observation {
+
+  /**
+   * Observation constructor for creating an anonymous observation.
+   */
+  def apply(): Observation = new Observation()
+
+  /**
+   * Observation constructor for creating a named observation.
+   */
+  def apply(name: String): Observation = new Observation(name)
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 2b071a384e0a..a5e125733a29 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -16,7 +16,6 @@
  */
 package org.apache.spark.sql.api
 
-import scala.annotation.varargs
 import scala.jdk.CollectionConverters._
 import scala.reflect.runtime.universe.TypeTag
 
@@ -24,7 +23,7 @@ import _root_.java.util
 
 import org.apache.spark.annotation.{DeveloperApi, Stable}
 import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, 
ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, 
ReduceFunction}
-import org.apache.spark.sql.{functions, AnalysisException, Column, Encoder, 
Row, TypedColumn}
+import org.apache.spark.sql.{functions, AnalysisException, Column, Encoder, 
Observation, Row, TypedColumn}
 import org.apache.spark.sql.types.{Metadata, StructType}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.ArrayImplicits._
@@ -1518,9 +1517,31 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] 
extends Serializable {
   * @group typedrel
   * @since 3.0.0
   */
-  @varargs
+  @scala.annotation.varargs
   def observe(name: String, expr: Column, exprs: Column*): DS[T]
 
+  /**
+   * Observe (named) metrics through an `org.apache.spark.sql.Observation` 
instance. This method
+   * does not support streaming datasets.
+   *
+   * A user can retrieve the metrics by accessing 
`org.apache.spark.sql.Observation.get`.
+   *
+   * {{{
+   *   // Observe row count (rows) and highest id (maxid) in the Dataset while 
writing it
+   *   val observation = Observation("my_metrics")
+   *   val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), 
max($"id").as("maxid"))
+   *   observed_ds.write.parquet("ds.parquet")
+   *   val metrics = observation.get
+   * }}}
+   *
+   * @throws IllegalArgumentException If this is a streaming Dataset 
(this.isStreaming == true)
+   *
+   * @group typedrel
+   * @since 3.3.0
+   */
+  @scala.annotation.varargs
+  def observe(observation: Observation, expr: Column, exprs: Column*): DS[T]
+
   /**
    * Returns a new Dataset by taking the first `n` rows. The difference 
between this function
    * and `head` is that `head` is an action and returns an array (by 
triggering query execution)
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 0905ee76c3f3..3aad90e96f8c 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -42,7 +42,7 @@ private[sql] class SparkResult[T](
     allocator: BufferAllocator,
     encoder: AgnosticEncoder[T],
     timeZoneId: String,
-    setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None)
+    setObservationMetricsOpt: Option[(Long, Row) => Unit] = None)
     extends AutoCloseable { self =>
 
   case class StageInfo(
@@ -211,21 +211,21 @@ private[sql] class SparkResult[T](
     metrics.asScala.map { metric =>
       assert(metric.getKeysCount == metric.getValuesCount)
       var schema = new StructType()
-      val keys = mutable.ListBuffer.empty[String]
-      val values = mutable.ListBuffer.empty[Any]
-      (0 until metric.getKeysCount).map { i =>
+      val values = mutable.ArrayBuilder.make[Any]
+      values.sizeHint(metric.getKeysCount)
+      (0 until metric.getKeysCount).foreach { i =>
         val key = metric.getKeys(i)
         val value = 
LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
         schema = schema.add(key, 
LiteralValueProtoConverter.toDataType(value.getClass))
-        keys += key
         values += value
       }
+      val row = new GenericRowWithSchema(values.result(), schema)
       // If the metrics is registered by an Observation object, attach them 
and unblock any
       // blocked thread.
       setObservationMetricsOpt.foreach { setObservationMetrics =>
-        setObservationMetrics(metric.getPlanId, keys.zip(values).toMap)
+        setObservationMetrics(metric.getPlanId, row)
       }
-      metric.getName -> new GenericRowWithSchema(values.toArray, schema)
+      metric.getName -> row
     }
   }
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6e0b5a35fcd3..58e61badaf37 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1178,8 +1178,10 @@ class SparkConnectPlanner(
     if (input.isStreaming || executeHolderOpt.isEmpty) {
       CollectMetrics(name, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
     } else {
+      // TODO this might be too complex for no good reason. It might
+      //  be easier to inspect the plan after it completes.
       val observation = Observation(name)
-      observation.register(session, planId)
+      session.observationManager.register(observation, planId)
       executeHolderOpt.get.addObservation(name, observation)
       CollectMetrics(name, metrics.map(_.named), 
transformRelation(rel.getInput), planId)
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a28dfbdbf66a..597decdbc740 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1005,30 +1005,11 @@ class Dataset[T] private[sql](
     CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
   }
 
-  /**
-   * Observe (named) metrics through an `org.apache.spark.sql.Observation` 
instance.
-   * This is equivalent to calling `observe(String, Column, Column*)` but does 
not require
-   * adding `org.apache.spark.sql.util.QueryExecutionListener` to the spark 
session.
-   * This method does not support streaming datasets.
-   *
-   * A user can retrieve the metrics by accessing 
`org.apache.spark.sql.Observation.get`.
-   *
-   * {{{
-   *   // Observe row count (rows) and highest id (maxid) in the Dataset while 
writing it
-   *   val observation = Observation("my_metrics")
-   *   val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), 
max($"id").as("maxid"))
-   *   observed_ds.write.parquet("ds.parquet")
-   *   val metrics = observation.get
-   * }}}
-   *
-   * @throws IllegalArgumentException If this is a streaming Dataset 
(this.isStreaming == true)
-   *
-   * @group typedrel
-   * @since 3.3.0
-   */
+  /** @inheritdoc */
   @scala.annotation.varargs
   def observe(observation: Observation, expr: Column, exprs: Column*): 
Dataset[T] = {
-    observation.on(this, expr, exprs: _*)
+    sparkSession.observationManager.register(observation, this)
+    observe(observation.name, expr, exprs: _*)
   }
 
   /** @inheritdoc */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
deleted file mode 100644
index 30d5943c6092..000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.util.UUID
-
-import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
-import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.util.QueryExecutionListener
-import org.apache.spark.util.ArrayImplicits._
-
-
-/**
- * Helper class to simplify usage of `Dataset.observe(String, Column, 
Column*)`:
- *
- * {{{
- *   // Observe row count (rows) and highest id (maxid) in the Dataset while 
writing it
- *   val observation = Observation("my metrics")
- *   val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), 
max($"id").as("maxid"))
- *   observed_ds.write.parquet("ds.parquet")
- *   val metrics = observation.get
- * }}}
- *
- * This collects the metrics while the first action is executed on the 
observed dataset. Subsequent
- * actions do not modify the metrics returned by [[get]]. Retrieval of the 
metric via [[get]]
- * blocks until the first action has finished and metrics become available.
- *
- * This class does not support streaming datasets.
- *
- * @param name name of the metric
- * @since 3.3.0
- */
-class Observation(name: String) extends ObservationBase(name) {
-
-  /**
-   * Create an Observation instance without providing a name. This generates a 
random name.
-   */
-  def this() = this(UUID.randomUUID().toString)
-
-  private val listener: ObservationListener = ObservationListener(this)
-
-  @volatile private var dataframeId: Option[(SparkSession, Long)] = None
-
-  /**
-   * Attach this observation to the given [[Dataset]] to observe aggregation 
expressions.
-   *
-   * @param ds dataset
-   * @param expr first aggregation expression
-   * @param exprs more aggregation expressions
-   * @tparam T dataset type
-   * @return observed dataset
-   * @throws IllegalArgumentException If this is a streaming Dataset 
(ds.isStreaming == true)
-   */
-  private[spark] def on[T](ds: Dataset[T], expr: Column, exprs: Column*): 
Dataset[T] = {
-    if (ds.isStreaming) {
-      throw new IllegalArgumentException("Observation does not support 
streaming Datasets." +
-        "This is because there will be multiple observed metrics as 
microbatches are constructed" +
-        ". Please register a StreamingQueryListener and get the metric for 
each microbatch in " +
-        "QueryProgressEvent.progress, or use query.lastProgress or 
query.recentProgress.")
-    }
-    register(ds.sparkSession, ds.id)
-    ds.observe(name, expr, exprs: _*)
-  }
-
-  private[sql] def register(sparkSession: SparkSession, dataframeId: Long): 
Unit = {
-    // makes this class thread-safe:
-    // only the first thread entering this block can set sparkSession
-    // all other threads will see the exception, as it is only allowed to do 
this once
-    synchronized {
-      if (this.dataframeId.isDefined) {
-        throw new IllegalArgumentException("An Observation can be used with a 
Dataset only once")
-      }
-      this.dataframeId = Some((sparkSession, dataframeId))
-    }
-
-    sparkSession.listenerManager.register(this.listener)
-  }
-
-  private def unregister(): Unit = {
-    this.dataframeId.foreach(_._1.listenerManager.unregister(this.listener))
-  }
-
-  private[spark] def onFinish(qe: QueryExecution): Unit = {
-    synchronized {
-      if (this.metrics.isEmpty && qe.logical.exists {
-        case CollectMetrics(name, _, _, dataframeId) =>
-          name == this.name && dataframeId == this.dataframeId.get._2
-        case _ => false
-      }) {
-        val row = qe.observedMetrics.get(name)
-        val metrics = row.map(r => 
r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq))
-        if (setMetricsAndNotify(metrics)) {
-          unregister()
-        }
-      }
-    }
-  }
-
-}
-
-private[sql] case class ObservationListener(observation: Observation)
-  extends QueryExecutionListener {
-
-  override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit =
-    observation.onFinish(qe)
-
-  override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit =
-    observation.onFinish(qe)
-
-}
-
-/**
- * (Scala-specific) Create instances of Observation via Scala `apply`.
- * @since 3.3.0
- */
-object Observation {
-
-  /**
-   * Observation constructor for creating an anonymous observation.
-   */
-  def apply(): Observation = new Observation()
-
-  /**
-   * Observation constructor for creating a named observation.
-   */
-  def apply(name: String): Observation = new Observation(name)
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index fa2d1b163322..55f67da68221 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -779,6 +779,8 @@ class SparkSession private(
      */
     def named: NamedExpression = ExpressionUtils.toNamed(expr)
   }
+
+  private[sql] lazy val observationManager = new ObservationManager(this)
 }
 
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala
new file mode 100644
index 000000000000..4fa1f0f4962a
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.internal
+
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.sql.{Dataset, Observation, SparkSession}
+import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.util.QueryExecutionListener
+
+/**
+ * This class keeps track of registered Observations that await query 
completion.
+ */
+private[sql] class ObservationManager(session: SparkSession) {
+  private val observations = new ConcurrentHashMap[(String, Long), Observation]
+  session.listenerManager.register(Listener)
+
+  def register(observation: Observation, ds: Dataset[_]): Unit = {
+    if (ds.isStreaming) {
+      throw new IllegalArgumentException("Observation does not support 
streaming Datasets." +
+        "This is because there will be multiple observed metrics as 
microbatches are constructed" +
+        ". Please register a StreamingQueryListener and get the metric for 
each microbatch in " +
+        "QueryProgressEvent.progress, or use query.lastProgress or 
query.recentProgress.")
+    }
+    register(observation, ds.id)
+  }
+
+  def register(observation: Observation, dataFrameId: Long): Unit = {
+    observation.markRegistered()
+    observations.putIfAbsent((observation.name, dataFrameId), observation)
+  }
+
+  private def tryComplete(qe: QueryExecution): Unit = {
+    val allMetrics = qe.observedMetrics
+    qe.logical.foreach {
+      case c: CollectMetrics =>
+        allMetrics.get(c.name).foreach { metrics =>
+          val observation = observations.remove((c.name, c.dataframeId))
+          if (observation != null) {
+            observation.setMetricsAndNotify(metrics)
+          }
+        }
+      case _ =>
+    }
+  }
+
+  private object Listener extends QueryExecutionListener {
+    override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit =
+      tryComplete(qe)
+
+    override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit =
+      tryComplete(qe)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to