Added a submitJob interface that returns a Future of the result.

Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/37d8f37a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/37d8f37a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/37d8f37a

Branch: refs/heads/master
Commit: 37d8f37a8ec110416fba0d51d8ba70370ac380c1
Parents: 1cb42e6
Author: Reynold Xin <reyno...@gmail.com>
Authored: Tue Sep 17 21:13:59 2013 -0700
Committer: Reynold Xin <reyno...@gmail.com>
Committed: Tue Sep 17 21:13:59 2013 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/FutureJob.scala |  50 +++++++
 .../scala/org/apache/spark/SparkContext.scala   |  19 +++
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  10 ++
 .../apache/spark/scheduler/DAGScheduler.scala   | 134 ++++++++++---------
 .../spark/scheduler/DAGSchedulerEvent.scala     |  20 +--
 .../spark/scheduler/DAGSchedulerSource.scala    |   2 +-
 .../org/apache/spark/scheduler/JobWaiter.scala  |  31 +++--
 .../spark/scheduler/DAGSchedulerSuite.scala     |   6 +-
 8 files changed, 185 insertions(+), 87 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/FutureJob.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/FutureJob.scala 
b/core/src/main/scala/org/apache/spark/FutureJob.scala
new file mode 100644
index 0000000..ec3e0c3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/FutureJob.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.concurrent.{ExecutionException, TimeUnit, Future}
+
+import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
+
+class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: () => T)
+  extends Future[T] {
+
+  override def isDone: Boolean = jobWaiter.jobFinished
+
+  override def cancel(mayInterruptIfRunning: Boolean): Boolean = {
+    jobWaiter.kill()
+    true
+  }
+
+  override def isCancelled: Boolean = {
+    throw new UnsupportedOperationException
+  }
+
+  override def get(): T = {
+    jobWaiter.awaitResult() match {
+      case JobSucceeded =>
+        resultFunc()
+      case JobFailed(e: Exception, _) =>
+        throw new ExecutionException(e)
+    }
+  }
+
+  override def get(timeout: Long, unit: TimeUnit): T = {
+    throw new UnsupportedOperationException
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index ff3e780..ceb898d 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -20,6 +20,7 @@ package org.apache.spark
 import java.io._
 import java.net.URI
 import java.util.Properties
+import java.util.concurrent.Future
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.Map
@@ -812,6 +813,24 @@ class SparkContext(
     result
   }
 
+  def submitJob[T, U, R](
+      rdd: RDD[T],
+      processPartition: Iterator[T] => U,
+      partitionResultHandler: (Int, U) => Unit,
+      resultFunc: () => R): Future[R] =
+  {
+    val callSite = Utils.formatSparkCallSite
+    val waiter = dagScheduler.submitJob(
+      rdd,
+      (context: TaskContext, iter: Iterator[T]) => processPartition(iter),
+      0 until rdd.partitions.size,
+      callSite,
+      allowLocal = false,
+      partitionResultHandler,
+      null)
+    new FutureJob(waiter, resultFunc)
+  }
+
   /**
    * Kill a running job.
    */

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1082cba..7cba393 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.rdd
 
 import java.util.Random
+import java.util.concurrent.Future
 
 import scala.collection.Map
 import scala.collection.JavaConversions.mapAsScalaMap
@@ -562,6 +563,15 @@ abstract class RDD[T: ClassManifest](
   }
 
   /**
+   * Return a future for retrieving the results of a collect in an 
asynchronous fashion.
+   */
+  def collectAsync(): Future[Seq[T]] = {
+    val results = new ArrayBuffer[T]
+    sc.submitJob[T, Array[T], Seq[T]](
+      this, _.toArray, (index, data) => results ++= data, () => results)
+  }
+
+  /**
    * Return an array that contains all of the elements in this RDD.
    */
   def toArray(): Array[T] = collect()

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index d44a3f2..efe258a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -105,13 +105,15 @@ class DAGScheduler(
 
   private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
 
-  val nextJobId = new AtomicInteger(0)
+  private[scheduler] val nextJobId = new AtomicInteger(0)
 
-  val nextStageId = new AtomicInteger(0)
+  def numTotalJobs: Int = nextJobId.get()
 
-  val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+  private val nextStageId = new AtomicInteger(0)
 
-  val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+  private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+
+  private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
 
   private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
 
@@ -263,32 +265,42 @@ class DAGScheduler(
   }
 
   /**
-   * Returns (and does not submit) a JobSubmitted event suitable to run a 
given job, and a
-   * JobWaiter whose getResult() method will return the result of the job when 
it is complete.
-   *
-   * The job is assumed to have at least one partition; zero partition jobs 
should be handled
-   * without a JobSubmitted event.
+   * Submit a job to the job scheduler and get a JobWaiter object back. The 
JobWaiter object
+   * can be used to block until the the job finishes executing or can be used 
to kill the job.
+   * If the given RDD does not contain any partitions, the function returns 
None.
    */
-  private[scheduler] def prepareJob[T, U: ClassManifest](
-      finalRdd: RDD[T],
+  def submitJob[T, U](
+      rdd: RDD[T],
       func: (TaskContext, Iterator[T]) => U,
       partitions: Seq[Int],
       callSite: String,
       allowLocal: Boolean,
       resultHandler: (Int, U) => Unit,
-      properties: Properties = null)
-    : (JobSubmitted, JobWaiter[U]) =
+      properties: Properties = null): JobWaiter[U] =
   {
+    val jobId = nextJobId.getAndIncrement()
+    if (partitions.size == 0) {
+      return new JobWaiter[U](this, jobId, 0, resultHandler)
+    }
+
+    // Check to make sure we are not launching a task on a partition that does 
not exist.
+    val maxPartitions = rdd.partitions.length
+    partitions.find(p => p >= maxPartitions).foreach { p =>
+      throw new IllegalArgumentException(
+        "Attempting to access a non-existent partition: " + p + ". " +
+          "Total number of partitions: " + maxPartitions)
+    }
+
     assert(partitions.size > 0)
-    val waiter = new JobWaiter(partitions.size, resultHandler)
     val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
-    val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, 
allowLocal, callSite, waiter,
-      properties)
-    (toSubmit, waiter)
+    val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
+    eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, 
allowLocal, callSite,
+      waiter, properties))
+    waiter
   }
 
   def runJob[T, U: ClassManifest](
-      finalRdd: RDD[T],
+      rdd: RDD[T],
       func: (TaskContext, Iterator[T]) => U,
       partitions: Seq[Int],
       callSite: String,
@@ -296,21 +308,7 @@ class DAGScheduler(
       resultHandler: (Int, U) => Unit,
       properties: Properties = null)
   {
-    if (partitions.size == 0) {
-      return
-    }
-
-    // Check to make sure we are not launching a task on a partition that does 
not exist.
-    val maxPartitions = finalRdd.partitions.length
-    partitions.find(p => p >= maxPartitions).foreach { p =>
-      throw new IllegalArgumentException(
-        "Attempting to access a non-existent partition: " + p + ". " +
-        "Total number of partitions: " + maxPartitions)
-    }
-
-    val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
-        finalRdd, func, partitions, callSite, allowLocal, resultHandler, 
properties)
-    eventQueue.put(toSubmit)
+    val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, 
resultHandler, properties)
     waiter.awaitResult() match {
       case JobSucceeded => {}
       case JobFailed(exception: Exception, _) =>
@@ -331,45 +329,50 @@ class DAGScheduler(
     val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
     val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
     val partitions = (0 until rdd.partitions.size).toArray
-    eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, 
callSite, listener, properties))
+    val jobId = nextJobId.getAndIncrement()
+    eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = 
false, callSite,
+      listener, properties))
     listener.awaitResult()    // Will throw an exception if the job fails
   }
 
