Repository: spark
Updated Branches:
  refs/heads/master 4741c0780 -> 0ffa7c488


http://git-wip-us.apache.org/repos/asf/spark/blob/0ffa7c48/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 1055f09..eba8d55 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.ui
 
 import java.util.Properties
 
+import scala.collection.mutable.ListBuffer
+
 import org.json4s.jackson.JsonMethods._
-import org.mockito.Mockito.mock
 
 import org.apache.spark._
 import org.apache.spark.LocalSparkContext._
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.internal.config
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler._
@@ -36,13 +36,14 @@ import org.apache.spark.sql.catalyst.util.quietly
 import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, 
SparkPlanInfo, SQLExecution}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.ui.SparkUI
+import org.apache.spark.status.config._
 import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, 
LongAccumulator}
-
+import org.apache.spark.util.kvstore.InMemoryStore
 
 class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with 
JsonTestUtils {
   import testImplicits._
-  import org.apache.spark.AccumulatorSuite.makeInfo
+
+  override protected def sparkConf = 
super.sparkConf.set(LIVE_ENTITY_UPDATE_PERIOD, 0L)
 
   private def createTestDataFrame: DataFrame = {
     Seq(
@@ -68,44 +69,67 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext with JsonTest
     details = ""
   )
 
-  private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new 
TaskInfo(
-    taskId = taskId,
-    attemptNumber = attemptNumber,
-    // The following fields are not used in tests
-    index = 0,
-    launchTime = 0,
-    executorId = "",
-    host = "",
-    taskLocality = null,
-    speculative = false
-  )
+  private def createTaskInfo(
+      taskId: Int,
+      attemptNumber: Int,
+      accums: Map[Long, Long] = Map()): TaskInfo = {
+    val info = new TaskInfo(
+      taskId = taskId,
+      attemptNumber = attemptNumber,
+      // The following fields are not used in tests
+      index = 0,
+      launchTime = 0,
+      executorId = "",
+      host = "",
+      taskLocality = null,
+      speculative = false)
+    info.markFinished(TaskState.FINISHED, 1L)
+    info.setAccumulables(createAccumulatorInfos(accums))
+    info
+  }
 
-  private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): 
TaskMetrics = {
-    val metrics = TaskMetrics.empty
-    accumulatorUpdates.foreach { case (id, update) =>
+  private def createAccumulatorInfos(accumulatorUpdates: Map[Long, Long]): 
Seq[AccumulableInfo] = {
+    accumulatorUpdates.map { case (id, value) =>
       val acc = new LongAccumulator
-      acc.metadata = AccumulatorMetadata(id, Some(""), true)
-      acc.add(update)
-      metrics.registerAccumulator(acc)
+      acc.metadata = AccumulatorMetadata(id, None, false)
+      acc.toInfo(Some(value), None)
+    }.toSeq
+  }
+
+  /** Return the shared SQL store from the active SparkSession. */
+  private def statusStore: SQLAppStatusStore =
+    new SQLAppStatusStore(spark.sparkContext.statusStore.store)
+
+  /**
+   * Runs a test with a temporary SQLAppStatusStore tied to a listener bus. 
Events can be sent to
+   * the listener bus to update the store, and all data will be cleaned up at 
the end of the test.
+   */
+  private def sqlStoreTest(name: String)
+      (fn: (SQLAppStatusStore, SparkListenerBus) => Unit): Unit = {
+    test(name) {
+      val store = new InMemoryStore()
+      val bus = new ReplayListenerBus()
+      val listener = new SQLAppStatusListener(sparkConf, store, true)
+      bus.addListener(listener)
+      val sqlStore = new SQLAppStatusStore(store, Some(listener))
+      fn(sqlStore, bus)
     }
-    metrics
   }
 
-  test("basic") {
+  sqlStoreTest("basic") { (store, bus) =>
     def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): 
Unit = {
       assert(actual.size == expected.size)
-      expected.foreach { e =>
+      expected.foreach { case (id, value) =>
         // The values in actual can be SQL metrics meaning that they contain 
additional formatting
         // when converted to string. Verify that they start with the expected 
value.
         // TODO: this is brittle. There is no requirement that the actual 
string needs to start
         // with the accumulator value.
-        assert(actual.contains(e._1))
-        val v = actual.get(e._1).get.trim
-        assert(v.startsWith(e._2.toString))
+        assert(actual.contains(id))
+        val v = actual.get(id).get.trim
+        assert(v.startsWith(value.toString), s"Wrong value for accumulator 
$id")
       }
     }
 
-    val listener = new SQLListener(spark.sparkContext.conf)
     val executionId = 0
     val df = createTestDataFrame
     val accumulatorIds =
@@ -118,7 +142,7 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext with JsonTest
       (id, accumulatorValue)
     }.toMap
 
-    listener.onOtherEvent(SparkListenerSQLExecutionStart(
+    bus.postToAll(SparkListenerSQLExecutionStart(
       executionId,
       "test",
       "test",
@@ -126,9 +150,7 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext with JsonTest
       SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan),
       System.currentTimeMillis()))
 
-    val executionUIData = listener.executionIdToData(0)
-
-    listener.onJobStart(SparkListenerJobStart(
+    bus.postToAll(SparkListenerJobStart(
       jobId = 0,
       time = System.currentTimeMillis(),
       stageInfos = Seq(
@@ -136,291 +158,270 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext with JsonTest
         createStageInfo(1, 0)
       ),
       createProperties(executionId)))
-    listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 
0)))
+    bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 0)))
 
-    assert(listener.getExecutionMetrics(0).isEmpty)
+    assert(store.executionMetrics(0).isEmpty)
 
-    listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
+    bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
-      (1L, 0, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
+      (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)),
+      (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates))
     )))
 
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 2))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     // Driver accumulator updates don't belong to this execution should be 
filtered and no
     // exception will be thrown.
