This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 0d4a2e6  [SPARK-35296][SQL] Allow Dataset.observe to work even if 
CollectMetricsExec in a task handles multiple partitions
0d4a2e6 is described below

commit 0d4a2e67677cea4e144ab14480f321068fc00961
Author: Kousuke Saruta <saru...@oss.nttdata.com>
AuthorDate: Fri Jun 11 01:20:35 2021 +0800

    [SPARK-35296][SQL] Allow Dataset.observe to work even if CollectMetricsExec 
in a task handles multiple partitions
    
    ### What changes were proposed in this pull request?
    
    This PR fixes an issue that `Dataset.observe` doesn't work if 
`CollectMetricsExec` in a task handles multiple partitions.
    If `coalesce` follows `observe` and the number of partitions shrinks after 
`coalesce`, `CollectMetricsExec` can handle multiple partitions in a task.
    
    ### Why are the changes needed?
    
    The current implementation of `CollectMetricsExec` doesn't consider the 
case it can handle multiple partitions.
    Because new `updater` is created for each partition even though those 
partitions belong to the same task, `collector.setState(updater)` raise an 
assertion error.
    This is a simple reproducible example.
    ```
    $ bin/spark-shell --master "local[1]"
    scala> spark.range(1, 4, 1, 3).observe("my_event", 
count($"id").as("count_val")).coalesce(2).collect
    ```
    ```
    java.lang.AssertionError: assertion failed
        at scala.Predef$.assert(Predef.scala:208)
        at 
org.apache.spark.sql.execution.AggregatingAccumulator.setState(AggregatingAccumulator.scala:204)
        at 
org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2(CollectMetricsExec.scala:72)
        at 
org.apache.spark.sql.execution.CollectMetricsExec.$anonfun$doExecute$2$adapted(CollectMetricsExec.scala:71)
        at 
org.apache.spark.TaskContext$$anon$1.onTaskCompletion(TaskContext.scala:125)
        at 
org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1(TaskContextImpl.scala:124)
        at 
org.apache.spark.TaskContextImpl.$anonfun$markTaskCompleted$1$adapted(TaskContextImpl.scala:124)
        at 
org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1(TaskContextImpl.scala:137)
        at 
org.apache.spark.TaskContextImpl.$anonfun$invokeListeners$1$adapted(TaskContextImpl.scala:135)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New test.
    
    Closes #32786 from sarutak/fix-collectmetricsexec.
    
    Authored-by: Kousuke Saruta <saru...@oss.nttdata.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 44b695fbb06b0d89783b4838941c68543c5a5c8b)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/execution/AggregatingAccumulator.scala     | 16 +++++-----
 .../spark/sql/execution/CollectMetricsExec.scala   |  6 +++-
 .../spark/sql/util/DataFrameCallbackSuite.scala    | 34 ++++++++++++++++++++++
 3 files changed, 48 insertions(+), 8 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
index 94e159c..0fa4e6c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AggregatingAccumulator.scala
@@ -33,7 +33,7 @@ class AggregatingAccumulator private(
     bufferSchema: Seq[DataType],
     initialValues: Seq[Expression],
     updateExpressions: Seq[Expression],
-    @transient private val mergeExpressions: Seq[Expression],
+    mergeExpressions: Seq[Expression],
     @transient private val resultExpressions: Seq[Expression],
     imperatives: Array[ImperativeAggregate],
     typedImperatives: Array[TypedImperativeAggregate[_]],
@@ -95,13 +95,14 @@ class AggregatingAccumulator private(
 
   /**
    * Driver side operations like `merge` and `value` are executed in the 
DAGScheduler thread. This
-   * thread does not have a SQL configuration so we attach our own here. Note 
that we can't (and
-   * shouldn't) call `merge` or `value` on an accumulator originating from an 
executor so we just
-   * return a default value here.
+   * thread does not have a SQL configuration so we attach our own here.
    */
-  private[this] def withSQLConf[T](default: => T)(body: => T): T = {
+  private[this] def withSQLConf[T](canRunOnExecutor: Boolean, default: => 
T)(body: => T): T = {
     if (conf != null) {
+      // When we can reach here, we are on the driver side.
       SQLConf.withExistingConf(conf)(body)
+    } else if (canRunOnExecutor) {
+      body
     } else {
       default
     }
@@ -147,7 +148,8 @@ class AggregatingAccumulator private(
     }
   }
 
-  override def merge(other: AccumulatorV2[InternalRow, InternalRow]): Unit = 
withSQLConf(()) {
+  override def merge(
+      other: AccumulatorV2[InternalRow, InternalRow]): Unit = 
withSQLConf(true, ()) {
     if (!other.isZero) {
       other match {
         case agg: AggregatingAccumulator =>
@@ -171,7 +173,7 @@ class AggregatingAccumulator private(
     }
   }
 
-  override def value: InternalRow = withSQLConf(InternalRow.empty) {
+  override def value: InternalRow = withSQLConf(false, InternalRow.empty) {
     // Either use the existing buffer or create a temporary one.
     val input = if (!isZero) {
       buffer
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
index b0bbb52..2883bc0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
@@ -69,7 +69,11 @@ case class CollectMetricsExec(
       // - Performance issues due to excessive serialization.
       val updater = collector.copyAndReset()
       TaskContext.get().addTaskCompletionListener[Unit] { _ =>
-        collector.setState(updater)
+        if (collector.isZero) {
+          collector.setState(updater)
+        } else {
+          collector.merge(updater)
+        }
       }
 
       rows.map { r =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index b17c935..296c810 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -281,6 +281,40 @@ class DataFrameCallbackSuite extends QueryTest
     }
   }
 
+  test("SPARK-35296: observe should work even if a task contains multiple 
partitions") {
+    val metricMaps = ArrayBuffer.empty[Map[String, Row]]
+    val listener = new QueryExecutionListener {
+      override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
+        metricMaps += qe.observedMetrics
+      }
+
+      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {
+        // No-op
+      }
+    }
+    spark.listenerManager.register(listener)
+    try {
+      val df = spark.range(1, 4, 1, 3)
+        .observe(
+          name = "my_event",
+          count($"id").as("count_val"))
+        .coalesce(2)
+
+      def checkMetrics(metrics: Map[String, Row]): Unit = {
+        assert(metrics.size === 1)
+        assert(metrics("my_event") === Row(3L))
+      }
+
+      df.collect()
+      sparkContext.listenerBus.waitUntilEmpty()
+      assert(metricMaps.size === 1)
+      checkMetrics(metricMaps.head)
+      metricMaps.clear()
+    } finally {
+      spark.listenerManager.unregister(listener)
+    }
+  }
+
   testQuietly("SPARK-31144: QueryExecutionListener should receive 
`java.lang.Error`") {
     var e: Exception = null
     val listener = new QueryExecutionListener {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to