Made output of CoGroup and aggregations interruptible.

Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/1d87616b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/1d87616b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/1d87616b

Branch: refs/heads/master
Commit: 1d87616b61cb89e0713c1e375f656a2e2eebc7c4
Parents: c5e4095
Author: Reynold Xin <reyno...@gmail.com>
Authored: Thu Sep 19 23:31:36 2013 -0700
Committer: Reynold Xin <reyno...@gmail.com>
Committed: Thu Sep 19 23:31:36 2013 -0700

----------------------------------------------------------------------
 .../org/apache/spark/rdd/CoGroupedRDD.scala     |  4 +--
 .../org/apache/spark/rdd/InterruptibleRDD.scala | 36 ++++++++++++++++++++
 .../org/apache/spark/rdd/PairRDDFunctions.scala |  3 ++
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  5 +++
 4 files changed, 46 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index bd4eba5..2015c33 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,7 +23,7 @@ import java.util.{HashMap => JHashMap}
 import scala.collection.JavaConversions
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, 
SparkEnv, TaskContext}
 import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
 
 
@@ -134,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: 
Product2[K, _]]], part:
         }
       }
     }
-    JavaConversions.mapAsScalaMap(map).iterator
+    new InterruptibleIterator(context, 
JavaConversions.mapAsScalaMap(map).iterator)
   }
 
   override def clearDependencies() {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
new file mode 100644
index 0000000..e731deb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/InterruptibleRDD.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{InterruptibleIterator, Partition, TaskContext}
+
+
+/**
+ * Wraps around an existing RDD to make it interruptible (can be killed).
+ */
+private[spark]
+class InterruptibleRDD[T: ClassManifest](prev: RDD[T]) extends RDD[T](prev) {
+
+  override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+  override val partitioner = prev.partitioner
+
+  override def compute(split: Partition, context: TaskContext) = {
+    new InterruptibleIterator(context, firstParent[T].iterator(split, context))
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala 
b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index a47c512..ee17794 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -85,17 +85,20 @@ class PairRDDFunctions[K: ClassManifest, V: 
ClassManifest](self: RDD[(K, V)])
     val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, 
mergeCombiners)
     if (self.partitioner == Some(partitioner)) {
       self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning 
= true)
+        .interruptible()
     } else if (mapSideCombine) {
       val combined = self.mapPartitions(aggregator.combineValuesByKey, 
preservesPartitioning = true)
       val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
         .setSerializer(serializerClass)
       partitioned.mapPartitions(aggregator.combineCombinersByKey, 
preservesPartitioning = true)
+        .interruptible()
     } else {
       // Don't apply map-side combiner.
       // A sanity check to make sure mergeCombiners is not defined.
       assert(mergeCombiners == null)
       val values = new ShuffledRDD[K, V, (K, V)](self, 
partitioner).setSerializer(serializerClass)
       values.mapPartitions(aggregator.combineValuesByKey, 
preservesPartitioning = true)
+        .interruptible()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/1d87616b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 72a5a20..841fd61 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -852,6 +852,11 @@ abstract class RDD[T: ClassManifest](
     map(x => (f(x), x))
   }
 
+  /**
+   * Creates an interruptible version of this RDD.
+   */
+  def interruptible(): RDD[T] = new InterruptibleRDD(this)
+
   /** A private method for tests, to look at the contents of each partition */
   private[spark] def collectPartitions(): Array[Array[T]] = {
     sc.runJob(this, (iter: Iterator[T]) => iter.toArray)

Reply via email to