Repository: spark
Updated Branches:
  refs/heads/master a63d9edcf -> 765a48849


[SPARK-9026][SPARK-4514] Modifications to JobWaiter, FutureAction, and 
AsyncRDDActions to support non-blocking operation

These changes rework the implementations of `SimpleFutureAction`, 
`ComplexFutureAction`, `JobWaiter`, and `AsyncRDDActions` such that 
asynchronous callbacks on the generated `Futures` NEVER block waiting for a job 
to complete. A small amount of mutex synchronization is necessary to protect 
the internal fields that manage cancellation, but these locks are only held 
very briefly and in practice should almost never cause any blocking to occur. 
The existing blocking APIs of these classes are retained, but they simply 
delegate to the underlying non-blocking API and `Await` the results with 
indefinite timeouts.

Associated JIRA ticket: https://issues.apache.org/jira/browse/SPARK-9026
Also fixes: https://issues.apache.org/jira/browse/SPARK-4514

This pull request contains all my own original work, which I release to the 
Spark project under its open source license.

Author: Richard W. Eggert II <richard.egg...@gmail.com>

Closes #9264 from reggert/fix-futureaction.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/765a4884
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/765a4884
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/765a4884

Branch: refs/heads/master
Commit: 765a488494dac0ed38d2b81742c06467b79d96b2
Parents: a63d9ed
Author: Richard W. Eggert II <richard.egg...@gmail.com>
Authored: Tue Dec 15 18:22:58 2015 -0800
Committer: Andrew Or <and...@databricks.com>
Committed: Tue Dec 15 18:22:58 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/FutureAction.scala   | 164 +++++++------------
 .../org/apache/spark/rdd/AsyncRDDActions.scala  |  48 +++---
 .../apache/spark/scheduler/DAGScheduler.scala   |   8 +-
 .../org/apache/spark/scheduler/JobWaiter.scala  |  48 +++---
 .../test/scala/org/apache/spark/Smuggle.scala   |  82 ++++++++++
 .../org/apache/spark/StatusTrackerSuite.scala   |  26 +++
 .../apache/spark/rdd/AsyncRDDActionsSuite.scala |  33 +++-
 7 files changed, 251 insertions(+), 158 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/765a4884/core/src/main/scala/org/apache/spark/FutureAction.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala 
b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 48792a9..2a8220f 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -20,13 +20,15 @@ package org.apache.spark
 import java.util.Collections
 import java.util.concurrent.TimeUnit
 
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.java.JavaFutureAction
 import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobWaiter
 
-import scala.concurrent._
-import scala.concurrent.duration.Duration
-import scala.util.{Failure, Try}
 
 /**
  * A future for the result of an action to support cancellation. This is an 
extension of the
@@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] {
  * A [[FutureAction]] holding the result of an action that triggers a single 
job. Examples include
  * count, collect, reduce.
  */
+@DeveloperApi
 class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], 
resultFunc: => T)
   extends FutureAction[T] {
 
@@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: 
JobWaiter[_], resultFunc:
   }
 
   override def ready(atMost: Duration)(implicit permit: CanAwait): 
SimpleFutureAction.this.type = {
-    if (!atMost.isFinite()) {
-      awaitResult()
-    } else jobWaiter.synchronized {
-      val finishTime = System.currentTimeMillis() + atMost.toMillis
-      while (!isCompleted) {
-        val time = System.currentTimeMillis()
-        if (time >= finishTime) {
-          throw new TimeoutException
-        } else {
-          jobWaiter.wait(finishTime - time)
-        }
-      }
-    }
+    jobWaiter.completionFuture.ready(atMost)
     this
   }
 
   @throws(classOf[Exception])
   override def result(atMost: Duration)(implicit permit: CanAwait): T = {
-    ready(atMost)(permit)
-    awaitResult() match {
-      case scala.util.Success(res) => res
-      case scala.util.Failure(e) => throw e
-    }
+    jobWaiter.completionFuture.ready(atMost)
+    assert(value.isDefined, "Future has not completed properly")
+    value.get.get
   }
 
   override def onComplete[U](func: (Try[T]) => U)(implicit executor: 
ExecutionContext) {
-    executor.execute(new Runnable {
-      override def run() {
-        func(awaitResult())
-      }
-    })
+    jobWaiter.completionFuture onComplete {_ => func(value.get)}
   }
 
   override def isCompleted: Boolean = jobWaiter.jobFinished
 
   override def isCancelled: Boolean = _cancelled
 
