Made output of CoGroup and aggregations interruptible.
Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/1d87616b Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/1d87616b Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/1d87616b Branch: refs/heads/master Commit: 1d87616b61cb89e0713c1e375f656a2e2eebc7c4 Parents: c5e4095 Author: Reynold Xin <reyno...@gmail.com> Authored: Thu Sep 19 23:31:36 2013 -0700 Committer: Reynold Xin <reyno...@gmail.com> Committed: Thu Sep 19 23:31:36 2013 -0700 ---------------------------------------------------------------------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 4 +-- .../org/apache/spark/rdd/InterruptibleRDD.scala | 36 ++++++++++++++++++++ .../org/apache/spark/rdd/PairRDDFunctions.scala | 3 ++ .../main/scala/org/apache/spark/rdd/RDD.scala | 5 +++ 4 files changed, 46 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index bd4eba5..2015c33 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,7 +23,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -134,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } } } - JavaConversions.mapAsScalaMap(map).iterator + new InterruptibleIterator(context, JavaConversions.mapAsScalaMap(map).iterator) } override def clearDependencies() { http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala new file mode 100644 index 0000000..e731deb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{InterruptibleIterator, Partition, TaskContext} + + +/** + * Wraps around an existing RDD to make it interruptible (can be killed). + */ +private[spark] +class InterruptibleRDD[T: ClassManifest](prev: RDD[T]) extends RDD[T](prev) { + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + override val partitioner = prev.partitioner + + override def compute(split: Partition, context: TaskContext) = { + new InterruptibleIterator(context, firstParent[T].iterator(split, context)) + } +} http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a47c512..ee17794 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -85,17 +85,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + .interruptible() } else if (mapSideCombine) { val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) .setSerializer(serializerClass) partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) + .interruptible() } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + .interruptible() } } http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/RDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 72a5a20..841fd61 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -852,6 +852,11 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } + /** + * Creates an interruptible version of this RDD. + */ + def interruptible(): RDD[T] = new InterruptibleRDD(this) + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray)