This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7db9b2293fa [SPARK-45656][SQL] Fix observation when named observations with the same name on different datasets 7db9b2293fa is described below commit 7db9b2293fa778073274d235dd72212b75d94073 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Wed Oct 25 16:59:26 2023 +0900 [SPARK-45656][SQL] Fix observation when named observations with the same name on different datasets ### What changes were proposed in this pull request? Fixes observation when named observations with the same name on different datasets. ### Why are the changes needed? Currently if there are observations with the same name on different dataset, one of them will be overwritten by the other execution. For example, ```py >>> observation1 = Observation("named") >>> df1 = spark.range(50) >>> observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt")) >>> >>> observation2 = Observation("named") >>> df2 = spark.range(100) >>> observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt")) >>> >>> observed_df1.collect() ... >>> observed_df2.collect() ... >>> observation1.get {'cnt': 50} >>> observation2.get {'cnt': 50} ``` `observation2` should return `{'cnt': 100}`. ### Does this PR introduce _any_ user-facing change? Yes, the observations with the same name will be available if they observe different datasets. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43519 from ueshin/issues/SPARK-45656/observation. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/tests/test_dataframe.py | 18 ++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../scala/org/apache/spark/sql/Observation.scala | 21 +++++++++++++-------- .../scala/org/apache/spark/sql/DatasetSuite.scala | 21 +++++++++++++++++++++ 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 3c493a8ae3a..0a2e3a53946 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1023,6 +1023,24 @@ class DataFrameTestsMixin: self.assertGreaterEqual(row.cnt, 0) self.assertGreaterEqual(row.sum, 0) + def test_observe_with_same_name_on_different_dataframe(self): + # SPARK-45656: named observations with the same name on different datasets + from pyspark.sql import Observation + + observation1 = Observation("named") + df1 = self.spark.range(50) + observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt")) + + observation2 = Observation("named") + df2 = self.spark.range(100) + observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt")) + + observed_df1.collect() + observed_df2.collect() + + self.assertEqual(observation1.get, dict(cnt=50)) + self.assertEqual(observation2.get, dict(cnt=100)) + def test_sample(self): with self.assertRaises(PySparkTypeError) as pe: self.spark.range(1).sample() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5079cfcca9d..4f07133bb76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -201,7 +201,7 @@ class Dataset[T] private[sql]( } // A globally unique id of this Dataset. - private val id = Dataset.curId.getAndIncrement() + private[sql] val id = Dataset.curId.getAndIncrement() queryExecution.assertAnalyzed() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala index ba40336fc14..14c4983794b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala @@ -21,6 +21,7 @@ import java.util.UUID import scala.jdk.CollectionConverters.MapHasAsJava +import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener @@ -56,7 +57,7 @@ class Observation(val name: String) { private val listener: ObservationListener = ObservationListener(this) - @volatile private var sparkSession: Option[SparkSession] = None + @volatile private var ds: Option[Dataset[_]] = None @volatile private var metrics: Option[Map[String, Any]] = None @@ -74,7 +75,7 @@ class Observation(val name: String) { if (ds.isStreaming) { throw new IllegalArgumentException("Observation does not support streaming Datasets") } - register(ds.sparkSession) + register(ds) ds.observe(name, expr, exprs: _*) } @@ -112,27 +113,31 @@ class Observation(val name: String) { get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava } - private def register(sparkSession: SparkSession): Unit = { + private def register(ds: Dataset[_]): Unit = { // makes this class thread-safe: // only the first thread entering this block can set sparkSession // all other threads will see the exception, as it is only allowed to do this once synchronized { - if (this.sparkSession.isDefined) { + if (this.ds.isDefined) { throw new IllegalArgumentException("An Observation can be used with a Dataset only once") } - this.sparkSession = Some(sparkSession) + this.ds = Some(ds) } - sparkSession.listenerManager.register(this.listener) + ds.sparkSession.listenerManager.register(this.listener) } private def unregister(): Unit = { - this.sparkSession.foreach(_.listenerManager.unregister(this.listener)) + this.ds.foreach(_.sparkSession.listenerManager.unregister(this.listener)) } private[spark] def onFinish(qe: QueryExecution): Unit = { synchronized { - if (this.metrics.isEmpty) { + if (this.metrics.isEmpty && qe.logical.exists { + case CollectMetrics(name, _, _, dataframeId) => + name == this.name && dataframeId == ds.get.id + case _ => false + }) { val row = qe.observedMetrics.get(name) this.metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames)) if (metrics.isDefined) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 51fa3cd5916..6b00799cabd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1024,6 +1024,27 @@ class DatasetSuite extends QueryTest assert(namedObservation.get === expected) } + test("SPARK-45656: named observations with the same name on different datasets") { + val namedObservation1 = Observation("named") + val df1 = spark.range(50) + val observed_df1 = df1.observe( + namedObservation1, count(lit(1)).as("count")) + + val namedObservation2 = Observation("named") + val df2 = spark.range(100) + val observed_df2 = df2.observe( + namedObservation2, count(lit(1)).as("count")) + + observed_df1.collect() + observed_df2.collect() + + val expected1 = Map("count" -> 50) + val expected2 = Map("count" -> 100) + + assert(namedObservation1.get === expected1) + assert(namedObservation2.get === expected2) + } + test("sample with replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org