Added task killing iterator to RDDs that take inputs.
Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/70953810 Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/70953810 Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/70953810 Branch: refs/heads/master Commit: 70953810b4e012eedd29514e01501c846e8d08f1 Parents: f19984d Author: Reynold Xin <reyno...@gmail.com> Authored: Thu Sep 19 18:33:16 2013 -0700 Committer: Reynold Xin <reyno...@gmail.com> Committed: Thu Sep 19 18:33:16 2013 -0700 ---------------------------------------------------------------------- .../apache/spark/BlockStoreShuffleFetcher.scala | 12 ++- .../scala/org/apache/spark/ShuffleFetcher.scala | 5 +- .../org/apache/spark/rdd/CoGroupedRDD.scala | 2 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 79 ++++++++++---------- .../spark/rdd/ParallelCollectionRDD.scala | 5 +- .../org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 2 +- 7 files changed, 60 insertions(+), 47 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/70953810/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala index 908ff56..ca7f9f8 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala @@ -28,7 +28,11 @@ import org.apache.spark.util.CompletionIterator private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) + override def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, + serializer: Serializer) : Iterator[T] = { @@ -74,7 +78,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) val itr = blockFetcherItr.flatMap(unpackBlock) - CompletionIterator[T, Iterator[T]](itr, { + val completionIter = CompletionIterator[T, Iterator[T]](itr, { val shuffleMetrics = new ShuffleReadMetrics shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime @@ -83,7 +87,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - metrics.shuffleReadMetrics = Some(shuffleMetrics) + context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics) }) + + new InterruptibleIterator[T](context, completionIter) } } http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/70953810/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala index 307c383..a85aa50 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala @@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher { * Fetch the shuffle outputs for a given ShuffleDependency. * @return An iterator over the elements of the fetched shuffle outputs. */ - def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] /** Stop the fetcher */ http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/70953810/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 0187256..bd4eba5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -129,7 +129,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { + fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { kv => getSeq(kv._1)(depNum) += kv._2 } } http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/70953810/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b3a89f..2662d48 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext} private[spark] @@ -71,49 +71,52 @@ class NewHadoopRDD[K, V]( result } - override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - if (format.isInstanceOf[Configurable]) { - format.asInstanceOf[Configurable].setConf(conf) - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) - - var havePair = false - var finished = false - - override def hasNext: Boolean = { - if (!finished && !havePair) { - finished = !reader.nextKeyValue - havePair = !finished + override def compute(theSplit: Partition, context: TaskContext) = { + val iter = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = confBroadcast.value.value + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + if (format.isInstanceOf[Configurable]) { + format.asInstanceOf[Configurable].setConf(conf) + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => close()) + + var havePair = false + var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + } + !finished } - !finished - } - override def next: (K, V) = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") + override def next(): (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + (reader.getCurrentKey, reader.getCurrentValue) } - havePair = false - return (reader.getCurrentKey, reader.getCurrentValue) - } - private def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + private def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } } } + new InterruptibleIterator(context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/70953810/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 6dbd430..cd96250 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -94,8 +94,9 @@ private[spark] class ParallelCollectionRDD[T: ClassManifest]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = - s.asInstanceOf[ParallelCollectionPartition[T]].iterator + override def compute(s: Partition, context: TaskContext) = { + new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) + } override def getPreferredLocations(s: Partition): Seq[String] = { locationPrefs.getOrElse(s.index, Nil) http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/70953810/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 9537152..a5d751a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -56,7 +56,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics, + SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, SparkEnv.get.serializerManager.get(serializerClass)) } http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/70953810/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 8c1a29d..7af4d80 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -108,7 +108,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM } case ShuffleCoGroupSplitDep(shuffleId) => { val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context.taskMetrics, serializer) + context, serializer) iter.foreach(op) } }