-    listener.onOtherEvent(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L))))
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 2))
+    bus.postToAll(SparkListenerDriverAccumUpdates(0, Seq((999L, 2L))))
 
-    listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
+
+    bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
-      (1L, 0, 0,
-        createTaskMetrics(accumulatorUpdates.mapValues(_ * 
2)).accumulators().map(makeInfo))
+      (0L, 0, 0, createAccumulatorInfos(accumulatorUpdates)),
+      (1L, 0, 0, createAccumulatorInfos(accumulatorUpdates.mapValues(_ * 2)))
     )))
 
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 3))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
 
     // Retrying a stage should reset the metrics
-    listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 
1)))
+    bus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1)))
 
-    listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
+    bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 1, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
-      (1L, 0, 1, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
+      (0L, 0, 1, createAccumulatorInfos(accumulatorUpdates)),
+      (1L, 0, 1, createAccumulatorInfos(accumulatorUpdates))
     )))
 
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 2))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     // Ignore the task end for the first attempt
-    listener.onTaskEnd(SparkListenerTaskEnd(
+    bus.postToAll(SparkListenerTaskEnd(
       stageId = 0,
       stageAttemptId = 0,
       taskType = "",
       reason = null,
-      createTaskInfo(0, 0),
-      createTaskMetrics(accumulatorUpdates.mapValues(_ * 100))))
+      createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 100)),
+      null))
 
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 2))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     // Finish two tasks
-    listener.onTaskEnd(SparkListenerTaskEnd(
+    bus.postToAll(SparkListenerTaskEnd(
       stageId = 0,
       stageAttemptId = 1,
       taskType = "",
       reason = null,
-      createTaskInfo(0, 0),
-      createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))))
-    listener.onTaskEnd(SparkListenerTaskEnd(
+      createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 2)),
+      null))
+    bus.postToAll(SparkListenerTaskEnd(
       stageId = 0,
       stageAttemptId = 1,
       taskType = "",
       reason = null,
-      createTaskInfo(1, 0),
-      createTaskMetrics(accumulatorUpdates.mapValues(_ * 3))))
+      createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)),
+      null))
 
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 5))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 5))
 
     // Summit a new stage
-    listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 
0)))
+    bus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 0)))
 
-    listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
+    bus.postToAll(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 1, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
-      (1L, 1, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
+      (0L, 1, 0, createAccumulatorInfos(accumulatorUpdates)),
+      (1L, 1, 0, createAccumulatorInfos(accumulatorUpdates))
     )))
 
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 7))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
 
     // Finish two tasks
-    listener.onTaskEnd(SparkListenerTaskEnd(
+    bus.postToAll(SparkListenerTaskEnd(
       stageId = 1,
       stageAttemptId = 0,
       taskType = "",
       reason = null,
-      createTaskInfo(0, 0),
-      createTaskMetrics(accumulatorUpdates.mapValues(_ * 3))))
-    listener.onTaskEnd(SparkListenerTaskEnd(
+      createTaskInfo(0, 0, accums = accumulatorUpdates.mapValues(_ * 3)),
+      null))
+    bus.postToAll(SparkListenerTaskEnd(
       stageId = 1,
       stageAttemptId = 0,
       taskType = "",
       reason = null,
-      createTaskInfo(1, 0),
-      createTaskMetrics(accumulatorUpdates.mapValues(_ * 3))))
+      createTaskInfo(1, 0, accums = accumulatorUpdates.mapValues(_ * 3)),
+      null))
 
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 11))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 
11))
 