+  /**
+   * Kill a job that is running or waiting in the queue.
+   */
   def killJob(jobId: Int): Unit = this.synchronized {
     activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
-  }
 
-  private def killJob(job: ActiveJob): Unit = this.synchronized {
-    logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
-    activeJobs.remove(job)
-    idToActiveJob.remove(job.jobId)
-    val stage = job.finalStage
-    resultStageToJob.remove(stage)
-    killStage(job, stage)
-    val e = new SparkException("Job killed")
-    job.listener.jobFailed(e)
-    listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None)))
-  }
-
-  private def killStage(job: ActiveJob, stage: Stage): Unit = 
this.synchronized {
-    // TODO: Can we reuse taskSetFailed?
-    logInfo("Killing Stage %s".format(stage.id))
-    stageIdToStage.remove(stage.id)
-    if (stage.isShuffleMap) {
-      shuffleToMapStage.remove(stage.id)
-    }
-    waiting.remove(stage)
-    pendingTasks.remove(stage)
-    taskSched.killTasks(stage.id)
-
-    if (running.contains(stage)) {
-      running.remove(stage)
+    def killJob(job: ActiveJob): Unit = this.synchronized {
+      logInfo("Killing Job and cleaning up stages %d".format(job.jobId))
+      activeJobs.remove(job)
+      idToActiveJob.remove(job.jobId)
+      val stage = job.finalStage
+      resultStageToJob.remove(stage)
+      killStage(job, stage)
       val e = new SparkException("Job killed")
-      listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage))))
+      job.listener.jobFailed(e)
+      listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None)))
     }
 
