Added a submitJob interface that returns a Future of the result.
Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/37d8f37a Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/37d8f37a Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/37d8f37a Branch: refs/heads/master Commit: 37d8f37a8ec110416fba0d51d8ba70370ac380c1 Parents: 1cb42e6 Author: Reynold Xin <reyno...@gmail.com> Authored: Tue Sep 17 21:13:59 2013 -0700 Committer: Reynold Xin <reyno...@gmail.com> Committed: Tue Sep 17 21:13:59 2013 -0700 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/FutureJob.scala | 50 +++++++ .../scala/org/apache/spark/SparkContext.scala | 19 +++ .../main/scala/org/apache/spark/rdd/RDD.scala | 10 ++ .../apache/spark/scheduler/DAGScheduler.scala | 134 ++++++++++--------- .../spark/scheduler/DAGSchedulerEvent.scala | 20 +-- .../spark/scheduler/DAGSchedulerSource.scala | 2 +- .../org/apache/spark/scheduler/JobWaiter.scala | 31 +++-- .../spark/scheduler/DAGSchedulerSuite.scala | 6 +- 8 files changed, 185 insertions(+), 87 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/FutureJob.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/FutureJob.scala b/core/src/main/scala/org/apache/spark/FutureJob.scala new file mode 100644 index 0000000..ec3e0c3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/FutureJob.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.{ExecutionException, TimeUnit, Future} + +import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} + +class FutureJob[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: () => T) + extends Future[T] { + + override def isDone: Boolean = jobWaiter.jobFinished + + override def cancel(mayInterruptIfRunning: Boolean): Boolean = { + jobWaiter.kill() + true + } + + override def isCancelled: Boolean = { + throw new UnsupportedOperationException + } + + override def get(): T = { + jobWaiter.awaitResult() match { + case JobSucceeded => + resultFunc() + case JobFailed(e: Exception, _) => + throw new ExecutionException(e) + } + } + + override def get(timeout: Long, unit: TimeUnit): T = { + throw new UnsupportedOperationException + } +} http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/SparkContext.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ff3e780..ceb898d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io._ import java.net.URI import java.util.Properties +import java.util.concurrent.Future import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map @@ -812,6 +813,24 @@ class SparkContext( result } + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitionResultHandler: (Int, U) => Unit, + resultFunc: () => R): Future[R] = + { + val callSite = Utils.formatSparkCallSite + val waiter = dagScheduler.submitJob( + rdd, + (context: TaskContext, iter: Iterator[T]) => processPartition(iter), + 0 until rdd.partitions.size, + callSite, + allowLocal = false, + partitionResultHandler, + null) + new FutureJob(waiter, resultFunc) + } + /** * Kill a running job. */ http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/rdd/RDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1082cba..7cba393 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.util.Random +import java.util.concurrent.Future import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap @@ -562,6 +563,15 @@ abstract class RDD[T: ClassManifest]( } /** + * Return a future for retrieving the results of a collect in an asynchronous fashion. + */ + def collectAsync(): Future[Seq[T]] = { + val results = new ArrayBuffer[T] + sc.submitJob[T, Array[T], Seq[T]]( + this, _.toArray, (index, data) => results ++= data, () => results) + } + + /** * Return an array that contains all of the elements in this RDD. */ def toArray(): Array[T] = collect() http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d44a3f2..efe258a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -105,13 +105,15 @@ class DAGScheduler( private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] - val nextJobId = new AtomicInteger(0) + private[scheduler] val nextJobId = new AtomicInteger(0) - val nextStageId = new AtomicInteger(0) + def numTotalJobs: Int = nextJobId.get() - val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private val nextStageId = new AtomicInteger(0) - val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -263,32 +265,42 @@ class DAGScheduler( } /** - * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a - * JobWaiter whose getResult() method will return the result of the job when it is complete. - * - * The job is assumed to have at least one partition; zero partition jobs should be handled - * without a JobSubmitted event. + * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * can be used to block until the the job finishes executing or can be used to kill the job. + * If the given RDD does not contain any partitions, the function returns None. */ - private[scheduler] def prepareJob[T, U: ClassManifest]( - finalRdd: RDD[T], + def submitJob[T, U]( + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - : (JobSubmitted, JobWaiter[U]) = + properties: Properties = null): JobWaiter[U] = { + val jobId = nextJobId.getAndIncrement() + if (partitions.size == 0) { + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + + // Check to make sure we are not launching a task on a partition that does not exist. + val maxPartitions = rdd.partitions.length + partitions.find(p => p >= maxPartitions).foreach { p => + throw new IllegalArgumentException( + "Attempting to access a non-existent partition: " + p + ". " + + "Total number of partitions: " + maxPartitions) + } + assert(partitions.size > 0) - val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, - properties) - (toSubmit, waiter) + val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, + waiter, properties)) + waiter } def runJob[T, U: ClassManifest]( - finalRdd: RDD[T], + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, @@ -296,21 +308,7 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null) { - if (partitions.size == 0) { - return - } - - // Check to make sure we are not launching a task on a partition that does not exist. - val maxPartitions = finalRdd.partitions.length - partitions.find(p => p >= maxPartitions).foreach { p => - throw new IllegalArgumentException( - "Attempting to access a non-existent partition: " + p + ". " + - "Total number of partitions: " + maxPartitions) - } - - val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) - eventQueue.put(toSubmit) + val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception, _) => @@ -331,45 +329,50 @@ class DAGScheduler( val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + val jobId = nextJobId.getAndIncrement() + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, + listener, properties)) listener.awaitResult() // Will throw an exception if the job fails } + /** + * Kill a job that is running or waiting in the queue. + */ def killJob(jobId: Int): Unit = this.synchronized { activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job)) - } - private def killJob(job: ActiveJob): Unit = this.synchronized { - logInfo("Killing Job and cleaning up stages %d".format(job.jobId)) - activeJobs.remove(job) - idToActiveJob.remove(job.jobId) - val stage = job.finalStage - resultStageToJob.remove(stage) - killStage(job, stage) - val e = new SparkException("Job killed") - job.listener.jobFailed(e) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None))) - } - - private def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized { - // TODO: Can we reuse taskSetFailed? - logInfo("Killing Stage %s".format(stage.id)) - stageIdToStage.remove(stage.id) - if (stage.isShuffleMap) { - shuffleToMapStage.remove(stage.id) - } - waiting.remove(stage) - pendingTasks.remove(stage) - taskSched.killTasks(stage.id) - - if (running.contains(stage)) { - running.remove(stage) + def killJob(job: ActiveJob): Unit = this.synchronized { + logInfo("Killing Job and cleaning up stages %d".format(job.jobId)) + activeJobs.remove(job) + idToActiveJob.remove(job.jobId) + val stage = job.finalStage + resultStageToJob.remove(stage) + killStage(job, stage) val e = new SparkException("Job killed") - listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage)))) + job.listener.jobFailed(e) + listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, None))) } - stage.parents.foreach(parentStage => killStage(job, parentStage)) - //stageToInfos -= stage + def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized { + // TODO: Can we reuse taskSetFailed? + logInfo("Killing Stage %s".format(stage.id)) + stageIdToStage.remove(stage.id) + if (stage.isShuffleMap) { + shuffleToMapStage.remove(stage.id) + } + waiting.remove(stage) + pendingTasks.remove(stage) + taskSched.killTasks(stage.id) + + if (running.contains(stage)) { + running.remove(stage) + val e = new SparkException("Job killed") + listenerBus.post(SparkListenerJobEnd(job, JobFailed(e, Some(stage)))) + } + + stage.parents.foreach(parentStage => killStage(job, parentStage)) + //stageToInfos -= stage + } } /** @@ -378,9 +381,8 @@ class DAGScheduler( */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => - val jobId = nextJobId.getAndIncrement() - val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) + case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => + val finalStage = newStage(rdd, None, jobId, Some(callSite)) val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 0d99670..e4b60c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -32,9 +32,10 @@ import org.apache.spark.executor.TaskMetrics * submitted) but there is a single "logic" thread that reads these events and takes decisions. * This greatly simplifies synchronization. */ -private[spark] sealed trait DAGSchedulerEvent +private[scheduler] sealed trait DAGSchedulerEvent -private[spark] case class JobSubmitted( +private[scheduler] case class JobSubmitted( + jobId: Int, finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], @@ -44,9 +45,10 @@ private[spark] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent -private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] +case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent -private[spark] case class CompletionEvent( +private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, @@ -55,10 +57,12 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] +case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent -private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent -private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] +case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent -private[spark] case object StopDAGScheduler extends DAGSchedulerEvent +private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 446d490..9fe7002 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar }) metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.nextJobId.get() + override def getValue: Int = dagScheduler.numTotalJobs }) metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] { http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 200d881..290dcc8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,29 +17,40 @@ package org.apache.spark.scheduler -import scala.collection.mutable.ArrayBuffer - /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. */ -private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) +private[spark] class JobWaiter[T]( + dagScheduler: DAGScheduler, + jobId: Int, + totalTasks: Int, + resultHandler: (Int, T) => Unit) extends JobListener { private var finishedTasks = 0 - private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? - private var jobResult: JobResult = null // If the job is finished, this will be its result + // Is the job as a whole finished (succeeded or failed)? + private var _jobFinished = totalTasks == 0 + + def jobFinished = _jobFinished + + // If the job is finished, this will be its result + private var jobResult: JobResult = null + + def kill() { + dagScheduler.killJob(jobId) + } override def taskSucceeded(index: Int, result: Any) { synchronized { - if (jobFinished) { + if (_jobFinished) { throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") } resultHandler(index, result.asInstanceOf[T]) finishedTasks += 1 if (finishedTasks == totalTasks) { - jobFinished = true + _jobFinished = true jobResult = JobSucceeded this.notifyAll() } @@ -48,17 +59,17 @@ private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Un override def jobFailed(exception: Exception) { synchronized { - if (jobFinished) { + if (_jobFinished) { throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") } - jobFinished = true + _jobFinished = true jobResult = JobFailed(exception, None) this.notifyAll() } } def awaitResult(): JobResult = synchronized { - while (!jobFinished) { + while (!_jobFinished) { this.wait() } return jobResult http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/37d8f37a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 180211c..f39e863 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -190,7 +190,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, allowLocal: Boolean = false, listener: JobListener = listener) { - runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) } /** Sends TaskSetFailed to the scheduler. */ @@ -224,7 +225,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def getPreferredLocations(split: Partition) = Nil override def toString = "DAGSchedulerSuite Local RDD" } - runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) }