-  override def value: Option[Try[T]] = {
-    if (jobWaiter.jobFinished) {
-      Some(awaitResult())
-    } else {
-      None
-    }
-  }
-
-  private def awaitResult(): Try[T] = {
-    jobWaiter.awaitResult() match {
-      case JobSucceeded => scala.util.Success(resultFunc)
-      case JobFailed(e: Exception) => scala.util.Failure(e)
-    }
-  }
+  override def value: Option[Try[T]] =
+    jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)}
 
   def jobIds: Seq[Int] = Seq(jobWaiter.jobId)
 }
 
 
 /**
+  * Handle via which a "run" function passed to a [[ComplexFutureAction]]
+  * can submit jobs for execution.
+  */
+@DeveloperApi
+trait JobSubmitter {
+  /**
+    * Submit a job for execution and return a FutureAction holding the result.
+    * This is a wrapper around the same functionality provided by SparkContext
+    * to enable cancellation.
+    */
+  def submitJob[T, U, R](
+    rdd: RDD[T],
+    processPartition: Iterator[T] => U,
+    partitions: Seq[Int],
+    resultHandler: (Int, U) => Unit,
+    resultFunc: => R): FutureAction[R]
+}
+
+
+/**
  * A [[FutureAction]] for actions that could trigger multiple Spark jobs. 
Examples include take,
- * takeSample. Cancellation works by setting the cancelled flag to true and 
interrupting the
- * action thread if it is being blocked by a job.
+ * takeSample. Cancellation works by setting the cancelled flag to true and 
cancelling any pending
+ * jobs.
  */
