This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7db9b2293fa [SPARK-45656][SQL] Fix observation when named observations
with the same name on different datasets
7db9b2293fa is described below
commit 7db9b2293fa778073274d235dd72212b75d94073
Author: Takuya UESHIN <[email protected]>
AuthorDate: Wed Oct 25 16:59:26 2023 +0900
[SPARK-45656][SQL] Fix observation when named observations with the same
name on different datasets
### What changes were proposed in this pull request?
Fixes observation when named observations with the same name on different
datasets.
### Why are the changes needed?
Currently if there are observations with the same name on different
dataset, one of them will be overwritten by the other execution.
For example,
```py
>>> observation1 = Observation("named")
>>> df1 = spark.range(50)
>>> observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt"))
>>>
>>> observation2 = Observation("named")
>>> df2 = spark.range(100)
>>> observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt"))
>>>
>>> observed_df1.collect()
...
>>> observed_df2.collect()
...
>>> observation1.get
{'cnt': 50}
>>> observation2.get
{'cnt': 50}
```
`observation2` should return `{'cnt': 100}`.
### Does this PR introduce _any_ user-facing change?
Yes, the observations with the same name will be available if they observe
different datasets.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43519 from ueshin/issues/SPARK-45656/observation.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_dataframe.py | 18 ++++++++++++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 2 +-
.../scala/org/apache/spark/sql/Observation.scala | 21 +++++++++++++--------
.../scala/org/apache/spark/sql/DatasetSuite.scala | 21 +++++++++++++++++++++
4 files changed, 53 insertions(+), 9 deletions(-)
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 3c493a8ae3a..0a2e3a53946 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1023,6 +1023,24 @@ class DataFrameTestsMixin:
self.assertGreaterEqual(row.cnt, 0)
self.assertGreaterEqual(row.sum, 0)
+ def test_observe_with_same_name_on_different_dataframe(self):
+ # SPARK-45656: named observations with the same name on different
datasets
+ from pyspark.sql import Observation
+
+ observation1 = Observation("named")
+ df1 = self.spark.range(50)
+ observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt"))
+
+ observation2 = Observation("named")
+ df2 = self.spark.range(100)
+ observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt"))
+
+ observed_df1.collect()
+ observed_df2.collect()
+
+ self.assertEqual(observation1.get, dict(cnt=50))
+ self.assertEqual(observation2.get, dict(cnt=100))
+
def test_sample(self):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.range(1).sample()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5079cfcca9d..4f07133bb76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -201,7 +201,7 @@ class Dataset[T] private[sql](
}
// A globally unique id of this Dataset.
- private val id = Dataset.curId.getAndIncrement()
+ private[sql] val id = Dataset.curId.getAndIncrement()
queryExecution.assertAnalyzed()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
index ba40336fc14..14c4983794b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -21,6 +21,7 @@ import java.util.UUID
import scala.jdk.CollectionConverters.MapHasAsJava
+import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener
@@ -56,7 +57,7 @@ class Observation(val name: String) {
private val listener: ObservationListener = ObservationListener(this)
- @volatile private var sparkSession: Option[SparkSession] = None
+ @volatile private var ds: Option[Dataset[_]] = None
@volatile private var metrics: Option[Map[String, Any]] = None
@@ -74,7 +75,7 @@ class Observation(val name: String) {
if (ds.isStreaming) {
throw new IllegalArgumentException("Observation does not support
streaming Datasets")
}
- register(ds.sparkSession)
+ register(ds)
ds.observe(name, expr, exprs: _*)
}
@@ -112,27 +113,31 @@ class Observation(val name: String) {
get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava
}
- private def register(sparkSession: SparkSession): Unit = {
+ private def register(ds: Dataset[_]): Unit = {
// makes this class thread-safe:
// only the first thread entering this block can set sparkSession
// all other threads will see the exception, as it is only allowed to do
this once
synchronized {
- if (this.sparkSession.isDefined) {
+ if (this.ds.isDefined) {
throw new IllegalArgumentException("An Observation can be used with a
Dataset only once")
}
- this.sparkSession = Some(sparkSession)
+ this.ds = Some(ds)
}
- sparkSession.listenerManager.register(this.listener)
+ ds.sparkSession.listenerManager.register(this.listener)
}
private def unregister(): Unit = {
- this.sparkSession.foreach(_.listenerManager.unregister(this.listener))
+ this.ds.foreach(_.sparkSession.listenerManager.unregister(this.listener))
}
private[spark] def onFinish(qe: QueryExecution): Unit = {
synchronized {
- if (this.metrics.isEmpty) {
+ if (this.metrics.isEmpty && qe.logical.exists {
+ case CollectMetrics(name, _, _, dataframeId) =>
+ name == this.name && dataframeId == ds.get.id
+ case _ => false
+ }) {
val row = qe.observedMetrics.get(name)
this.metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames))
if (metrics.isDefined) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 51fa3cd5916..6b00799cabd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1024,6 +1024,27 @@ class DatasetSuite extends QueryTest
assert(namedObservation.get === expected)
}
+ test("SPARK-45656: named observations with the same name on different
datasets") {
+ val namedObservation1 = Observation("named")
+ val df1 = spark.range(50)
+ val observed_df1 = df1.observe(
+ namedObservation1, count(lit(1)).as("count"))
+
+ val namedObservation2 = Observation("named")
+ val df2 = spark.range(100)
+ val observed_df2 = df2.observe(
+ namedObservation2, count(lit(1)).as("count"))
+
+ observed_df1.collect()
+ observed_df2.collect()
+
+ val expected1 = Map("count" -> 50)
+ val expected2 = Map("count" -> 100)
+
+ assert(namedObservation1.get === expected1)
+ assert(namedObservation2.get === expected2)
+ }
+
test("sample with replacement") {
val n = 100
val data = sparkContext.parallelize(1 to n, 2).toDS()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]