Job cancellation: address Matei's code review feedback.

Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/6b288b75
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/6b288b75
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/6b288b75

Branch: refs/heads/master
Commit: 6b288b75d4c05f42ad3612813dc77ff824bb6203
Parents: ab0940f
Author: Reynold Xin <r...@apache.org>
Authored: Sat Oct 12 15:53:31 2013 -0700
Committer: Reynold Xin <r...@apache.org>
Committed: Sat Oct 12 15:53:31 2013 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/FutureAction.scala   | 78 +++++++---------
 .../scala/org/apache/spark/SparkContext.scala   | 17 ++--
 .../scala/org/apache/spark/TaskContext.scala    |  9 +-
 .../spark/executor/MesosExecutorBackend.scala   | 18 +++-
 .../org/apache/spark/rdd/AsyncRDDActions.scala  | 10 +-
 .../org/apache/spark/rdd/CheckpointRDD.scala    |  2 +-
 .../org/apache/spark/rdd/InterruptibleRDD.scala | 36 -------
 .../spark/rdd/MapPartitionsWithContextRDD.scala | 41 ++++++++
 .../spark/rdd/MapPartitionsWithIndexRDD.scala   | 41 --------
 .../org/apache/spark/rdd/PairRDDFunctions.scala | 19 ++--
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 98 +++++++++++---------
 .../apache/spark/scheduler/DAGScheduler.scala   |  4 +-
 .../org/apache/spark/scheduler/ResultTask.scala | 27 ++++--
 .../apache/spark/scheduler/ShuffleMapTask.scala | 28 ++++--
 .../scala/org/apache/spark/scheduler/Task.scala | 16 +++-
 .../org/apache/spark/CheckpointSuite.scala      |  4 +-
 .../org/apache/spark/JobCancellationSuite.scala | 39 +++++++-
 .../apache/spark/rdd/AsyncRDDActionsSuite.scala | 47 ++++------
 18 files changed, 283 insertions(+), 251 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/FutureAction.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala 