-class ComplexFutureAction[T] extends FutureAction[T] {
+@DeveloperApi
+class ComplexFutureAction[T](run : JobSubmitter => Future[T])
+  extends FutureAction[T] { self =>
 
-  // Pointer to the thread that is executing the action. It is set when the 
action is run.
-  @volatile private var thread: Thread = _
+  @volatile private var _cancelled = false
 
-  // A flag indicating whether the future has been cancelled. This is used in 
case the future
-  // is cancelled before the action was even run (and thus we have no thread 
to interrupt).
-  @volatile private var _cancelled: Boolean = false
-
-  @volatile private var jobs: Seq[Int] = Nil
+  @volatile private var subActions: List[FutureAction[_]] = Nil
 
   // A promise used to signal the future.
-  private val p = promise[T]()
+  private val p = Promise[T]().tryCompleteWith(run(jobSubmitter))
 
-  override def cancel(): Unit = this.synchronized {
+  override def cancel(): Unit = synchronized {
     _cancelled = true
-    if (thread != null) {
-      thread.interrupt()
-    }
-  }
-
-  /**
-   * Executes some action enclosed in the closure. To properly enable 
cancellation, the closure
-   * should use runJob implementation in this promise. See takeAsync for 
example.
-   */
-  def run(func: => T)(implicit executor: ExecutionContext): this.type = {
-    scala.concurrent.future {
-      thread = Thread.currentThread
-      try {
-        p.success(func)
-      } catch {
-        case e: Exception => p.failure(e)
-      } finally {
-        // This lock guarantees when calling `thread.interrupt()` in `cancel`,
-        // thread won't be set to null.
-        ComplexFutureAction.this.synchronized {
-          thread = null
-        }
-      }
-    }
-    this
+    p.tryFailure(new SparkException("Action has been cancelled"))
+    subActions.foreach(_.cancel())
   }
 
-  /**
-   * Runs a Spark job. This is a wrapper around the same functionality 
provided by SparkContext
-   * to enable cancellation.
-   */
-  def runJob[T, U, R](
+  private def jobSubmitter = new JobSubmitter {
+    def submitJob[T, U, R](
       rdd: RDD[T],
       processPartition: Iterator[T] => U,
       partitions: Seq[Int],
       resultHandler: (Int, U) => Unit,
-      resultFunc: => R) {
-    // If the action hasn't been cancelled yet, submit the job. The check and 
the submitJob
-    // command need to be in an atomic block.
-    val job = this.synchronized {
+      resultFunc: => R): FutureAction[R] = self.synchronized {
+      // If the action hasn't been cancelled yet, submit the job. The check 
and the submitJob
+      // command need to be in an atomic block.
       if (!isCancelled) {
-        rdd.context.submitJob(rdd, processPartition, partitions, 
resultHandler, resultFunc)
+        val job = rdd.context.submitJob(
+          rdd,
+          processPartition,
+          partitions,
+          resultHandler,
+          resultFunc)
+        subActions = job :: subActions
+        job
       } else {
         throw new SparkException("Action has been cancelled")
       }
     }
-
-    this.jobs = jobs ++ job.jobIds
-
-    // Wait for the job to complete. If the action is cancelled (with an 
interrupt),
-    // cancel the job and stop the execution. This is not in a synchronized 
block because
-    // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
-    try {
-      Await.ready(job, Duration.Inf)
-    } catch {
-      case e: InterruptedException =>
-        job.cancel()
-        throw new SparkException("Action has been cancelled")
-    }
   }
 
   override def isCancelled: Boolean = _cancelled
@@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] {
 
   override def value: Option[Try[T]] = p.future.value
 
-  def jobIds: Seq[Int] = jobs
+  def jobIds: Seq[Int] = subActions.flatMap(_.jobIds)
 
 }
 
+
 private[spark]
 class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: 
S => T)
   extends JavaFutureAction[T] {
@@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: 
FutureAction[S], converter: S
     Await.ready(futureAction, timeout)
     futureAction.value.get match {
       case scala.util.Success(value) => converter(value)
-      case Failure(exception) =>
+      case scala.util.Failure(exception) =>
         if (isCancelled) {
           throw new CancellationException("Job cancelled").initCause(exception)
         } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/765a4884/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala 
b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index d5e8536..14f541f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -19,13 +19,12 @@ package org.apache.spark.rdd
 
 import java.util.concurrent.atomic.AtomicLong
 
-import org.apache.spark.util.ThreadUtils
-
 import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext
+import scala.concurrent.{Future, ExecutionContext}
 import scala.reflect.ClassTag
 
-import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, 
Logging}
+import org.apache.spark.util.ThreadUtils
 
 /**
  * A set of asynchronous RDD actions available through an implicit conversion.
@@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends 
Serializable with Loggi
    * Returns a future for retrieving the first num elements of the RDD.
    */
   def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
-    val f = new ComplexFutureAction[Seq[T]]
     val callSite = self.context.getCallSite
-
-    f.run {
-      // This is a blocking action so we should use 
"AsyncRDDActions.futureExecutionContext" which
-      // is a cached thread pool.
-      val results = new ArrayBuffer[T](num)
-      val totalParts = self.partitions.length
-      var partsScanned = 0
-      self.context.setCallSite(callSite)
-      while (results.size < num && partsScanned < totalParts) {
+    val localProperties = self.context.getLocalProperties
+    // Cached thread pool to handle aggregation of subtasks.
+    implicit val executionContext = AsyncRDDActions.futureExecutionContext
+    val results = new ArrayBuffer[T](num)
+    val totalParts = self.partitions.length
+
+    /*
+      Recursively triggers jobs to scan partitions until either the requested
+      number of elements are retrieved, or the partitions to scan are 
exhausted.
+      This implementation is non-blocking, asynchronously handling the
+      results of each job and triggering the next job using callbacks on 
futures.
+     */
+    def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : 
Future[Seq[T]] =
+      if (results.size >= num || partsScanned >= totalParts) {
+        Future.successful(results.toSeq)
+      } else {
         // The number of partitions to try in this iteration. It is ok for 
this number to be
         // greater than totalParts because we actually cap it at totalParts in 
runJob.
         var numPartsToTry = 1
@@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends 
Serializable with Loggi
         val p = partsScanned until math.min(partsScanned + numPartsToTry, 
totalParts)
 
         val buf = new Array[Array[T]](p.size)
-        f.runJob(self,
+        self.context.setCallSite(callSite)
+        self.context.setLocalProperties(localProperties)
+        val job = jobSubmitter.submitJob(self,
           (it: Iterator[T]) => it.take(left).toArray,
           p,
           (index: Int, data: Array[T]) => buf(index) = data,
           Unit)
-
-        buf.foreach(results ++= _.take(num - results.size))
-        partsScanned += numPartsToTry
+        job.flatMap {_ =>
+          buf.foreach(results ++= _.take(num - results.size))
+          continue(partsScanned + numPartsToTry)
+        }
       }
-      results.toSeq
-    }(AsyncRDDActions.futureExecutionContext)
 
-    f
+    new ComplexFutureAction[Seq[T]](continue(0)(_))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/765a4884/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5582720..8d0e0c8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.Map
 import scala.collection.mutable.{HashMap, HashSet, Stack}
+import scala.concurrent.Await
 import scala.concurrent.duration._
 import scala.language.existentials
 import scala.language.postfixOps
@@ -610,11 +611,12 @@ class DAGScheduler(
       properties: Properties): Unit = {
     val start = System.nanoTime
     val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, 
properties)
-    waiter.awaitResult() match {
-      case JobSucceeded =>
+    Await.ready(waiter.completionFuture, atMost = Duration.Inf)
+    waiter.completionFuture.value.get match {
+      case scala.util.Success(_) =>
         logInfo("Job %d finished: %s, took %f s".format
           (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
-      case JobFailed(exception: Exception) =>
+      case scala.util.Failure(exception) =>
         logInfo("Job %d failed: %s, took %f s".format
           (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
         // SPARK-8644: Include user stack trace in exceptions coming from 
DAGScheduler.

http://git-wip-us.apache.org/repos/asf/spark/blob/765a4884/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala 
b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 382b094..4326135 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.scheduler
 
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.{Future, Promise}
+
 /**
  * An object that waits for a DAGScheduler job to complete. As tasks finish, 
it passes their
  * results to the given handler function.
@@ -28,17 +32,15 @@ private[spark] class JobWaiter[T](
     resultHandler: (Int, T) => Unit)
   extends JobListener {
 
-  private var finishedTasks = 0
-
-  // Is the job as a whole finished (succeeded or failed)?
-  @volatile
-  private var _jobFinished = totalTasks == 0
-
-  def jobFinished: Boolean = _jobFinished
-
+  private val finishedTasks = new AtomicInteger(0)
   // If the job is finished, this will be its result. In the case of 0 task 
jobs (e.g. zero
   // partition RDDs), we set the jobResult directly to JobSucceeded.
-  private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+  private val jobPromise: Promise[Unit] =
+    if (totalTasks == 0) Promise.successful(()) else Promise()
+
+  def jobFinished: Boolean = jobPromise.isCompleted
+
+  def completionFuture: Future[Unit] = jobPromise.future
 
   /**
    * Sends a signal to the DAGScheduler to cancel the job. The cancellation 
itself is handled
@@ -49,29 +51,17 @@ private[spark] class JobWaiter[T](
     dagScheduler.cancelJob(jobId)
   }
 
-  override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
-    if (_jobFinished) {
-      throw new UnsupportedOperationException("taskSucceeded() called on a 
finished JobWaiter")
+  override def taskSucceeded(index: Int, result: Any): Unit = {
+    // resultHandler call must be synchronized in case resultHandler itself is 
not thread safe.
+    synchronized {
+      resultHandler(index, result.asInstanceOf[T])
     }
-    resultHandler(index, result.asInstanceOf[T])
-    finishedTasks += 1
-    if (finishedTasks == totalTasks) {
-      _jobFinished = true
-      jobResult = JobSucceeded
-      this.notifyAll()
+    if (finishedTasks.incrementAndGet() == totalTasks) {
+      jobPromise.success(())
     }
   }
 
-  override def jobFailed(exception: Exception): Unit = synchronized {
-    _jobFinished = true
-    jobResult = JobFailed(exception)
-    this.notifyAll()
-  }
+  override def jobFailed(exception: Exception): Unit =
+    jobPromise.failure(exception)
 
-  def awaitResult(): JobResult = synchronized {
-    while (!_jobFinished) {
-      this.wait()
-    }
-    return jobResult
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/765a4884/core/src/test/scala/org/apache/spark/Smuggle.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala 
b/core/src/test/scala/org/apache/spark/Smuggle.scala
new file mode 100644
index 0000000..01694a6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/Smuggle.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.UUID
+import java.util.concurrent.locks.ReentrantReadWriteLock
+
+import scala.collection.mutable
+
+/**
+  * Utility wrapper to "smuggle" objects into tasks while bypassing 
serialization.
+  * This is intended for testing purposes, primarily to make locks, 
semaphores, and
+  * other constructs that would not survive serialization available from 
within tasks.
+  * A Smuggle reference is itself serializable, but after being serialized and
+  * deserialized, it still refers to the same underlying "smuggled" object, as 
long
+  * as it was deserialized within the same JVM. This can be useful for tests 
that
+  * depend on the timing of task completion to be deterministic, since one can 
"smuggle"
+  * a lock or semaphore into the task, and then the task can block until the 
test gives
+  * the go-ahead to proceed via the lock.
+  */
+class Smuggle[T] private(val key: Symbol) extends Serializable {
+  def smuggledObject: T = Smuggle.get(key)
+}
+
+
+object Smuggle {
+  /**
+    * Wraps the specified object to be smuggled into a serialized task without
+    * being serialized itself.
+    *
+    * @param smuggledObject
+    * @tparam T
+    * @return Smuggle wrapper around smuggledObject.
+    */
+  def apply[T](smuggledObject: T): Smuggle[T] = {
+    val key = Symbol(UUID.randomUUID().toString)
+    lock.writeLock().lock()
+    try {
+      smuggledObjects += key -> smuggledObject
+    } finally {
+      lock.writeLock().unlock()
+    }
+    new Smuggle(key)
+  }
+
+  private val lock = new ReentrantReadWriteLock
+  private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any]
+
+  private def get[T](key: Symbol) : T = {
+    lock.readLock().lock()
+    try {
+      smuggledObjects(key).asInstanceOf[T]
+    } finally {
+      lock.readLock().unlock()
+    }
+  }
+
+  /**
+    * Implicit conversion of a Smuggle wrapper to the object being smuggled.
+    *
+    * @param smuggle the wrapper to unpack.
+    * @tparam T
+    * @return the smuggled object represented by the wrapper.
+    */
+  implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = 
smuggle.smuggledObject
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/765a4884/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
index 46516e8..5483f2b 100644
--- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
@@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers 
with LocalSparkCont
         Set(firstJobId, secondJobId))
     }
   }
+
+  test("getJobIdsForGroup() with takeAsync()") {
+    sc = new SparkContext("local", "test", new SparkConf(false))
+    sc.setJobGroup("my-job-group2", "description")
+    sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
+    val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1)
+    val firstJobId = eventually(timeout(10 seconds)) {
+      firstJobFuture.jobIds.head
+    }
+    eventually(timeout(10 seconds)) {
+      sc.statusTracker.getJobIdsForGroup("my-job-group2") should be 
(Seq(firstJobId))
+    }
+  }
+
+  test("getJobIdsForGroup() with takeAsync() across multiple partitions") {
+    sc = new SparkContext("local", "test", new SparkConf(false))
+    sc.setJobGroup("my-job-group2", "description")
+    sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
+    val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999)
+    val firstJobId = eventually(timeout(10 seconds)) {
+      firstJobFuture.jobIds.head
+    }
+    eventually(timeout(10 seconds)) {
+      sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/765a4884/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala 
b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index ec99f2a..de015eb 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
 
 import java.util.concurrent.Semaphore
 
-import scala.concurrent.{Await, TimeoutException}
+import scala.concurrent._
 import scala.concurrent.duration.Duration
 import scala.concurrent.ExecutionContext.Implicits.global
 
@@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
 
-import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, 
SparkFunSuite}
+import org.apache.spark._
 
 class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with 
Timeouts {
 
@@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with 
BeforeAndAfterAll with Tim
       Await.result(f, Duration(20, "milliseconds"))
     }
   }
+
+  private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = {
+    val executionContextInvoked = Promise[Unit]
+    val fakeExecutionContext = new ExecutionContext {
+      override def execute(runnable: Runnable): Unit = {
+        executionContextInvoked.success(())
+      }
+      override def reportFailure(t: Throwable): Unit = ()
+    }
+    val starter = Smuggle(new Semaphore(0))
+    starter.drainPermits()
+    val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => 
starter.acquire(1); itr}
+    val f = action(rdd)
+    f.onComplete(_ => ())(fakeExecutionContext)
+    // Here we verify that registering the callback didn't cause a thread to 
be consumed.
+    assert(!executionContextInvoked.isCompleted)
+    // Now allow the executors to proceed with task processing.
+    starter.release(rdd.partitions.length)
+    // Waiting for the result verifies that the tasks were successfully 
processed.
+    Await.result(executionContextInvoked.future, atMost = 15.seconds)
+  }
+
+  test("SimpleFutureAction callback must not consume a thread while waiting") {
+    testAsyncAction(_.countAsync())
+  }
+
+  test("ComplexFutureAction callback must not consume a thread while waiting") 
{
+    testAsyncAction((_.takeAsync(100)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to