-    stage.parents.foreach(parentStage => killStage(job, parentStage))
-    //stageToInfos -= stage
+    def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized {
+      // TODO: Can we reuse taskSetFailed?
+      logInfo("Killing Stage %s".format(stage.id))
+      stageIdToStage.remove(stage.id)
+      if (stage.isShuffleMap) {
+        shuffleToMapStage.remove(stage.id)
+      }
+      waiting.remove(stage)
+      pendingTasks.remove(stage)
+      taskSched.killTasks(stage.id)
+
+      if (running.contains(stage)) {
+        running.remove(stage)
+        val e = new SparkException("Job killed")
+        listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage))))
+      }
+
+      stage.parents.foreach(parentStage => killStage(job, parentStage))
+      //stageToInfos -= stage
+    }
   }
 
   /**
@@ -378,9 +381,8 @@ class DAGScheduler(
    */
   private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
     event match {
-      case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, 
listener, properties) =>
-        val jobId = nextJobId.getAndIncrement()
-        val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
+      case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, 
listener, properties) =>
+        val finalStage = newStage(rdd, None, jobId, Some(callSite))
         val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, 
listener, properties)
         clearCacheLocs()
         logInfo("Got job " + job.jobId + " (" + callSite + ") with " + 
partitions.length +

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 0d99670..e4b60c4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -32,9 +32,10 @@ import org.apache.spark.executor.TaskMetrics
  * submitted) but there is a single "logic" thread that reads these events and 
takes decisions.
  * This greatly simplifies synchronization.
  */
-private[spark] sealed trait DAGSchedulerEvent
+private[scheduler] sealed trait DAGSchedulerEvent
 
-private[spark] case class JobSubmitted(
+private[scheduler] case class JobSubmitted(
+    jobId: Int,
     finalRDD: RDD[_],
     func: (TaskContext, Iterator[_]) => _,
     partitions: Array[Int],
@@ -44,9 +45,10 @@ private[spark] case class JobSubmitted(
     properties: Properties = null)
   extends DAGSchedulerEvent
 
-private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) 
extends DAGSchedulerEvent
+private[scheduler]
+case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends 
DAGSchedulerEvent
 
-private[spark] case class CompletionEvent(
+private[scheduler] case class CompletionEvent(
     task: Task[_],
     reason: TaskEndReason,
     result: Any,
@@ -55,10 +57,12 @@ private[spark] case class CompletionEvent(
     taskMetrics: TaskMetrics)
   extends DAGSchedulerEvent
 
-private[spark] case class ExecutorGained(execId: String, host: String) extends 
DAGSchedulerEvent
+private[scheduler]
+case class ExecutorGained(execId: String, host: String) extends 
DAGSchedulerEvent
 
-private[spark] case class ExecutorLost(execId: String) extends 
DAGSchedulerEvent
+private[scheduler] case class ExecutorLost(execId: String) extends 
DAGSchedulerEvent
 
-private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) 
extends DAGSchedulerEvent
+private[scheduler]
+case class TaskSetFailed(taskSet: TaskSet, reason: String) extends 
DAGSchedulerEvent
 
-private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
+private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 446d490..9fe7002 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: 
DAGScheduler, sc: Spar
   })
 
   metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new 