b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 85018cb..1ad9240 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -31,6 +31,8 @@ import org.apache.spark.rdd.RDD
  * support cancellation.
  */
 trait FutureAction[T] extends Future[T] {
+  // Note that we redefine methods of the Future trait here explicitly so we 
can specify a different
+  // documentation (with reference to the word "action").
 
   /**
    * Cancels the execution of this action.
@@ -87,14 +89,14 @@ trait FutureAction[T] extends Future[T] {
  * The future holding the result of an action that triggers a single job. 
Examples include
  * count, collect, reduce.
  */
-class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
+class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], 
resultFunc: => T)
   extends FutureAction[T] {
 
   override def cancel() {
     jobWaiter.cancel()
   }
 
-  override def ready(atMost: Duration)(implicit permit: CanAwait): 
FutureJob.this.type = {
+  override def ready(atMost: Duration)(implicit permit: CanAwait): 
SimpleFutureAction.this.type = {
     if (!atMost.isFinite()) {
       awaitResult()
     } else {
@@ -149,19 +151,20 @@ class FutureJob[T] private[spark](jobWaiter: 
JobWaiter[_], resultFunc: => T)
 
 /**
  * A FutureAction for actions that could trigger multiple Spark jobs. Examples 
include take,
- * takeSample.
- *
- * This is implemented as a Scala Promise that can be cancelled. Note that the 
promise itself is
- * also its own Future (i.e. this.future returns this). See the implementation 
of takeAsync for
- * usage.
+ * takeSample. Cancellation works by setting the cancelled flag to true and 
interrupting the
+ * action thread if it is being blocked by a job.
  */
-class CancellablePromise[T] extends FutureAction[T] with Promise[T] {
-  // Cancellation works by setting the cancelled flag to true and interrupt 
the action thread
-  // if it is in progress. Before executing the action, the execution thread 
needs to check the
-  // cancelled flag in case cancel() is called before the thread even starts 
to execute. Because
-  // this and the execution thread is synchronized on the same promise object 
(this), the actual
-  // cancellation/interrupt event can only be triggered when the execution 
thread is waiting for
-  // the result of a job.
+class ComplexFutureAction[T] extends FutureAction[T] {
+
+  // Pointer to the thread that is executing the action. It is set when the 
action is run.
+  @volatile private var thread: Thread = _
+
+  // A flag indicating whether the future has been cancelled. This is used in 
case the future
+  // is cancelled before the action was even run (and thus we have no thread 
to interrupt).
+  @volatile private var _cancelled: Boolean = false
+
+  // A promise used to signal the future.
+  private val p = promise[T]()
 
   override def cancel(): Unit = this.synchronized {
     _cancelled = true
@@ -174,15 +177,18 @@ class CancellablePromise[T] extends FutureAction[T] with 
Promise[T] {
    * Executes some action enclosed in the closure. To properly enable 
cancellation, the closure
    * should use runJob implementation in this promise. See takeAsync for 
example.
    */
-  def run(func: => T)(implicit executor: ExecutionContext): Unit = 
scala.concurrent.future {
-    thread = Thread.currentThread
-    try {
-      this.success(func)
-    } catch {
-      case e: Exception => this.failure(e)
-    } finally {
-      thread = null
+  def run(func: => T)(implicit executor: ExecutionContext): this.type = {
+    scala.concurrent.future {
+      thread = Thread.currentThread
+      try {
+        p.success(func)
+      } catch {
+        case e: Exception => p.failure(e)
+      } finally {
+        thread = null
+      }
     }
+    this
   }
 
   /**
@@ -193,15 +199,15 @@ class CancellablePromise[T] extends FutureAction[T] with 
Promise[T] {
       rdd: RDD[T],
       processPartition: Iterator[T] => U,
       partitions: Seq[Int],
-      partitionResultHandler: (Int, U) => Unit,
+      resultHandler: (Int, U) => Unit,
       resultFunc: => R) {
     // If the action hasn't been cancelled yet, submit the job. The check and 
the submitJob
     // command need to be in an atomic block.
     val job = this.synchronized {
       if (!cancelled) {
-        rdd.context.submitJob(rdd, processPartition, partitions, 
partitionResultHandler, resultFunc)
+        rdd.context.submitJob(rdd, processPartition, partitions, 
resultHandler, resultFunc)
       } else {
-        throw new SparkException("action has been cancelled")
+        throw new SparkException("Action has been cancelled")
       }
     }
 
@@ -213,7 +219,7 @@ class CancellablePromise[T] extends FutureAction[T] with 
Promise[T] {
     } catch {
       case e: InterruptedException =>
         job.cancel()
-        throw new SparkException("action has been cancelled")
+        throw new SparkException("Action has been cancelled")
     }
   }
 
@@ -222,28 +228,14 @@ class CancellablePromise[T] extends FutureAction[T] with 
Promise[T] {
    */
   def cancelled: Boolean = _cancelled
 
-  // Pointer to the thread that is executing the action. It is set when the 
action is run.
-  @volatile private var thread: Thread = _
-
-  // A flag indicating whether the future has been cancelled. This is used in 
case the future
-  // is cancelled before the action was even run (and thus we have no thread 
to interrupt).
-  @volatile private var _cancelled: Boolean = false
-
-  // Internally, we delegate most functionality to this promise.
-  private val p = promise[T]()
-
-  override def future: this.type = this
-
-  override def tryComplete(result: Try[T]): Boolean = p.tryComplete(result)
-
-  @scala.throws(classOf[InterruptedException])
-  @scala.throws(classOf[scala.concurrent.TimeoutException])
+  @throws(classOf[InterruptedException])
+  @throws(classOf[scala.concurrent.TimeoutException])
   override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = 
{
     p.future.ready(atMost)(permit)
     this
   }
 
-  @scala.throws(classOf[Exception])
+  @throws(classOf[Exception])
   override def result(atMost: Duration)(implicit permit: CanAwait): T = {
     p.future.result(atMost)(permit)
   }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 52fc4dd..96a2f1f 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -760,10 +760,11 @@ class SparkContext(
       allowLocal: Boolean,
       resultHandler: (Int, U) => Unit) {
     val callSite = Utils.formatSparkCallSite
+    val cleanedFunc = clean(func)
     logInfo("Starting job: " + callSite)
     val start = System.nanoTime
-    val result = dagScheduler.runJob(rdd, func, partitions, callSite, 
allowLocal, resultHandler,
-      localProperties.get)
+    val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, 
allowLocal,
+      resultHandler, localProperties.get)
     logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - 
start) / 1e9 + " s")
     rdd.doCheckpoint()
     result
@@ -853,16 +854,14 @@ class SparkContext(
   }
 
   /**
-   * Submit a job for execution and return a FutureJob holding the result. 
Note that the
-   * processPartition closure will be "cleaned" so the caller doesn't have to 
clean the closure
-   * explicitly.
+   * Submit a job for execution and return a FutureJob holding the result.
    */
   def submitJob[T, U, R](
       rdd: RDD[T],
       processPartition: Iterator[T] => U,
       partitions: Seq[Int],
-      partitionResultHandler: (Int, U) => Unit,
-      resultFunc: => R): FutureJob[R] =
+      resultHandler: (Int, U) => Unit,
+      resultFunc: => R): SimpleFutureAction[R] =
   {
     val cleanF = clean(processPartition)
     val callSite = Utils.formatSparkCallSite
@@ -872,9 +871,9 @@ class SparkContext(
       partitions,
       callSite,
       allowLocal = false,
-      partitionResultHandler,
+      resultHandler,
       null)
-    new FutureJob(waiter, resultFunc)
+    new SimpleFutureAction(waiter, resultFunc)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala 
b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 86370d5..51584d6 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -22,14 +22,17 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.executor.TaskMetrics
 
 class TaskContext(
-  val stageId: Int,
-  val splitId: Int,
+  private[spark] val stageId: Int,
+  val partitionId: Int,
   val attemptId: Long,
   val runningLocally: Boolean = false,
   @volatile var interrupted: Boolean = false,
-  val taskMetrics: TaskMetrics = TaskMetrics.empty()
+  private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
 ) extends Serializable {
 
+  @deprecated("use partitionId", "0.8.1")
+  def splitId = partitionId
+
   // List of callback functions to execute when the task completes.
   @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala 
b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index da62091..b56d8c9 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -18,14 +18,18 @@
 package org.apache.spark.executor
 
 import java.nio.ByteBuffer
-import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, 
MesosNativeLibrary, ExecutorDriver}
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => 
MesosTaskStatus, _}
-import org.apache.spark.TaskState.TaskState
+
 import com.google.protobuf.ByteString
-import org.apache.spark.{Logging}
+
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, 
MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
+
+import org.apache.spark.Logging
 import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
 import org.apache.spark.util.Utils
 
+
 private[spark] class MesosExecutorBackend
   extends MesosExecutor
   with ExecutorBackend
@@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
   }
 
   override def killTask(d: ExecutorDriver, t: TaskID) {
-    logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not 
yet implemented)")
+    if (executor == null) {
+      logError("Received KillTask but executor was null")
+    } else {
+      executor.killTask(t.getValue.toLong)
+    }
   }
 
   override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala 
b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 1f24ee8..faaf837 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicLong
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.ExecutionContext.Implicits.global
 
-import org.apache.spark.{Logging, CancellablePromise, FutureAction}
+import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
 
 /**
  * A set of asynchronous RDD actions available through an implicit conversion.
@@ -63,9 +63,9 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends 
Serializable with
    * Returns a future for retrieving the first num elements of the RDD.
    */
   def takeAsync(num: Int): FutureAction[Seq[T]] = {
-    val promise = new CancellablePromise[Seq[T]]
+    val f = new ComplexFutureAction[Seq[T]]
 
-    promise.run {
+    f.run {
       val results = new ArrayBuffer[T](num)
       val totalParts = self.partitions.length
       var partsScanned = 0
@@ -89,7 +89,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends 
Serializable with
         val p = partsScanned until math.min(partsScanned + numPartsToTry, 
totalParts)
 
         val buf = new Array[Array[T]](p.size)
-        promise.runJob(self,
+        f.runJob(self,
           (it: Iterator[T]) => it.take(left).toArray,
           p,
           (index: Int, data: Array[T]) => buf(index) = data,
@@ -101,7 +101,7 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) 
extends Serializable with
       results.toSeq
     }
 
-    promise.future
+    f
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 3311757..ccaaecb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -85,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging {
     val outputDir = new Path(path)
     val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
 
-    val finalOutputName = splitIdToFile(ctx.splitId)
+    val finalOutputName = splitIdToFile(ctx.partitionId)
     val finalOutputPath = new Path(outputDir, finalOutputName)
     val tempOutputPath = new Path(outputDir, "." + finalOutputName + 
"-attempt-" + ctx.attemptId)
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
deleted file mode 100644
index e731deb..0000000
--- a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{InterruptibleIterator, Partition, TaskContext}
-
-
-/**
- * Wraps around an existing RDD to make it interruptible (can be killed).
- */
-private[spark]
-class InterruptibleRDD[T: ClassManifest](prev: RDD[T]) extends RDD[T](prev) {
-
-  override def getPartitions: Array[Partition] = firstParent[T].partitions
-
-  override val partitioner = prev.partitioner
-
-  override def compute(split: Partition, context: TaskContext) = {
-    new InterruptibleIterator(context, firstParent[T].iterator(split, context))
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
new file mode 100644
index 0000000..aea08ff
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{Partition, TaskContext}
+
+
+/**
+ * A variant of the MapPartitionsRDD that passes the TaskContext into the 
closure. From the
+ * TaskContext, the closure can either get access to the interruptible flag or 
get the index
+ * of the partition in the RDD.
+ */
+private[spark]
+class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
+    prev: RDD[T],
+    f: (TaskContext, Iterator[T]) => Iterator[U],
+    preservesPartitioning: Boolean
+  ) extends RDD[U](prev) {
+
+  override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+  override val partitioner = if (preservesPartitioning) prev.partitioner else 
None
+
+  override def compute(split: Partition, context: TaskContext) =
+    f(context, firstParent[T].iterator(split, context))
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
deleted file mode 100644
index 3ed8339..0000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the partition index into the
- * closure. This can be used to generate or collect partition specific
- * information such as the number of tuples in a partition.
- */
-private[spark]
-class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
-    prev: RDD[T],
-    f: (Int, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean
-  ) extends RDD[U](prev) {
-
-  override def getPartitions: Array[Partition] = firstParent[T].partitions
-
-  override val partitioner = if (preservesPartitioning) prev.partitioner else 
None
-
-  override def compute(split: Partition, context: TaskContext) =
-    f(split.index, firstParent[T].iterator(split, context))
-}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala 
b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index ee17794..93b78e1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -84,21 +84,24 @@ class PairRDDFunctions[K: ClassManifest, V: 
ClassManifest](self: RDD[(K, V)])
     }
     val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, 
mergeCombiners)
     if (self.partitioner == Some(partitioner)) {
-      self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning 
= true)
-        .interruptible()
+      self.mapPartitionsWithContext((context, iter) => {
+        new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+      }, preservesPartitioning = true)
     } else if (mapSideCombine) {
       val combined = self.mapPartitions(aggregator.combineValuesByKey, 
preservesPartitioning = true)
       val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
         .setSerializer(serializerClass)
-      partitioned.mapPartitions(aggregator.combineCombinersByKey, 
preservesPartitioning = true)
-        .interruptible()
+      partitioned.mapPartitionsWithContext((context, iter) => {
+        new InterruptibleIterator(context, 
aggregator.combineCombinersByKey(iter))
+      }, preservesPartitioning = true)
     } else {
       // Don't apply map-side combiner.
       // A sanity check to make sure mergeCombiners is not defined.
       assert(mergeCombiners == null)
       val values = new ShuffledRDD[K, V, (K, V)](self, 
partitioner).setSerializer(serializerClass)
-      values.mapPartitions(aggregator.combineValuesByKey, 
preservesPartitioning = true)
-        .interruptible()
+      values.mapPartitionsWithContext((context, iter) => {
+        new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+      }, preservesPartitioning = true)
     }
   }
 
@@ -567,7 +570,7 @@ class PairRDDFunctions[K: ClassManifest, V: 
ClassManifest](self: RDD[(K, V)])
       // around by taking a mod. We expect that no task will be attempted 2 
billion times.
       val attemptNumber = (context.attemptId % Int.MaxValue).toInt
       /* "reduce task" <split #> <attempt # = spark task #> */
-      val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, 
context.splitId, attemptNumber)
+      val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, 
context.partitionId, attemptNumber)
       val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
       val format = outputFormatClass.newInstance
       val committer = format.getOutputCommitter(hadoopContext)
@@ -666,7 +669,7 @@ class PairRDDFunctions[K: ClassManifest, V: 
ClassManifest](self: RDD[(K, V)])
       // around by taking a mod. We expect that no task will be attempted 2 
billion times.
       val attemptNumber = (context.attemptId % Int.MaxValue).toInt
 
-      writer.setup(context.stageId, context.splitId, attemptNumber)
+      writer.setup(context.stageId, context.partitionId, attemptNumber)
       writer.open()
 
       var count = 0

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 4be506b..0355618 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -418,26 +418,39 @@ abstract class RDD[T: ClassManifest](
       command: Seq[String],
       env: Map[String, String] = Map(),
       printPipeContext: (String => Unit) => Unit = null,
-      printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+      printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
     new PipedRDD(this, command, env,
       if (printPipeContext ne null) sc.clean(printPipeContext) else null,
       if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+  }
 
   /**
    * Return a new RDD by applying a function to each partition of this RDD.
    */
-  def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U],
-    preservesPartitioning: Boolean = false): RDD[U] =
+  def mapPartitions[U: ClassManifest](
+      f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): 
RDD[U] = {
     new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+  }
 
   /**
    * Return a new RDD by applying a function to each partition of this RDD, 
while tracking the index
    * of the original partition.
    */
   def mapPartitionsWithIndex[U: ClassManifest](
-    f: (Int, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean = false): RDD[U] =
-    new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+      f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = 
false): RDD[U] = {
+    val func = (context: TaskContext, iter: Iterator[T]) => 
f(context.partitionId, iter)
+    new MapPartitionsWithContextRDD(this, sc.clean(func), 
preservesPartitioning)
+  }
+
+  /**
+   * Return a new RDD by applying a function to each partition of this RDD. 
This is a variant of
+   * mapPartitions that also passes the TaskContext into the closure.
+   */
+  def mapPartitionsWithContext[U: ClassManifest](
+      f: (TaskContext, Iterator[T]) => Iterator[U],
+      preservesPartitioning: Boolean = false): RDD[U] = {
+    new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+  }
 
   /**
    * Return a new RDD by applying a function to each partition of this RDD, 
while tracking the index
@@ -445,22 +458,23 @@ abstract class RDD[T: ClassManifest](
    */
   @deprecated("use mapPartitionsWithIndex", "0.7.0")
   def mapPartitionsWithSplit[U: ClassManifest](
-    f: (Int, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean = false): RDD[U] =
-    new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+      f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = 
false): RDD[U] = {
+    mapPartitionsWithIndex(f, preservesPartitioning)
+  }
 
   /**
    * Maps f over this RDD, where f takes an additional parameter of type A.  
This
    * additional parameter is produced by constructA, which is called in each
    * partition with the index of that partition.
    */
-  def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, 
preservesPartitioning: Boolean = false)
-    (f:(T, A) => U): RDD[U] = {
-      def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
-        val a = constructA(index)
-        iter.map(t => f(t, a))
-      }
-    new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), 
preservesPartitioning)
+  def mapWith[A: ClassManifest, U: ClassManifest]
+      (constructA: Int => A, preservesPartitioning: Boolean = false)
+      (f: (T, A) => U): RDD[U] = {
+    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+      val a = constructA(context.partitionId)
+      iter.map(t => f(t, a))
+    }
+    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), 
preservesPartitioning)
   }
 
   /**
@@ -468,13 +482,14 @@ abstract class RDD[T: ClassManifest](
    * additional parameter is produced by constructA, which is called in each
    * partition with the index of that partition.
    */
