This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 528ba0e0db7d [SPARK-49423][CONNECT][SQL] Consolidate Observation in
sql/api
528ba0e0db7d is described below
commit 528ba0e0db7d6f72317063c04c2dd6ddadfcfaab
Author: Herman van Hovell <[email protected]>
AuthorDate: Fri Aug 30 08:01:22 2024 -0400
[SPARK-49423][CONNECT][SQL] Consolidate Observation in sql/api
### What changes were proposed in this pull request?
This PR moves Observation into sql/api. For classic I moved the wiring into
the `observe` method itself, and the required listener is now part of the
`org.apache.spark.sql.internal` package. I have also take the liberty to get
rid of most of the homegrown locking, and I have replaced it with a promise.
### Why are the changes needed?
We are creating a shared interface for the classic and connect Scala
DataFrame API. This class is part of that API.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47921 from hvanhovell/SPARK-49423.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 23 +---
.../scala/org/apache/spark/sql/Observation.scala | 46 -------
.../scala/org/apache/spark/sql/SparkSession.scala | 11 +-
.../CheckConnectJvmClientCompatibility.scala | 2 -
project/MimaExcludes.scala | 4 +
python/pyspark/sql/observation.py | 6 +-
.../{ObservationBase.scala => Observation.scala} | 87 ++++++++-----
.../scala/org/apache/spark/sql/api/Dataset.scala | 27 +++-
.../spark/sql/connect/client/SparkResult.scala | 14 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 4 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 25 +---
.../scala/org/apache/spark/sql/Observation.scala | 143 ---------------------
.../scala/org/apache/spark/sql/SparkSession.scala | 2 +
.../spark/sql/internal/ObservationManager.scala | 69 ++++++++++
14 files changed, 173 insertions(+), 290 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index d05834c4fc6c..3b10978b7c8b 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1201,28 +1201,7 @@ class Dataset[T] private[sql] (
}
}
- /**
- * Observe (named) metrics through an `org.apache.spark.sql.Observation`
instance. This is
- * equivalent to calling `observe(String, Column, Column*)` but does not
require to collect all
- * results before returning the metrics - the metrics are filled during
iterating the results,
- * as soon as they are available. This method does not support streaming
datasets.
- *
- * A user can retrieve the metrics by accessing
`org.apache.spark.sql.Observation.get`.
- *
- * {{{
- * // Observe row count (rows) and highest id (maxid) in the Dataset while
writing it
- * val observation = Observation("my_metrics")
- * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"),
max($"id").as("maxid"))
- * observed_ds.write.parquet("ds.parquet")
- * val metrics = observation.get
- * }}}
- *
- * @throws IllegalArgumentException
- * If this is a streaming Dataset (this.isStreaming == true)
- *
- * @group typedrel
- * @since 4.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
def observe(observation: Observation, expr: Column, exprs: Column*):
Dataset[T] = {
val df = observe(observation.name, expr, exprs: _*)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
deleted file mode 100644
index 75629b6000f9..000000000000
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.util.UUID
-
-class Observation(name: String) extends ObservationBase(name) {
-
- /**
- * Create an Observation instance without providing a name. This generates a
random name.
- */
- def this() = this(UUID.randomUUID().toString)
-}
-
-/**
- * (Scala-specific) Create instances of Observation via Scala `apply`.
- * @since 4.0.0
- */
-object Observation {
-
- /**
- * Observation constructor for creating an anonymous observation.
- */
- def apply(): Observation = new Observation()
-
- /**
- * Observation constructor for creating a named observation.
- */
- def apply(name: String): Observation = new Observation(name)
-
-}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 3837db00acc6..24d0a5ac7262 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -569,17 +569,14 @@ class SparkSession private[sql] (
private[sql] var releaseSessionOnClose = true
private[sql] def registerObservation(planId: Long, observation:
Observation): Unit = {
- if (observationRegistry.putIfAbsent(planId, observation) != null) {
- throw new IllegalArgumentException("An Observation can be used with a
Dataset only once")
- }
+ observation.markRegistered()
+ observationRegistry.putIfAbsent(planId, observation)
}
- private[sql] def setMetricsAndUnregisterObservation(
- planId: Long,
- metrics: Map[String, Any]): Unit = {
+ private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics:
Row): Unit = {
val observationOrNull = observationRegistry.remove(planId)
if (observationOrNull != null) {
- observationOrNull.setMetricsAndNotify(Some(metrics))
+ observationOrNull.setMetricsAndNotify(metrics)
}
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index af9168339dcf..10b31155376f 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.Dataset$" // private[sql]
),
-
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
-
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
// TODO (SPARK-49096):
// Mima check might complain the following Dataset rules does not filter
any problem.
// This is due to a potential bug in Mima that all methods in `class
Dataset` are not being
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 2fc0725df5bc..fe4a08971509 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -120,6 +120,10 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.Window"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.Window$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.WindowSpec"),
+
+ // SPARK-49423: Consolidate Observation in sql/api
+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"),
+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"),
)
// Default exclude rules
diff --git a/python/pyspark/sql/observation.py
b/python/pyspark/sql/observation.py
index 5f26b439b048..6ceb6bc90327 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -124,10 +124,8 @@ class Observation:
assert self._jvm is not None
cls = self._jvm.org.apache.spark.sql.Observation
self._jo = cls(self._name) if self._name is not None else cls()
- observed_df = self._jo.on(
- df._jdf,
- exprs[0]._jc,
- _to_seq(df._sc, [c._jc for c in exprs[1:]]),
+ observed_df = df._jdf.observe(
+ self._jo, exprs[0]._jc, _to_seq(df._sc, [c._jc for c in exprs[1:]])
)
return DataFrame(observed_df, df.sparkSession)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
similarity index 62%
rename from sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index 4789ae8975d1..02f5a8de1e3f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -17,7 +17,15 @@
package org.apache.spark.sql
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.concurrent.{Future, Promise}
+import scala.concurrent.duration.{Duration, DurationInt}
import scala.jdk.CollectionConverters.MapHasAsJava
+import scala.util.Try
+
+import org.apache.spark.util.SparkThreadUtils
/**
* Helper class to simplify usage of `Dataset.observe(String, Column,
Column*)`:
@@ -39,11 +47,22 @@ import scala.jdk.CollectionConverters.MapHasAsJava
* @param name name of the metric
* @since 3.3.0
*/
-abstract class ObservationBase(val name: String) {
+class Observation(val name: String) {
+ require(name.nonEmpty, "Name must not be empty")
- if (name.isEmpty) throw new IllegalArgumentException("Name must not be
empty")
+ /**
+ * Create an Observation with a random name.
+ */
+ def this() = this(UUID.randomUUID().toString)
- @volatile protected var metrics: Option[Map[String, Any]] = None
+ private val isRegistered = new AtomicBoolean()
+
+ private val promise = Promise[Map[String, Any]]()
+
+ /**
+ * Future holding the (yet to be completed) observation.
+ */
+ val future: Future[Map[String, Any]] = promise.future
/**
* (Scala-specific) Get the observed metrics. This waits for the observed
dataset to finish
@@ -54,17 +73,7 @@ abstract class ObservationBase(val name: String) {
* @throws InterruptedException interrupted while waiting
*/
@throws[InterruptedException]
- def get: Map[String, _] = {
- synchronized {
- // we need to loop as wait might return without us calling notify
- //
https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610
- while (this.metrics.isEmpty) {
- wait()
- }
- }
-
- this.metrics.get
- }
+ def get: Map[String, Any] = SparkThreadUtils.awaitResult(future,
Duration.Inf)
/**
* (Java-specific) Get the observed metrics. This waits for the observed
dataset to finish
@@ -75,9 +84,7 @@ abstract class ObservationBase(val name: String) {
* @throws InterruptedException interrupted while waiting
*/
@throws[InterruptedException]
- def getAsJava: java.util.Map[String, AnyRef] = {
- get.map { case (key, value) => (key, value.asInstanceOf[Object]) }.asJava
- }
+ def getAsJava: java.util.Map[String, Any] = get.asJava
/**
* Get the observed metrics. This returns the metrics if they are available,
otherwise an empty.
@@ -85,12 +92,16 @@ abstract class ObservationBase(val name: String) {
* @return the observed metrics as a `Map[String, Any]`
*/
@throws[InterruptedException]
- private[sql] def getOrEmpty: Map[String, _] = {
- synchronized {
- if (metrics.isEmpty) {
- wait(100) // Wait for 100ms to see if metrics are available
- }
- metrics.getOrElse(Map.empty)
+ private[sql] def getOrEmpty: Map[String, Any] = {
+ Try(SparkThreadUtils.awaitResult(future, 100.millis)).getOrElse(Map.empty)
+ }
+
+ /**
+ * Mark this Observation as registered.
+ */
+ private[sql] def markRegistered(): Unit = {
+ if (!isRegistered.compareAndSet(false, true)) {
+ throw new IllegalArgumentException("An Observation can be used with a
Dataset only once")
}
}
@@ -99,15 +110,25 @@ abstract class ObservationBase(val name: String) {
*
* @return `true` if all waiting threads were notified, `false` if otherwise.
*/
- private[spark] def setMetricsAndNotify(metrics: Option[Map[String, Any]]):
Boolean = {
- synchronized {
- this.metrics = metrics
- if(metrics.isDefined) {
- notifyAll()
- true
- } else {
- false
- }
- }
+ private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
+ val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
+ promise.trySuccess(metricsMap)
}
}
+
+/**
+ * (Scala-specific) Create instances of Observation via Scala `apply`.
+ * @since 3.3.0
+ */
+object Observation {
+
+ /**
+ * Observation constructor for creating an anonymous observation.
+ */
+ def apply(): Observation = new Observation()
+
+ /**
+ * Observation constructor for creating a named observation.
+ */
+ def apply(name: String): Observation = new Observation(name)
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 2b071a384e0a..a5e125733a29 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -16,7 +16,6 @@
*/
package org.apache.spark.sql.api
-import scala.annotation.varargs
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
@@ -24,7 +23,7 @@ import _root_.java.util
import org.apache.spark.annotation.{DeveloperApi, Stable}
import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction,
ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction,
ReduceFunction}
-import org.apache.spark.sql.{functions, AnalysisException, Column, Encoder,
Row, TypedColumn}
+import org.apache.spark.sql.{functions, AnalysisException, Column, Encoder,
Observation, Row, TypedColumn}
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
@@ -1518,9 +1517,31 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]]
extends Serializable {
* @group typedrel
* @since 3.0.0
*/
- @varargs
+ @scala.annotation.varargs
def observe(name: String, expr: Column, exprs: Column*): DS[T]
+ /**
+ * Observe (named) metrics through an `org.apache.spark.sql.Observation`
instance. This method
+ * does not support streaming datasets.
+ *
+ * A user can retrieve the metrics by accessing
`org.apache.spark.sql.Observation.get`.
+ *
+ * {{{
+ * // Observe row count (rows) and highest id (maxid) in the Dataset while
writing it
+ * val observation = Observation("my_metrics")
+ * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"),
max($"id").as("maxid"))
+ * observed_ds.write.parquet("ds.parquet")
+ * val metrics = observation.get
+ * }}}
+ *
+ * @throws IllegalArgumentException If this is a streaming Dataset
(this.isStreaming == true)
+ *
+ * @group typedrel
+ * @since 3.3.0
+ */
+ @scala.annotation.varargs
+ def observe(observation: Observation, expr: Column, exprs: Column*): DS[T]
+
/**
* Returns a new Dataset by taking the first `n` rows. The difference
between this function
* and `head` is that `head` is an action and returns an array (by
triggering query execution)
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 0905ee76c3f3..3aad90e96f8c 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -42,7 +42,7 @@ private[sql] class SparkResult[T](
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String,
- setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None)
+ setObservationMetricsOpt: Option[(Long, Row) => Unit] = None)
extends AutoCloseable { self =>
case class StageInfo(
@@ -211,21 +211,21 @@ private[sql] class SparkResult[T](
metrics.asScala.map { metric =>
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
- val keys = mutable.ListBuffer.empty[String]
- val values = mutable.ListBuffer.empty[Any]
- (0 until metric.getKeysCount).map { i =>
+ val values = mutable.ArrayBuilder.make[Any]
+ values.sizeHint(metric.getKeysCount)
+ (0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value =
LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key,
LiteralValueProtoConverter.toDataType(value.getClass))
- keys += key
values += value
}
+ val row = new GenericRowWithSchema(values.result(), schema)
// If the metrics is registered by an Observation object, attach them
and unblock any
// blocked thread.
setObservationMetricsOpt.foreach { setObservationMetrics =>
- setObservationMetrics(metric.getPlanId, keys.zip(values).toMap)
+ setObservationMetrics(metric.getPlanId, row)
}
- metric.getName -> new GenericRowWithSchema(values.toArray, schema)
+ metric.getName -> row
}
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6e0b5a35fcd3..58e61badaf37 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1178,8 +1178,10 @@ class SparkConnectPlanner(
if (input.isStreaming || executeHolderOpt.isEmpty) {
CollectMetrics(name, metrics.map(_.named),
transformRelation(rel.getInput), planId)
} else {
+ // TODO this might be too complex for no good reason. It might
+ // be easier to inspect the plan after it completes.
val observation = Observation(name)
- observation.register(session, planId)
+ session.observationManager.register(observation, planId)
executeHolderOpt.get.addObservation(name, observation)
CollectMetrics(name, metrics.map(_.named),
transformRelation(rel.getInput), planId)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a28dfbdbf66a..597decdbc740 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1005,30 +1005,11 @@ class Dataset[T] private[sql](
CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
}
- /**
- * Observe (named) metrics through an `org.apache.spark.sql.Observation`
instance.
- * This is equivalent to calling `observe(String, Column, Column*)` but does
not require
- * adding `org.apache.spark.sql.util.QueryExecutionListener` to the spark
session.
- * This method does not support streaming datasets.
- *
- * A user can retrieve the metrics by accessing
`org.apache.spark.sql.Observation.get`.
- *
- * {{{
- * // Observe row count (rows) and highest id (maxid) in the Dataset while
writing it
- * val observation = Observation("my_metrics")
- * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"),
max($"id").as("maxid"))
- * observed_ds.write.parquet("ds.parquet")
- * val metrics = observation.get
- * }}}
- *
- * @throws IllegalArgumentException If this is a streaming Dataset
(this.isStreaming == true)
- *
- * @group typedrel
- * @since 3.3.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
def observe(observation: Observation, expr: Column, exprs: Column*):
Dataset[T] = {
- observation.on(this, expr, exprs: _*)
+ sparkSession.observationManager.register(observation, this)
+ observe(observation.name, expr, exprs: _*)
}
/** @inheritdoc */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
deleted file mode 100644
index 30d5943c6092..000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ /dev/null
@@ -1,143 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.util.UUID
-
-import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
-import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.util.QueryExecutionListener
-import org.apache.spark.util.ArrayImplicits._
-
-
-/**
- * Helper class to simplify usage of `Dataset.observe(String, Column,
Column*)`:
- *
- * {{{
- * // Observe row count (rows) and highest id (maxid) in the Dataset while
writing it
- * val observation = Observation("my metrics")
- * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"),
max($"id").as("maxid"))
- * observed_ds.write.parquet("ds.parquet")
- * val metrics = observation.get
- * }}}
- *
- * This collects the metrics while the first action is executed on the
observed dataset. Subsequent
- * actions do not modify the metrics returned by [[get]]. Retrieval of the
metric via [[get]]
- * blocks until the first action has finished and metrics become available.
- *
- * This class does not support streaming datasets.
- *
- * @param name name of the metric
- * @since 3.3.0
- */
-class Observation(name: String) extends ObservationBase(name) {
-
- /**
- * Create an Observation instance without providing a name. This generates a
random name.
- */
- def this() = this(UUID.randomUUID().toString)
-
- private val listener: ObservationListener = ObservationListener(this)
-
- @volatile private var dataframeId: Option[(SparkSession, Long)] = None
-
- /**
- * Attach this observation to the given [[Dataset]] to observe aggregation
expressions.
- *
- * @param ds dataset
- * @param expr first aggregation expression
- * @param exprs more aggregation expressions
- * @tparam T dataset type
- * @return observed dataset
- * @throws IllegalArgumentException If this is a streaming Dataset
(ds.isStreaming == true)
- */
- private[spark] def on[T](ds: Dataset[T], expr: Column, exprs: Column*):
Dataset[T] = {
- if (ds.isStreaming) {
- throw new IllegalArgumentException("Observation does not support
streaming Datasets." +
- "This is because there will be multiple observed metrics as
microbatches are constructed" +
- ". Please register a StreamingQueryListener and get the metric for
each microbatch in " +
- "QueryProgressEvent.progress, or use query.lastProgress or
query.recentProgress.")
- }
- register(ds.sparkSession, ds.id)
- ds.observe(name, expr, exprs: _*)
- }
-
- private[sql] def register(sparkSession: SparkSession, dataframeId: Long):
Unit = {
- // makes this class thread-safe:
- // only the first thread entering this block can set sparkSession
- // all other threads will see the exception, as it is only allowed to do
this once
- synchronized {
- if (this.dataframeId.isDefined) {
- throw new IllegalArgumentException("An Observation can be used with a
Dataset only once")
- }
- this.dataframeId = Some((sparkSession, dataframeId))
- }
-
- sparkSession.listenerManager.register(this.listener)
- }
-
- private def unregister(): Unit = {
- this.dataframeId.foreach(_._1.listenerManager.unregister(this.listener))
- }
-
- private[spark] def onFinish(qe: QueryExecution): Unit = {
- synchronized {
- if (this.metrics.isEmpty && qe.logical.exists {
- case CollectMetrics(name, _, _, dataframeId) =>
- name == this.name && dataframeId == this.dataframeId.get._2
- case _ => false
- }) {
- val row = qe.observedMetrics.get(name)
- val metrics = row.map(r =>
r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq))
- if (setMetricsAndNotify(metrics)) {
- unregister()
- }
- }
- }
- }
-
-}
-
-private[sql] case class ObservationListener(observation: Observation)
- extends QueryExecutionListener {
-
- override def onSuccess(funcName: String, qe: QueryExecution, durationNs:
Long): Unit =
- observation.onFinish(qe)
-
- override def onFailure(funcName: String, qe: QueryExecution, exception:
Exception): Unit =
- observation.onFinish(qe)
-
-}
-
-/**
- * (Scala-specific) Create instances of Observation via Scala `apply`.
- * @since 3.3.0
- */
-object Observation {
-
- /**
- * Observation constructor for creating an anonymous observation.
- */
- def apply(): Observation = new Observation()
-
- /**
- * Observation constructor for creating a named observation.
- */
- def apply(name: String): Observation = new Observation(name)
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index fa2d1b163322..55f67da68221 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -779,6 +779,8 @@ class SparkSession private(
*/
def named: NamedExpression = ExpressionUtils.toNamed(expr)
}
+
+ private[sql] lazy val observationManager = new ObservationManager(this)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala
new file mode 100644
index 000000000000..4fa1f0f4962a
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/ObservationManager.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.internal
+
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.sql.{Dataset, Observation, SparkSession}
+import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.util.QueryExecutionListener
+
+/**
+ * This class keeps track of registered Observations that await query
completion.
+ */
+private[sql] class ObservationManager(session: SparkSession) {
+ private val observations = new ConcurrentHashMap[(String, Long), Observation]
+ session.listenerManager.register(Listener)
+
+ def register(observation: Observation, ds: Dataset[_]): Unit = {
+ if (ds.isStreaming) {
+ throw new IllegalArgumentException("Observation does not support
streaming Datasets." +
+ "This is because there will be multiple observed metrics as
microbatches are constructed" +
+ ". Please register a StreamingQueryListener and get the metric for
each microbatch in " +
+ "QueryProgressEvent.progress, or use query.lastProgress or
query.recentProgress.")
+ }
+ register(observation, ds.id)
+ }
+
+ def register(observation: Observation, dataFrameId: Long): Unit = {
+ observation.markRegistered()
+ observations.putIfAbsent((observation.name, dataFrameId), observation)
+ }
+
+ private def tryComplete(qe: QueryExecution): Unit = {
+ val allMetrics = qe.observedMetrics
+ qe.logical.foreach {
+ case c: CollectMetrics =>
+ allMetrics.get(c.name).foreach { metrics =>
+ val observation = observations.remove((c.name, c.dataframeId))
+ if (observation != null) {
+ observation.setMetricsAndNotify(metrics)
+ }
+ }
+ case _ =>
+ }
+ }
+
+ private object Listener extends QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs:
Long): Unit =
+ tryComplete(qe)
+
+ override def onFailure(funcName: String, qe: QueryExecution, exception:
Exception): Unit =
+ tryComplete(qe)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]