This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 692dc66  [SPARK-35695][SQL] Collect observed metrics from cached and 
adaptive execution sub-trees
692dc66 is described below

commit 692dc66c4a3660665c1f156df6eeb9ce6f86195e
Author: Tanel Kiis <tanel.k...@gmail.com>
AuthorDate: Fri Jun 11 21:03:08 2021 +0800

    [SPARK-35695][SQL] Collect observed metrics from cached and adaptive 
execution sub-trees
    
    ### What changes were proposed in this pull request?
    
    Collect observed metrics from cached and adaptive execution sub-trees.
    
    ### Why are the changes needed?
    
    Currently persisting/caching will hide all observed metrics in that 
sub-tree from reaching the `QueryExecutionListeners`. Adaptive query execution 
can also hide the metrics from reaching `QueryExecutionListeners`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Bugfix
    
    ### How was this patch tested?
    
    New UTs
    
    Closes #32862 from tanelk/SPARK-35695_collect_metrics_persist.
    
    Lead-authored-by: Tanel Kiis <tanel.k...@gmail.com>
    Co-authored-by: tanel.k...@gmail.com <tanel.k...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/execution/CollectMetricsExec.scala   |  12 ++-
 .../spark/sql/util/DataFrameCallbackSuite.scala    | 110 +++++++++++++--------
 2 files changed, 79 insertions(+), 43 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
index 933dabe..89aeb09 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, 
SortOrder}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
QueryStageExec}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -93,8 +95,14 @@ object CollectMetricsExec {
    */
   def collect(plan: SparkPlan): Map[String, Row] = {
     val metrics = plan.collectWithSubqueries {
-      case collector: CollectMetricsExec => collector.name -> 
collector.collectedMetrics
+      case collector: CollectMetricsExec => Map(collector.name -> 
collector.collectedMetrics)
+      case tableScan: InMemoryTableScanExec =>
+        CollectMetricsExec.collect(tableScan.relation.cachedPlan)
+      case adaptivePlan: AdaptiveSparkPlanExec =>
+        CollectMetricsExec.collect(adaptivePlan.executedPlan)
+      case queryStageExec: QueryStageExec =>
+        CollectMetricsExec.collect(queryStageExec.plan)
     }
-    metrics.toMap
+    metrics.reduceOption(_ ++ _).getOrElse(Map.empty)
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 7a18d6e..01efd98 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.util
 
+import java.lang.{Long => JLong}
+
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
@@ -28,6 +30,7 @@ import 
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import 
org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, 
LeafRunnableCommand}
 import 
org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
 import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StringType
 
@@ -235,18 +238,58 @@ class DataFrameCallbackSuite extends QueryTest
   }
 
   test("get observable metrics by callback") {
-    val metricMaps = ArrayBuffer.empty[Map[String, Row]]
-    val listener = new QueryExecutionListener {
-      override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
-        metricMaps += qe.observedMetrics
-      }
+    val df = spark.range(100)
+      .observe(
+        name = "my_event",
+        min($"id").as("min_val"),
+        max($"id").as("max_val"),
+        // Test unresolved alias
+        sum($"id"),
+        count(when($"id" % 2 === 0, 1)).as("num_even"))
+      .observe(
+        name = "other_event",
+        avg($"id").cast("int").as("avg_val"))
+
+    validateObservedMetrics(df)
+  }
 
-      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {
-        // No-op
-      }
-    }
-    spark.listenerManager.register(listener)
-    try {
+  test("SPARK-35296: observe should work even if a task contains multiple 
partitions") {
+    val df = spark.range(0, 100, 1, 3)
+      .observe(
+        name = "my_event",
+        min($"id").as("min_val"),
+        max($"id").as("max_val"),
+        // Test unresolved alias
+        sum($"id"),
+        count(when($"id" % 2 === 0, 1)).as("num_even"))
+      .observe(
+        name = "other_event",
+        avg($"id").cast("int").as("avg_val"))
+      .coalesce(2)
+
+    validateObservedMetrics(df)
+  }
+
+  test("SPARK-35695: get observable metrics with persist by callback") {
+    val df = spark.range(100)
+      .observe(
+        name = "my_event",
+        min($"id").as("min_val"),
+        max($"id").as("max_val"),
+        // Test unresolved alias
+        sum($"id"),
+        count(when($"id" % 2 === 0, 1)).as("num_even"))
+      .persist()
+      .observe(
+        name = "other_event",
+        avg($"id").cast("int").as("avg_val"))
+      .persist()
+
+    validateObservedMetrics(df)
+  }
+
+  test("SPARK-35695: get observable metrics with adaptive execution by 
callback") {
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
       val df = spark.range(100)
         .observe(
           name = "my_event",
@@ -255,35 +298,16 @@ class DataFrameCallbackSuite extends QueryTest
           // Test unresolved alias
           sum($"id"),
           count(when($"id" % 2 === 0, 1)).as("num_even"))
+        .repartition($"id")
         .observe(
           name = "other_event",
           avg($"id").cast("int").as("avg_val"))
 
-      def checkMetrics(metrics: Map[String, Row]): Unit = {
-        assert(metrics.size === 2)
-        assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L))
-        assert(metrics("other_event") === Row(49))
-      }
-
-      // First run
-      df.collect()
-      sparkContext.listenerBus.waitUntilEmpty()
-      assert(metricMaps.size === 1)
-      checkMetrics(metricMaps.head)
-      metricMaps.clear()
-
-      // Second run should produce the same result as the first run.
-      df.collect()
-      sparkContext.listenerBus.waitUntilEmpty()
-      assert(metricMaps.size === 1)
-      checkMetrics(metricMaps.head)
-
-    } finally {
-      spark.listenerManager.unregister(listener)
+      validateObservedMetrics(df)
     }
   }
 
-  test("SPARK-35296: observe should work even if a task contains multiple 
partitions") {
+  private def validateObservedMetrics(df: Dataset[JLong]): Unit = {
     val metricMaps = ArrayBuffer.empty[Map[String, Row]]
     val listener = new QueryExecutionListener {
       override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
@@ -296,27 +320,31 @@ class DataFrameCallbackSuite extends QueryTest
     }
     spark.listenerManager.register(listener)
     try {
-      val df = spark.range(1, 4, 1, 3)
-        .observe(
-          name = "my_event",
-          count($"id").as("count_val"))
-        .coalesce(2)
-
       def checkMetrics(metrics: Map[String, Row]): Unit = {
-        assert(metrics.size === 1)
-        assert(metrics("my_event") === Row(3L))
+        assert(metrics.size === 2)
+        assert(metrics("my_event") === Row(0L, 99L, 4950L, 50L))
+        assert(metrics("other_event") === Row(49))
       }
 
+      // First run
       df.collect()
       sparkContext.listenerBus.waitUntilEmpty()
       assert(metricMaps.size === 1)
       checkMetrics(metricMaps.head)
       metricMaps.clear()
+
+      // Second run should produce the same result as the first run.
+      df.collect()
+      sparkContext.listenerBus.waitUntilEmpty()
+      assert(metricMaps.size === 1)
+      checkMetrics(metricMaps.head)
+
     } finally {
       spark.listenerManager.unregister(listener)
     }
   }
 
+
   testQuietly("SPARK-31144: QueryExecutionListener should receive 
`java.lang.Error`") {
     var e: Exception = null
     val listener = new QueryExecutionListener {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to