Repository: spark
Updated Branches:
  refs/heads/master e9332f600 -> 9690eba16


[SPARK-25680][SQL] SQL execution listener shouldn't happen on execution thread

## What changes were proposed in this pull request?

The SQL execution listener framework was created from scratch(see 
https://github.com/apache/spark/pull/9078). It didn't leverage what we already 
have in the spark listener framework, and one major problem is, the listener 
runs on the spark execution thread, which means a bad listener can block 
spark's query processing.

This PR re-implements the SQL execution listener framework. Now 
`ExecutionListenerManager` is just a normal spark listener, which watches the 
`SparkListenerSQLExecutionEnd` events and post events to the
user-provided SQL execution listeners.

## How was this patch tested?

existing tests.

Closes #22674 from cloud-fan/listener.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9690eba1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9690eba1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9690eba1

Branch: refs/heads/master
Commit: 9690eba16efe6d25261934d8b73a221972b684f3
Parents: e9332f6
Author: Wenchen Fan <[email protected]>
Authored: Wed Oct 17 16:06:07 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Wed Oct 17 16:06:07 2018 +0800

----------------------------------------------------------------------
 .../org/apache/spark/util/ListenerBus.scala     |   8 ++
 project/MimaExcludes.scala                      |   4 +-
 .../org/apache/spark/sql/DataFrameWriter.scala  |  13 +--
 .../scala/org/apache/spark/sql/Dataset.scala    |  14 +--
 .../spark/sql/execution/SQLExecution.scala      |  34 ++++--
 .../spark/sql/execution/ui/SQLListener.scala    |  21 +++-
 .../sql/internal/BaseSessionStateBuilder.scala  |   4 +-
 .../spark/sql/util/QueryExecutionListener.scala | 108 +++++++++----------
 .../apache/spark/sql/SessionStateSuite.scala    |   3 +
 .../scala/org/apache/spark/sql/UDFSuite.scala   |   3 +
 .../sql/execution/SQLJsonProtocolSuite.scala    |  33 +++++-
 .../spark/sql/util/DataFrameCallbackSuite.scala |  13 +++
 .../util/ExecutionListenerManagerSuite.scala    |  14 +--
 13 files changed, 168 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala 
b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
index a8f1068..2e51770 100644
--- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
@@ -61,6 +61,14 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends 
Logging {
   }
 
   /**
+   * Remove all listeners and they won't receive any events. This method is 
thread-safe and can be
+   * called in any thread.
+   */
+  final def removeAllListeners(): Unit = {
+    listenersPlusTimers.clear()
+  }
+
+  /**
    * This can be overridden by subclasses if there is any extra cleanup to do 
when removing a
    * listener.  In particular AsyncEventQueues can clean up queues in the 
LiveListenerBus.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 851fa23..a5d6d63 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -38,7 +38,9 @@ object MimaExcludes {
   lazy val v30excludes = v24excludes ++ Seq(
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"),
     
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"),
-    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues")
+    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"),
+    
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this")
   )
 
   // Exclude rules for 2.4.x

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 5d0feec..5a28870 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -672,17 +672,8 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) {
    */
   private def runCommand(session: SparkSession, name: String)(command: 
LogicalPlan): Unit = {
     val qe = session.sessionState.executePlan(command)
-    try {
-      val start = System.nanoTime()
-      // call `QueryExecution.toRDD` to trigger the execution of commands.
-      SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
-      val end = System.nanoTime()
-      session.listenerManager.onSuccess(name, qe, end - start)
-    } catch {
-      case e: Exception =>
-        session.listenerManager.onFailure(name, qe, e)
-        throw e
-    }
+    // call `QueryExecution.toRDD` to trigger the execution of commands.
+    SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
   }
 
   
///////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index fa14aa1..0fb3301 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3356,21 +3356,11 @@ class Dataset[T] private[sql](
    * user-registered callback functions.
    */
   private def withAction[U](name: String, qe: QueryExecution)(action: 
SparkPlan => U) = {
-    try {
+    SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) {
       qe.executedPlan.foreach { plan =>
         plan.resetMetrics()
       }
-      val start = System.nanoTime()
-      val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
-        action(qe.executedPlan)
-      }
-      val end = System.nanoTime()
-      sparkSession.listenerManager.onSuccess(name, qe, end - start)
-      result
-    } catch {
-      case e: Exception =>
-        sparkSession.listenerManager.onFailure(name, qe, e)
-        throw e
+      action(qe.executedPlan)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 439932b..dda7cb5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -58,7 +58,8 @@ object SQLExecution {
    */
   def withNewExecutionId[T](
       sparkSession: SparkSession,
-      queryExecution: QueryExecution)(body: => T): T = {
+      queryExecution: QueryExecution,
+      name: Option[String] = None)(body: => T): T = {
     val sc = sparkSession.sparkContext
     val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
     val executionId = SQLExecution.nextExecutionId
@@ -71,14 +72,35 @@ object SQLExecution {
       val callSite = sc.getCallSite()
 
       withSQLConfPropagated(sparkSession) {
-        sc.listenerBus.post(SparkListenerSQLExecutionStart(
-          executionId, callSite.shortForm, callSite.longForm, 
queryExecution.toString,
-          SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), 
System.currentTimeMillis()))
+        var ex: Option[Exception] = None
+        val startTime = System.nanoTime()
         try {
+          sc.listenerBus.post(SparkListenerSQLExecutionStart(
+            executionId = executionId,
+            description = callSite.shortForm,
+            details = callSite.longForm,
+            physicalPlanDescription = queryExecution.toString,
+            // `queryExecution.executedPlan` triggers query planning. If it 
fails, the exception
+            // will be caught and reported in the 
`SparkListenerSQLExecutionEnd`
+            sparkPlanInfo = 
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
+            time = System.currentTimeMillis()))
           body
+        } catch {
+          case e: Exception =>
+            ex = Some(e)
+            throw e
         } finally {
-          sc.listenerBus.post(SparkListenerSQLExecutionEnd(
-            executionId, System.currentTimeMillis()))
+          val endTime = System.nanoTime()
+          val event = SparkListenerSQLExecutionEnd(executionId, 
System.currentTimeMillis())
+          // Currently only `Dataset.withAction` and 
`DataFrameWriter.runCommand` specify the `name`
+          // parameter. The `ExecutionListenerManager` only watches SQL 
executions with name. We
+          // can specify the execution name in more places in the future, so 
that
+          // `QueryExecutionListener` can track more cases.
+          event.executionName = name
+          event.duration = endTime - startTime
+          event.qe = queryExecution
+          event.executionFailure = ex
+          sc.listenerBus.post(event)
         }
       }
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index b58b8c6..c04a31c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.ui
 
+import com.fasterxml.jackson.annotation.JsonIgnore
 import com.fasterxml.jackson.databind.JavaType
 import com.fasterxml.jackson.databind.`type`.TypeFactory
 import com.fasterxml.jackson.databind.annotation.JsonDeserialize
@@ -24,8 +25,7 @@ import com.fasterxml.jackson.databind.util.Converter
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.scheduler._
-import org.apache.spark.sql.execution.SparkPlanInfo
-import org.apache.spark.sql.execution.metric._
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlanInfo}
 
 @DeveloperApi
 case class SparkListenerSQLExecutionStart(
@@ -39,7 +39,22 @@ case class SparkListenerSQLExecutionStart(
 
 @DeveloperApi
 case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long)
-  extends SparkListenerEvent
+  extends SparkListenerEvent {
+
+  // The name of the execution, e.g. `df.collect` will trigger a SQL execution 
with name "collect".
+  @JsonIgnore private[sql] var executionName: Option[String] = None
+
+  // The following 3 fields are only accessed when `executionName` is defined.
+
+  // The duration of the SQL execution, in nanoseconds.
+  @JsonIgnore private[sql] var duration: Long = 0L
+
+  // The `QueryExecution` instance that represents the SQL execution
+  @JsonIgnore private[sql] var qe: QueryExecution = null
+
+  // The exception object that caused this execution to fail. None if the 
execution doesn't fail.
+  @JsonIgnore private[sql] var executionFailure: Option[Exception] = None
+}
 
 /**
  * A message used to update SQL metric value for driver-side updates (which 
doesn't get reflected

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 3a0db7e..60bba5e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -266,8 +266,8 @@ abstract class BaseSessionStateBuilder(
    * This gets cloned from parent if available, otherwise a new instance is 
created.
    */
   protected def listenerManager: ExecutionListenerManager = {
-    parentState.map(_.listenerManager.clone()).getOrElse(
-      new ExecutionListenerManager(session.sparkContext.conf))
+    parentState.map(_.listenerManager.clone(session)).getOrElse(
+      new ExecutionListenerManager(session, loadExtensions = true))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
index 2b46233..1310fdf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
@@ -17,17 +17,16 @@
 
 package org.apache.spark.sql.util
 
-import java.util.concurrent.locks.ReentrantReadWriteLock
+import scala.collection.JavaConverters._
 
-import scala.collection.mutable.ListBuffer
-import scala.util.control.NonFatal
-
-import org.apache.spark.SparkConf
 import org.apache.spark.annotation.{DeveloperApi, Experimental, 
InterfaceStability}
 import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
 import org.apache.spark.sql.internal.StaticSQLConf._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ListenerBus, Utils}
 
 /**
  * :: Experimental ::
@@ -75,10 +74,18 @@ trait QueryExecutionListener {
  */
 @Experimental
 @InterfaceStability.Evolving
-class ExecutionListenerManager private extends Logging {
-
-  private[sql] def this(conf: SparkConf) = {
-    this()
+// The `session` is used to indicate which session carries this listener 
manager, and we only
+// catch SQL executions which are launched by the same session.
+// The `loadExtensions` flag is used to indicate whether we should load the 
pre-defined,
+// user-specified listeners during construction. We should not do it when 
cloning this listener
+// manager, as we will copy all listeners to the cloned listener manager.
+class ExecutionListenerManager private[sql](session: SparkSession, 
loadExtensions: Boolean)
+  extends Logging {
+
+  private val listenerBus = new ExecutionListenerBus(session)
+
+  if (loadExtensions) {
+    val conf = session.sparkContext.conf
     conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames =>
       Utils.loadExtensions(classOf[QueryExecutionListener], classNames, 
conf).foreach(register)
     }
@@ -88,82 +95,63 @@ class ExecutionListenerManager private extends Logging {
    * Registers the specified [[QueryExecutionListener]].
    */
   @DeveloperApi
-  def register(listener: QueryExecutionListener): Unit = writeLock {
-    listeners += listener
+  def register(listener: QueryExecutionListener): Unit = {
+    listenerBus.addListener(listener)
   }
 
   /**
    * Unregisters the specified [[QueryExecutionListener]].
    */
   @DeveloperApi
-  def unregister(listener: QueryExecutionListener): Unit = writeLock {
-    listeners -= listener
+  def unregister(listener: QueryExecutionListener): Unit = {
+    listenerBus.removeListener(listener)
   }
 
   /**
    * Removes all the registered [[QueryExecutionListener]].
    */
   @DeveloperApi
-  def clear(): Unit = writeLock {
-    listeners.clear()
+  def clear(): Unit = {
+    listenerBus.removeAllListeners()
   }
 
   /**
    * Get an identical copy of this listener manager.
    */
-  @DeveloperApi
-  override def clone(): ExecutionListenerManager = writeLock {
-    val newListenerManager = new ExecutionListenerManager
-    listeners.foreach(newListenerManager.register)
+  private[sql] def clone(session: SparkSession): ExecutionListenerManager = {
+    val newListenerManager = new ExecutionListenerManager(session, 
loadExtensions = false)
+    listenerBus.listeners.asScala.foreach(newListenerManager.register)
     newListenerManager
   }
+}
 
-  private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
-    readLock {
-      withErrorHandling { listener =>
-        listener.onSuccess(funcName, qe, duration)
-      }
-    }
-  }
-
-  private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {
-    readLock {
-      withErrorHandling { listener =>
-        listener.onFailure(funcName, qe, exception)
-      }
-    }
-  }
-
-  private[this] val listeners = ListBuffer.empty[QueryExecutionListener]
+private[sql] class ExecutionListenerBus(session: SparkSession)
+  extends SparkListener with ListenerBus[QueryExecutionListener, 
SparkListenerSQLExecutionEnd] {
 
-  /** A lock to prevent updating the list of listeners while we are traversing 
through them. */
-  private[this] val lock = new ReentrantReadWriteLock()
+  session.sparkContext.listenerBus.addToSharedQueue(this)
 
-  private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = {
-    for (listener <- listeners) {
-      try {
-        f(listener)
-      } catch {
-        case NonFatal(e) => logWarning("Error executing query execution 
listener", e)
-      }
-    }
+  override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
+    case e: SparkListenerSQLExecutionEnd => postToAll(e)
+    case _ =>
   }
 
-  /** Acquires a read lock on the cache for the duration of `f`. */
-  private def readLock[A](f: => A): A = {
-    val rl = lock.readLock()
-    rl.lock()
-    try f finally {
-      rl.unlock()
+  override protected def doPostEvent(
+      listener: QueryExecutionListener,
+      event: SparkListenerSQLExecutionEnd): Unit = {
+    if (shouldReport(event)) {
+      val funcName = event.executionName.get
+      event.executionFailure match {
+        case Some(ex) =>
+          listener.onFailure(funcName, event.qe, ex)
+        case _ =>
+          listener.onSuccess(funcName, event.qe, event.duration)
+      }
     }
   }
 
-  /** Acquires a write lock on the cache for the duration of `f`. */
-  private def writeLock[A](f: => A): A = {
-    val wl = lock.writeLock()
-    wl.lock()
-    try f finally {
-      wl.unlock()
-    }
+  private def shouldReport(e: SparkListenerSQLExecutionEnd): Boolean = {
+    // Only catch SQL execution with a name, and triggered by the same spark 
session that this
+    // listener manager belongs.
+    e.executionName.isDefined && e.qe.sparkSession.eq(this.session)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
index e1b5eba..6317cd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
@@ -155,6 +155,7 @@ class SessionStateSuite extends SparkFunSuite {
       assert(forkedSession ne activeSession)
       assert(forkedSession.listenerManager ne activeSession.listenerManager)
       runCollectQueryOn(forkedSession)
+      activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
       assert(collectorA.commands.length == 1) // forked should callback to A
       assert(collectorA.commands(0) == "collect")
 
@@ -162,12 +163,14 @@ class SessionStateSuite extends SparkFunSuite {
       // => changes to forked do not affect original
       forkedSession.listenerManager.register(collectorB)
       runCollectQueryOn(activeSession)
+      activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
       assert(collectorB.commands.isEmpty) // original should not callback to B
       assert(collectorA.commands.length == 2) // original should still 
callback to A
       assert(collectorA.commands(1) == "collect")
       // <= changes to original do not affect forked
       activeSession.listenerManager.register(collectorC)
       runCollectQueryOn(forkedSession)
+      activeSession.sparkContext.listenerBus.waitUntilEmpty(1000)
       assert(collectorC.commands.isEmpty) // forked should not callback to C
       assert(collectorA.commands.length == 3) // forked should still callback 
to A
       assert(collectorB.commands.length == 1) // forked should still callback 
to B

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 30dca94..269600d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -356,10 +356,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
           .withColumn("b", udf1($"a", lit(10)))
         df.cache()
         df.write.saveAsTable("t")
+        sparkContext.listenerBus.waitUntilEmpty(1000)
         assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
         df.write.insertInto("t")
+        sparkContext.listenerBus.waitUntilEmpty(1000)
         assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
         df.write.save(path.getCanonicalPath)
+        sparkContext.listenerBus.waitUntilEmpty(1000)
         assert(numTotalCachedHit == 3, "expected to be cached in save for 
native")
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
index 08e40e2..08789e6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.execution
 
-import org.json4s.jackson.JsonMethods.parse
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart
+import org.apache.spark.sql.LocalSparkSession
+import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, 
SparkListenerSQLExecutionStart}
+import org.apache.spark.sql.test.TestSparkSession
 import org.apache.spark.util.JsonProtocol
 
-class SQLJsonProtocolSuite extends SparkFunSuite {
+class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession {
 
   test("SparkPlanGraph backward compatibility: metadata") {
     val SQLExecutionStartJsonString =
@@ -49,4 +51,29 @@ class SQLJsonProtocolSuite extends SparkFunSuite {
       new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0)
     assert(reconstructedEvent == expectedEvent)
   }
+
+  test("SparkListenerSQLExecutionEnd backward compatibility") {
+    spark = new TestSparkSession()
+    val qe = spark.sql("select 1").queryExecution
+    val event = SparkListenerSQLExecutionEnd(1, 10)
+    event.duration = 1000
+    event.executionName = Some("test")
+    event.qe = qe
+    event.executionFailure = Some(new RuntimeException("test"))
+    val json = JsonProtocol.sparkEventToJson(event)
+    assert(json == parse(
+      """
+        |{
+        |  "Event" : 
"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd",
+        |  "executionId" : 1,
+        |  "time" : 10
+        |}
+      """.stripMargin))
+    val readBack = JsonProtocol.sparkEventFromJson(json)
+    event.duration = 0
+    event.executionName = None
+    event.qe = null
+    event.executionFailure = None
+    assert(readBack == event)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index a239e39..e8710ae 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -48,6 +48,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
     df.select("i").collect()
     df.filter($"i" > 0).count()
 
+    sparkContext.listenerBus.waitUntilEmpty(1000)
     assert(metrics.length == 2)
 
     assert(metrics(0)._1 == "collect")
@@ -78,6 +79,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
 
     val e = intercept[SparkException](df.select(errorUdf($"i")).collect())
 
+    sparkContext.listenerBus.waitUntilEmpty(1000)
     assert(metrics.length == 1)
     assert(metrics(0)._1 == "collect")
     assert(metrics(0)._2.analyzed.isInstanceOf[Project])
@@ -103,10 +105,16 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
     spark.listenerManager.register(listener)
 
     val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
+
     df.collect()
+    // Wait for the first `collect` to be caught by our listener. Otherwise 
the next `collect` will
+    // reset the plan metrics.
+    sparkContext.listenerBus.waitUntilEmpty(1000)
     df.collect()
+
     Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
 
+    sparkContext.listenerBus.waitUntilEmpty(1000)
     assert(metrics.length == 3)
     assert(metrics(0) === 1)
     assert(metrics(1) === 1)
@@ -154,6 +162,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
 
     // For this simple case, the peakExecutionMemory of a stage should be the 
data size of the
     // aggregate operator, as we only have one memory consuming operator per 
stage.
+    sparkContext.listenerBus.waitUntilEmpty(1000)
     assert(metrics.length == 2)
     assert(metrics(0) == topAggDataSize)
     assert(metrics(1) == bottomAggDataSize)
@@ -177,6 +186,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
 
     withTempPath { path =>
       spark.range(10).write.format("json").save(path.getCanonicalPath)
+      sparkContext.listenerBus.waitUntilEmpty(1000)
       assert(commands.length == 1)
       assert(commands.head._1 == "save")
       assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand])
@@ -187,6 +197,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
     withTable("tab") {
       sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via 
onSuccess
       spark.range(10).write.insertInto("tab")
+      sparkContext.listenerBus.waitUntilEmpty(1000)
       assert(commands.length == 3)
       assert(commands(2)._1 == "insertInto")
       assert(commands(2)._2.isInstanceOf[InsertIntoTable])
@@ -197,6 +208,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
 
     withTable("tab") {
       spark.range(10).select($"id", $"id" % 5 as 
"p").write.partitionBy("p").saveAsTable("tab")
+      sparkContext.listenerBus.waitUntilEmpty(1000)
       assert(commands.length == 5)
       assert(commands(4)._1 == "saveAsTable")
       assert(commands(4)._2.isInstanceOf[CreateTable])
@@ -208,6 +220,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
       val e = intercept[AnalysisException] {
         spark.range(10).select($"id", $"id").write.insertInto("tab")
       }
+      sparkContext.listenerBus.waitUntilEmpty(1000)
       assert(exceptions.length == 1)
       assert(exceptions.head._1 == "insertInto")
       assert(exceptions.head._2 == e)

http://git-wip-us.apache.org/repos/asf/spark/blob/9690eba1/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala
index 4205e23..da414f4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala
@@ -20,26 +20,28 @@ package org.apache.spark.sql.util
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.spark._
+import org.apache.spark.sql.{LocalSparkSession, SparkSession}
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.internal.StaticSQLConf._
 
-class ExecutionListenerManagerSuite extends SparkFunSuite {
+class ExecutionListenerManagerSuite extends SparkFunSuite with 
LocalSparkSession {
 
   import CountingQueryExecutionListener._
 
   test("register query execution listeners using configuration") {
     val conf = new SparkConf(false)
       .set(QUERY_EXECUTION_LISTENERS, 
Seq(classOf[CountingQueryExecutionListener].getName()))
+    spark = 
SparkSession.builder().master("local").appName("test").config(conf).getOrCreate()
 
-    val mgr = new ExecutionListenerManager(conf)
+    spark.sql("select 1").collect()
+    spark.sparkContext.listenerBus.waitUntilEmpty(1000)
     assert(INSTANCE_COUNT.get() === 1)
-    mgr.onSuccess(null, null, 42L)
     assert(CALLBACK_COUNT.get() === 1)
 
-    val clone = mgr.clone()
+    val cloned = spark.cloneSession()
+    cloned.sql("select 1").collect()
+    spark.sparkContext.listenerBus.waitUntilEmpty(1000)
     assert(INSTANCE_COUNT.get() === 1)
-
-    clone.onSuccess(null, null, 42L)
     assert(CALLBACK_COUNT.get() === 2)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to