-  def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, 
preservesPartitioning: Boolean = false)
-    (f:(T, A) => Seq[U]): RDD[U] = {
-      def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
-        val a = constructA(index)
-        iter.flatMap(t => f(t, a))
-      }
-    new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), 
preservesPartitioning)
+  def flatMapWith[A: ClassManifest, U: ClassManifest]
+      (constructA: Int => A, preservesPartitioning: Boolean = false)
+      (f: (T, A) => Seq[U]): RDD[U] = {
+    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+      val a = constructA(context.partitionId)
+      iter.flatMap(t => f(t, a))
+    }
+    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), 
preservesPartitioning)
   }
 
   /**
@@ -482,13 +497,12 @@ abstract class RDD[T: ClassManifest](
    * This additional parameter is produced by constructA, which is called in 
each
    * partition with the index of that partition.
    */
-  def foreachWith[A: ClassManifest](constructA: Int => A)
-    (f:(T, A) => Unit) {
-      def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
-        val a = constructA(index)
-        iter.map(t => {f(t, a); t})
-      }
-    (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ 
=> {})
+  def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
+    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+      val a = constructA(context.partitionId)
+      iter.map(t => {f(t, a); t})
+    }
+    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ 
=> {})
   }
 
   /**
@@ -496,13 +510,12 @@ abstract class RDD[T: ClassManifest](
    * additional parameter is produced by constructA, which is called in each
    * partition with the index of that partition.
    */