Gauge[Int] {
-    override def getValue: Int = dagScheduler.nextJobId.get()
+    override def getValue: Int = dagScheduler.numTotalJobs
   })
 
   metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), 
new Gauge[Int] {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala 
b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 200d881..290dcc8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,29 +17,40 @@
 
 package org.apache.spark.scheduler
 
-import scala.collection.mutable.ArrayBuffer
-
 /**
  * An object that waits for a DAGScheduler job to complete. As tasks finish, 
it passes their
  * results to the given handler function.
  */
-private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => 
Unit)
+private[spark] class JobWaiter[T](
+    dagScheduler: DAGScheduler,
+    jobId: Int,
+    totalTasks: Int,
+    resultHandler: (Int, T) => Unit)
   extends JobListener {
 
   private var finishedTasks = 0
 
-  private var jobFinished = false          // Is the job as a whole finished 
(succeeded or failed)?
-  private var jobResult: JobResult = null  // If the job is finished, this 
will be its result
+  // Is the job as a whole finished (succeeded or failed)?
+  private var _jobFinished = totalTasks == 0
+
+  def jobFinished = _jobFinished
+
+  // If the job is finished, this will be its result
+  private var jobResult: JobResult = null
+
+  def kill() {
+    dagScheduler.killJob(jobId)
+  }
 
   override def taskSucceeded(index: Int, result: Any) {
     synchronized {
-      if (jobFinished) {
+      if (_jobFinished) {
         throw new UnsupportedOperationException("taskSucceeded() called on a 
finished JobWaiter")
       }
       resultHandler(index, result.asInstanceOf[T])
       finishedTasks += 1
       if (finishedTasks == totalTasks) {
-        jobFinished = true
+        _jobFinished = true
         jobResult = JobSucceeded
         this.notifyAll()
       }
@@ -48,17 +59,17 @@ private[spark] class JobWaiter[T](totalTasks: Int, 
resultHandler: (Int, T) => Un
 
   override def jobFailed(exception: Exception) {
     synchronized {
-      if (jobFinished) {
+      if (_jobFinished) {
         throw new UnsupportedOperationException("jobFailed() called on a 
finished JobWaiter")
       }
-      jobFinished = true
+      _jobFinished = true
       jobResult = JobFailed(exception, None)
       this.notifyAll()
     }
   }
 
   def awaitResult(): JobResult = synchronized {
-    while (!jobFinished) {
+    while (!_jobFinished) {
       this.wait()
     }
     return jobResult

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 180211c..f39e863 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -190,7 +190,8 @@ class DAGSchedulerSuite extends FunSuite with 
BeforeAndAfter with LocalSparkCont
       func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
       allowLocal: Boolean = false,
       listener: JobListener = listener) {
-    runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, 
listener))
   }
 
   /** Sends TaskSetFailed to the scheduler. */
@@ -224,7 +225,8 @@ class DAGSchedulerSuite extends FunSuite with 
BeforeAndAfter with LocalSparkCont
       override def getPreferredLocations(split: Partition) = Nil
       override def toString = "DAGSchedulerSuite Local RDD"
     }
-    runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, 
listener))
     assert(results === Map(0 -> 42))
   }
 

Reply via email to