-    assert(executionUIData.runningJobs === Seq(0))
-    assert(executionUIData.succeededJobs.isEmpty)
-    assert(executionUIData.failedJobs.isEmpty)
+    assertJobs(store.execution(0), running = Seq(0))
 
-    listener.onJobEnd(SparkListenerJobEnd(
+    bus.postToAll(SparkListenerJobEnd(
       jobId = 0,
       time = System.currentTimeMillis(),
       JobSucceeded
     ))
-    listener.onOtherEvent(SparkListenerSQLExecutionEnd(
+    bus.postToAll(SparkListenerSQLExecutionEnd(
       executionId, System.currentTimeMillis()))
 
-    assert(executionUIData.runningJobs.isEmpty)
-    assert(executionUIData.succeededJobs === Seq(0))
-    assert(executionUIData.failedJobs.isEmpty)
-
-    checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 11))
+    assertJobs(store.execution(0), completed = Seq(0))
+    checkAnswer(store.executionMetrics(0), accumulatorUpdates.mapValues(_ * 
11))
   }
 
-  test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
-    val listener = new SQLListener(spark.sparkContext.conf)
+  sqlStoreTest("onExecutionEnd happens before onJobEnd(JobSucceeded)") { 
(store, bus) =>
     val executionId = 0
     val df = createTestDataFrame
-    listener.onOtherEvent(SparkListenerSQLExecutionStart(
+    bus.postToAll(SparkListenerSQLExecutionStart(
       executionId,
       "test",
       "test",
       df.queryExecution.toString,
       SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan),
       System.currentTimeMillis()))
-    listener.onJobStart(SparkListenerJobStart(
+    bus.postToAll(SparkListenerJobStart(
       jobId = 0,
       time = System.currentTimeMillis(),
       stageInfos = Nil,
       createProperties(executionId)))
-    listener.onOtherEvent(SparkListenerSQLExecutionEnd(
+    bus.postToAll(SparkListenerSQLExecutionEnd(
       executionId, System.currentTimeMillis()))
-    listener.onJobEnd(SparkListenerJobEnd(
+    bus.postToAll(SparkListenerJobEnd(
       jobId = 0,
       time = System.currentTimeMillis(),
       JobSucceeded
     ))
 
-    val executionUIData = listener.executionIdToData(0)
-    assert(executionUIData.runningJobs.isEmpty)
-    assert(executionUIData.succeededJobs === Seq(0))
-    assert(executionUIData.failedJobs.isEmpty)
+    assertJobs(store.execution(0), completed = Seq(0))
   }
 
-  test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
-    val listener = new SQLListener(spark.sparkContext.conf)
+  sqlStoreTest("onExecutionEnd happens before multiple 
onJobEnd(JobSucceeded)s") { (store, bus) =>
     val executionId = 0
     val df = createTestDataFrame
-    listener.onOtherEvent(SparkListenerSQLExecutionStart(
+    bus.postToAll(SparkListenerSQLExecutionStart(
       executionId,
       "test",
       "test",
       df.queryExecution.toString,
       SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan),
       System.currentTimeMillis()))
-    listener.onJobStart(SparkListenerJobStart(
+    bus.postToAll(SparkListenerJobStart(
       jobId = 0,
       time = System.currentTimeMillis(),
       stageInfos = Nil,
       createProperties(executionId)))
-    listener.onJobEnd(SparkListenerJobEnd(
+    bus.postToAll(SparkListenerJobEnd(
         jobId = 0,
         time = System.currentTimeMillis(),
         JobSucceeded
     ))
 
-    listener.onJobStart(SparkListenerJobStart(
+    bus.postToAll(SparkListenerJobStart(
       jobId = 1,
       time = System.currentTimeMillis(),
       stageInfos = Nil,
       createProperties(executionId)))
-    listener.onOtherEvent(SparkListenerSQLExecutionEnd(
+    bus.postToAll(SparkListenerSQLExecutionEnd(
       executionId, System.currentTimeMillis()))
-    listener.onJobEnd(SparkListenerJobEnd(
+    bus.postToAll(SparkListenerJobEnd(
       jobId = 1,
       time = System.currentTimeMillis(),
       JobSucceeded
     ))
 
-    val executionUIData = listener.executionIdToData(0)
-    assert(executionUIData.runningJobs.isEmpty)
-    assert(executionUIData.succeededJobs.sorted === Seq(0, 1))
-    assert(executionUIData.failedJobs.isEmpty)
+    assertJobs(store.execution(0), completed = Seq(0, 1))
   }
 
-  test("onExecutionEnd happens before onJobEnd(JobFailed)") {
-    val listener = new SQLListener(spark.sparkContext.conf)
+  sqlStoreTest("onExecutionEnd happens before onJobEnd(JobFailed)") { (store, 
bus) =>
     val executionId = 0
     val df = createTestDataFrame
-    listener.onOtherEvent(SparkListenerSQLExecutionStart(
+    bus.postToAll(SparkListenerSQLExecutionStart(
       executionId,
       "test",
       "test",
       df.queryExecution.toString,
       SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan),
       System.currentTimeMillis()))
-    listener.onJobStart(SparkListenerJobStart(
+    bus.postToAll(SparkListenerJobStart(
       jobId = 0,
       time = System.currentTimeMillis(),
       stageInfos = Seq.empty,
       createProperties(executionId)))
-    listener.onOtherEvent(SparkListenerSQLExecutionEnd(
+    bus.postToAll(SparkListenerSQLExecutionEnd(
       executionId, System.currentTimeMillis()))
-    listener.onJobEnd(SparkListenerJobEnd(
+    bus.postToAll(SparkListenerJobEnd(
       jobId = 0,
       time = System.currentTimeMillis(),
       JobFailed(new RuntimeException("Oops"))
     ))
 
-    val executionUIData = listener.executionIdToData(0)
-    assert(executionUIData.runningJobs.isEmpty)
-    assert(executionUIData.succeededJobs.isEmpty)
-    assert(executionUIData.failedJobs === Seq(0))
+    assertJobs(store.execution(0), failed = Seq(0))
   }
 
   test("SPARK-11126: no memory leak when running non SQL jobs") {
-    val previousStageNumber = 
spark.sharedState.listener.stageIdToStageMetrics.size
+    val previousStageNumber = statusStore.executionsList().size
     spark.sparkContext.parallelize(1 to 10).foreach(i => ())
     spark.sparkContext.listenerBus.waitUntilEmpty(10000)
     // listener should ignore the non SQL stage
-    assert(spark.sharedState.listener.stageIdToStageMetrics.size == 
previousStageNumber)
+    assert(statusStore.executionsList().size == previousStageNumber)
 
     spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ())
     spark.sparkContext.listenerBus.waitUntilEmpty(10000)
     // listener should save the SQL stage
-    assert(spark.sharedState.listener.stageIdToStageMetrics.size == 
previousStageNumber + 1)
-  }
-
-  test("SPARK-13055: history listener only tracks SQL metrics") {
-    val listener = new SQLHistoryListener(sparkContext.conf, 
mock(classOf[SparkUI]))
-    // We need to post other events for the listener to track our accumulators.
-    // These are largely just boilerplate unrelated to what we're trying to 
test.
-    val df = createTestDataFrame
-    val executionStart = SparkListenerSQLExecutionStart(
-      0, "", "", "", 
SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0)
-    val stageInfo = createStageInfo(0, 0)
-    val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), 
createProperties(0))
-    val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
-    // This task has both accumulators that are SQL metrics and accumulators 
that are not.
-    // The listener should only track the ones that are actually SQL metrics.
-    val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella")
-    val nonSqlMetric = sparkContext.longAccumulator("baseball")
-    val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.value), None)
-    val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.value), None)
-    val taskInfo = createTaskInfo(0, 0)
-    taskInfo.setAccumulables(List(sqlMetricInfo, nonSqlMetricInfo))
-    val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, 
null)
-    listener.onOtherEvent(executionStart)
-    listener.onJobStart(jobStart)
-    listener.onStageSubmitted(stageSubmitted)
-    // Before SPARK-13055, this throws ClassCastException because the history 
listener would
-    // assume that the accumulator value is of type Long, but this may not be 
true for
-    // accumulators that are not SQL metrics.
-    listener.onTaskEnd(taskEnd)
-    val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { 
stageMetrics =>
-      stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates)
-    }
-    // Listener tracks only SQL metrics, not other accumulators
-    assert(trackedAccums.size === 1)
-    assert(trackedAccums.head === ((sqlMetricInfo.id, 
sqlMetricInfo.update.get)))
+    assert(statusStore.executionsList().size == previousStageNumber + 1)
   }
 
   test("driver side SQL metrics") {
-    val listener = new SQLListener(spark.sparkContext.conf)
-    val expectedAccumValue = 12345
+    val oldCount = statusStore.executionsList().size
+    val expectedAccumValue = 12345L
     val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue)
-    sqlContext.sparkContext.addSparkListener(listener)
     val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) {
       override lazy val sparkPlan = physicalPlan
       override lazy val executedPlan = physicalPlan
     }
+
     SQLExecution.withNewExecutionId(spark, dummyQueryExecution) {
       physicalPlan.execute().collect()
     }
 
-    def waitTillExecutionFinished(): Unit = {
-      while (listener.getCompletedExecutions.isEmpty) {
-        Thread.sleep(100)
+    while (statusStore.executionsList().size < oldCount) {
+      Thread.sleep(100)
+    }
+
+    // Wait for listener to finish computing the metrics for the execution.
+    while (statusStore.executionsList().last.metricValues == null) {
+      Thread.sleep(100)
+    }
+
+    val execId = statusStore.executionsList().last.executionId
+    val metrics = statusStore.executionMetrics(execId)
+    val driverMetric = physicalPlan.metrics("dummy")
+    val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, 
Seq(expectedAccumValue))
+
+    assert(metrics.contains(driverMetric.id))
+    assert(metrics(driverMetric.id) === expectedValue)
+  }
+
+  private def assertJobs(
+      exec: Option[SQLExecutionUIData],
+      running: Seq[Int] = Nil,
+      completed: Seq[Int] = Nil,
+      failed: Seq[Int] = Nil): Unit = {
+
+    val actualRunning = new ListBuffer[Int]()
+    val actualCompleted = new ListBuffer[Int]()
+    val actualFailed = new ListBuffer[Int]()
+
+    exec.get.jobs.foreach { case (jobId, jobStatus) =>
+      jobStatus match {
+        case JobExecutionStatus.RUNNING => actualRunning += jobId
+        case JobExecutionStatus.SUCCEEDED => actualCompleted += jobId
+        case JobExecutionStatus.FAILED => actualFailed += jobId
+        case _ => fail(s"Unexpected status $jobStatus")
       }
     }
-    waitTillExecutionFinished()
 
-    val driverUpdates = listener.getCompletedExecutions.head.driverAccumUpdates
-    assert(driverUpdates.size == 1)
-    assert(driverUpdates(physicalPlan.longMetric("dummy").id) == 
expectedAccumValue)
+    assert(actualRunning.toSeq.sorted === running)
+    assert(actualCompleted.toSeq.sorted === completed)
+    assert(actualFailed.toSeq.sorted === failed)
   }
 
   test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol 
(SPARK-18462)") {
@@ -490,7 +491,8 @@ private case class MyPlan(sc: SparkContext, expectedValue: 
Long) extends LeafExe
 
 class SQLListenerMemoryLeakSuite extends SparkFunSuite {
 
-  test("no memory leak") {
+  // TODO: this feature is not yet available in SQLAppStatusStore.
+  ignore("no memory leak") {
     quietly {
       val conf = new SparkConf()
         .setMaster("local")
@@ -498,7 +500,6 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
         .set(config.MAX_TASK_FAILURES, 1) // Don't retry the tasks to run this 
test quickly
         .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run 
this test quickly
       withSpark(new SparkContext(conf)) { sc =>
-        SparkSession.sqlListener.set(null)
         val spark = new SparkSession(sc)
         import spark.implicits._
         // Run 100 successful executions and 100 failed executions.
@@ -516,12 +517,9 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
           }
         }
         sc.listenerBus.waitUntilEmpty(10000)
-        assert(spark.sharedState.listener.getCompletedExecutions.size <= 50)
-        assert(spark.sharedState.listener.getFailedExecutions.size <= 50)
-        // 50 for successful executions and 50 for failed executions
-        assert(spark.sharedState.listener.executionIdToData.size <= 100)
-        assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100)
-        assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100)
+
+        val statusStore = new SQLAppStatusStore(sc.statusStore.store)
+        assert(statusStore.executionsList().size <= 50)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0ffa7c48/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index e0568a3..0b4629a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -73,7 +73,6 @@ trait SharedSparkSession
    * call 'beforeAll'.
    */
   protected def initializeSession(): Unit = {
-    SparkSession.sqlListener.set(null)
     if (_spark == null) {
       _spark = createSparkSession
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to