-  def filterWith[A: ClassManifest](constructA: Int => A)
-    (p:(T, A) => Boolean): RDD[T] = {
-      def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
-        val a = constructA(index)
-        iter.filter(t => p(t, a))
-      }
-    new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
+  def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => 
Boolean): RDD[T] = {
+    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+      val a = constructA(context.partitionId)
+      iter.filter(t => p(t, a))
+    }
+    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
   }
 
   /**
@@ -541,16 +554,14 @@ abstract class RDD[T: ClassManifest](
    * Applies a function f to all elements of this RDD.
    */
   def foreach(f: T => Unit) {
-    val cleanF = sc.clean(f)
-    sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+    sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
   }
 
   /**
    * Applies a function f to each partition of this RDD.
    */
   def foreachPartition(f: Iterator[T] => Unit) {
-    val cleanF = sc.clean(f)
-    sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+    sc.runJob(this, (iter: Iterator[T]) => f(iter))
   }
 
   /**
@@ -862,11 +873,6 @@ abstract class RDD[T: ClassManifest](
     map(x => (f(x), x))
   }
 
-  /**
-   * Creates an interruptible version of this RDD.
-   */
-  def interruptible(): RDD[T] = new InterruptibleRDD(this)
-
   /** A private method for tests, to look at the contents of each partition */
   private[spark] def collectPartitions(): Array[Array[T]] = {
     sc.runJob(this, (iter: Iterator[T]) => iter.toArray)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c5b28b8..2a8fbe8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -377,7 +377,7 @@ class DAGScheduler(
 
       case JobCancelled(jobId) =>
         // Cancel a job: find all the running stages that are linked to this 
job, and cancel them.
-        running.find(_.jobId == jobId).foreach { stage =>
+        running.filter(_.jobId == jobId).foreach { stage =>
           taskSched.cancelTasks(stage.id)
         }
 
@@ -658,7 +658,7 @@ class DAGScheduler(
             if (failedEpoch.contains(execId) && smt.epoch <= 
failedEpoch(execId)) {
               logInfo("Ignoring possibly bogus ShuffleMapTask completion from 
" + execId)
             } else {
-              stage.addOutputLoc(smt.partition, status)
+              stage.addOutputLoc(smt.partitionId, status)
             }
             if (running.contains(stage) && pendingTasks(stage).isEmpty) {
               markStageAsFinished(stage)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala 
b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index c084059..625c84f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -71,18 +71,31 @@ private[spark] object ResultTask {
 }
 
 
+/**
+ * A task that sends back the output to the driver application.
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd input to func
+ * @param func a function to apply on a partition of the RDD
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ * @param outputId index of the task in this job (a job can launch tasks on 
only a subset of the
+ *                 input RDD's partitions).
+ */
 private[spark] class ResultTask[T, U](
     stageId: Int,
     var rdd: RDD[T],
     var func: (TaskContext, Iterator[T]) => U,
-    _partition: Int,
+    _partitionId: Int,
     @transient locs: Seq[TaskLocation],
     var outputId: Int)
-  extends Task[U](stageId, _partition) with Externalizable {
+  extends Task[U](stageId, _partitionId) with Externalizable {
 
   def this() = this(0, null, null, 0, null, 0)
 
-  var split = if (rdd == null) null else rdd.partitions(partition)
+  var split = if (rdd == null) null else rdd.partitions(partitionId)
 
   @transient private val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
@@ -99,17 +112,17 @@ private[spark] class ResultTask[T, U](
 
   override def preferredLocations: Seq[TaskLocation] = preferredLocs
 
-  override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+  override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
 
   override def writeExternal(out: ObjectOutput) {
     RDDCheckpointData.synchronized {
-      split = rdd.partitions(partition)
+      split = rdd.partitions(partitionId)
       out.writeInt(stageId)
       val bytes = ResultTask.serializeInfo(
         stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
       out.writeInt(bytes.length)
       out.write(bytes)
-      out.writeInt(partition)
+      out.writeInt(partitionId)
       out.writeInt(outputId)
       out.writeLong(epoch)
       out.writeObject(split)
@@ -124,7 +137,7 @@ private[spark] class ResultTask[T, U](
     val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
     rdd = rdd_.asInstanceOf[RDD[T]]
     func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
-    partition = in.readInt()
+    partitionId = in.readInt()
     outputId = in.readInt()
     epoch = in.readLong()
     split = in.readObject().asInstanceOf[Partition]

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala 
b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 1904ee8..66c1eae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask {
   }
 }
 
+/**
+ * A ShuffleMapTask divides the elements of an RDD into multiple buckets 
(based on a partitioner
+ * specified in the ShuffleDependency).
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd the final RDD in this stage
+ * @param dep the ShuffleDependency
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ */
 private[spark] class ShuffleMapTask(
     stageId: Int,
     var rdd: RDD[_],
     var dep: ShuffleDependency[_,_],
-    _partition: Int,
+    _partitionId: Int,
     @transient private var locs: Seq[TaskLocation])
-  extends Task[MapStatus](stageId, _partition)
+  extends Task[MapStatus](stageId, _partitionId)
   with Externalizable
   with Logging {
 
@@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask(
     if (locs == null) Nil else locs.toSet.toSeq
   }
 
-  var split = if (rdd == null) null else rdd.partitions(partition)
+  var split = if (rdd == null) null else rdd.partitions(partitionId)
 
   override def writeExternal(out: ObjectOutput) {
     RDDCheckpointData.synchronized {
-      split = rdd.partitions(partition)
+      split = rdd.partitions(partitionId)
       out.writeInt(stageId)
       val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
       out.writeInt(bytes.length)
       out.write(bytes)
-      out.writeInt(partition)
+      out.writeInt(partitionId)
       out.writeLong(epoch)
       out.writeObject(split)
     }
@@ -124,7 +136,7 @@ private[spark] class ShuffleMapTask(
     val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
     rdd = rdd_
     dep = dep_
-    partition = in.readInt()
+    partitionId = in.readInt()
     epoch = in.readLong()
     split = in.readObject().asInstanceOf[Partition]
   }
@@ -141,7 +153,7 @@ private[spark] class ShuffleMapTask(
       // Obtain all the block writers for shuffle blocks.
       val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
       shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, 
numOutputSplits, ser)
-      buckets = shuffle.acquireWriters(partition)
+      buckets = shuffle.acquireWriters(partitionId)
 
       // Write the map output to its associated buckets.
       for (elem <- rdd.iterator(split, context)) {
@@ -185,5 +197,5 @@ private[spark] class ShuffleMapTask(
 
   override def preferredLocations: Seq[TaskLocation] = preferredLocs
 
-  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 2c65d82..1fe0d0e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -31,12 +31,22 @@ import org.apache.spark.util.ByteBufferInputStream
 
 
 /**
- * A task to execute on a worker node.
+ * A unit of execution. We have two kinds of Task's in Spark:
+ * - [[org.apache.spark.scheduler.ShuffleMapTask]]
+ * - [[org.apache.spark.scheduler.ResultTask]]
+ *
+ * A Spark job consists of one or more stages. The very last stage in a job 
consists of multiple
+ * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask 
executes the task
+ * and sends the task output back to the driver application. A ShuffleMapTask 
executes the task
+ * and divides the task output to multiple buckets (based on the task's 
partitioner).
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param partitionId index of the number in the RDD
  */
-private[spark] abstract class Task[T](val stageId: Int, var partition: Int) 
extends Serializable {
+private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) 
extends Serializable {
 
   def run(attemptId: Long): T = {
-    context = new TaskContext(stageId, partition, attemptId, runningLocally = 
false)
+    context = new TaskContext(stageId, partitionId, attemptId, runningLocally 
= false)
     if (_killed) {
       kill()
     }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala 
b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d9103ae..70c1acc 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -62,8 +62,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext 
with Logging {
     testCheckpointing(_.sample(false, 0.5, 0))
     testCheckpointing(_.glom())
     testCheckpointing(_.mapPartitions(_.map(_.toString)))
-    testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
-      (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+    testCheckpointing(r => new MapPartitionsWithContextRDD(r,
+      (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), 
false ))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + 
_).mapValues(_.toString))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + 
_).flatMapValues(x => 1 to x))
     testCheckpointing(_.pipe(Seq("cat")))

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala 
b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 53c2253..a192651 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -39,6 +39,7 @@ class JobCancellationSuite extends FunSuite with 
ShouldMatchers with BeforeAndAf
 
   override def afterEach() {
     super.afterEach()
+    resetSparkContext()
     System.clearProperty("spark.scheduler.mode")
   }
 
@@ -49,7 +50,6 @@ class JobCancellationSuite extends FunSuite with 
ShouldMatchers with BeforeAndAf
     testTake()
     // Make sure we can still launch tasks.
     assert(sc.parallelize(1 to 10, 2).count === 10)
-    resetSparkContext()
   }
 
   test("local mode, fair scheduler") {
@@ -61,7 +61,6 @@ class JobCancellationSuite extends FunSuite with 
ShouldMatchers with BeforeAndAf
     testTake()
     // Make sure we can still launch tasks.
     assert(sc.parallelize(1 to 10, 2).count === 10)
-    resetSparkContext()
   }
 
   test("cluster mode, FIFO scheduler") {
@@ -71,7 +70,6 @@ class JobCancellationSuite extends FunSuite with 
ShouldMatchers with BeforeAndAf
     testTake()
     // Make sure we can still launch tasks.
     assert(sc.parallelize(1 to 10, 2).count === 10)
-    resetSparkContext()
   }
 
   test("cluster mode, fair scheduler") {
@@ -83,7 +81,40 @@ class JobCancellationSuite extends FunSuite with 
ShouldMatchers with BeforeAndAf
     testTake()
     // Make sure we can still launch tasks.
     assert(sc.parallelize(1 to 10, 2).count === 10)
-    resetSparkContext()
+  }
+
+  test("two jobs sharing the same stage") {
+    // sem1: make sure cancel is issued after some tasks are launched
+    // sem2: make sure the first stage is not finished until cancel is issued
+    val sem1 = new Semaphore(0)
+    val sem2 = new Semaphore(0)
+
+    sc = new SparkContext("local[2]", "test")
+    sc.dagScheduler.addSparkListener(new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart) {
+        sem1.release()
+      }
+    })
+
+    // Create two actions that would share the some stages.
+    val rdd = sc.parallelize(1 to 10, 2).map { i =>
+      sem2.acquire()
+      (i, i)
+    }.reduceByKey(_+_)
+    val f1 = rdd.collectAsync()
+    val f2 = rdd.countAsync()
+
+    // Kill one of the action.
+    future {
+      sem1.acquire()
+      f1.cancel()
+      sem2.release(10)
+    }
+
+    // Expect both to fail now.
+    // TODO: update this test when we change Spark so cancelling f1 wouldn't 
affect f2.
+    intercept[SparkException] { f1.get() }
+    intercept[SparkException] { f2.get() }
   }
 
   def testCount() {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/6b288b75/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala 
b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 3ef000d..da032b1 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -18,19 +18,18 @@
 package org.apache.spark.rdd
 
 import java.util.concurrent.Semaphore
-import java.util.concurrent.atomic.AtomicInteger
 
-import scala.concurrent.future
 import scala.concurrent.ExecutionContext.Implicits.global
 
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkContext._
 import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
-import org.apache.spark.scheduler._
 
 
-class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
+class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with 
Timeouts {
 
   @transient private var sc: SparkContext = _
 
@@ -114,29 +113,29 @@ class AsyncRDDActionsSuite extends FunSuite with 
BeforeAndAfterAll {
   test("async success handling") {
     val f = sc.parallelize(1 to 10, 2).countAsync()
 
-    // This semaphore is used to make sure our final assert waits until 
onComplete / onSuccess
-    // finishes execution.
+    // Use a semaphore to make sure onSuccess and onComplete's success path 
will be called.
+    // If not, the test will hang.
     val sem = new Semaphore(0)
 
-    AsyncRDDActionsSuite.asyncSuccessHappened.set(0)
     f.onComplete {
       case scala.util.Success(res) =>
-        AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
         sem.release()
       case scala.util.Failure(e) =>
+        info("Should not have reached this code path (onComplete matching 
Failure)")
         throw new Exception("Task should succeed")
-        sem.release()
     }
     f.onSuccess { case a: Any =>
-      AsyncRDDActionsSuite.asyncSuccessHappened.incrementAndGet()
       sem.release()
     }
     f.onFailure { case t =>
+      info("Should not have reached this code path (onFailure)")
       throw new Exception("Task should succeed")
     }
     assert(f.get() === 10)
-    sem.acquire(2)
-    assert(AsyncRDDActionsSuite.asyncSuccessHappened.get() === 2)
+
+    failAfter(10 seconds) {
+      sem.acquire(2)
+    }
   }
 
   /**
@@ -148,38 +147,30 @@ class AsyncRDDActionsSuite extends FunSuite with 
BeforeAndAfterAll {
       throw new Exception("intentional"); i
     }.countAsync()
 
-    // This semaphore is used to make sure our final assert waits until 
onComplete / onFailure
-    // finishes execution.
+    // Use a semaphore to make sure onFailure and onComplete's failure path 
will be called.
+    // If not, the test will hang.
     val sem = new Semaphore(0)
 
-    AsyncRDDActionsSuite.asyncFailureHappend.set(0)
     f.onComplete {
       case scala.util.Success(res) =>
+        info("Should not have reached this code path (onComplete matching 
Success)")
         throw new Exception("Task should fail")
-        sem.release()
       case scala.util.Failure(e) =>
-        AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
         sem.release()
     }
     f.onSuccess { case a: Any =>
+      info("Should not have reached this code path (onSuccess)")
       throw new Exception("Task should fail")
     }
     f.onFailure { case t =>
-      AsyncRDDActionsSuite.asyncFailureHappend.incrementAndGet()
       sem.release()
     }
     intercept[SparkException] {
       f.get()
     }
-    sem.acquire(2)
-    assert(AsyncRDDActionsSuite.asyncFailureHappend.get() === 2)
-  }
-}
-
-object AsyncRDDActionsSuite {
-  // Some counters used in the test cases above.
-  var asyncSuccessHappened = new AtomicInteger
 
-  var asyncFailureHappend = new AtomicInteger
+    failAfter(10 seconds) {
+      sem.acquire(2)
+    }
+  }
 }
-

Reply via email to