Repository: spark
Updated Branches:
  refs/heads/master 7d52777ef -> 9909efc10


SPARK-1839: PySpark RDD#take() shouldn't always read from driver

This patch simply ports over the Scala implementation of RDD#take(), which 
reads the first partition at the driver, then decides how many more partitions 
it needs to read and will possibly start a real job if it's more than 1. (Note 
that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 
1 partition selected and no parent stages.)

Author: Aaron Davidson <[email protected]>

Closes #922 from aarondav/take and squashes the following commits:

fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read 
from driver


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9909efc1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9909efc1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9909efc1

Branch: refs/heads/master
Commit: 9909efc10aaa62c47fd7c4c9da73ac8c56a454d5
Parents: 7d52777
Author: Aaron Davidson <[email protected]>
Authored: Sat May 31 13:04:57 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Sat May 31 13:04:57 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 20 +++++++
 python/pyspark/context.py                       | 26 +++++++++
 python/pyspark/rdd.py                           | 59 +++++++++++++-------
 3 files changed, 84 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9909efc1/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 57b28b9..d1df993 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -269,6 +269,26 @@ private object SpecialLengths {
 private[spark] object PythonRDD {
   val UTF8 = Charset.forName("UTF-8")
 
+  /**
+   * Adapter for calling SparkContext#runJob from Python.
+   *
+   * This method will return an iterator of an array that contains all 
elements in the RDD
+   * (effectively a collect()), but allows you to run on a certain subset of 
partitions,
+   * or to enable local execution.
+   */
+  def runJob(
+      sc: SparkContext,
+      rdd: JavaRDD[Array[Byte]],
+      partitions: JArrayList[Int],
+      allowLocal: Boolean): Iterator[Array[Byte]] = {
+    type ByteArray = Array[Byte]
+    type UnrolledPartition = Array[ByteArray]
+    val allPartitions: Array[UnrolledPartition] =
+      sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, 
allowLocal)
+    val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
+    flattenedPartition.iterator
+  }
+
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: 
Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))

http://git-wip-us.apache.org/repos/asf/spark/blob/9909efc1/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 56746cb..9ae9305 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -537,6 +537,32 @@ class SparkContext(object):
         """
         self._jsc.sc().cancelAllJobs()
 
+    def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = 
False):
+        """
+        Executes the given partitionFunc on the specified set of partitions,
+        returning the result as an array of elements.
+
+        If 'partitions' is not specified, this will run over all partitions.
+
+        >>> myRDD = sc.parallelize(range(6), 3)
+        >>> sc.runJob(myRDD, lambda part: [x * x for x in part])
+        [0, 1, 4, 9, 16, 25]
+
+        >>> myRDD = sc.parallelize(range(6), 3)
+        >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True)
+        [0, 1, 16, 25]
+        """
+        if partitions == None:
+            partitions = range(rdd._jrdd.splits().size())
+        javaPartitions = ListConverter().convert(partitions, 
self._gateway._gateway_client)
+
+        # Implementation note: This is implemented as a mapPartitions followed
+        # by runJob() in order to avoid having to pass a Python lambda into
+        # SparkContext#runJob.
+        mappedRDD = rdd.mapPartitions(partitionFunc)
+        it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, 
javaPartitions, allowLocal)
+        return list(mappedRDD._collect_iterator_through_file(it))
+
 def _test():
     import atexit
     import doctest

http://git-wip-us.apache.org/repos/asf/spark/blob/9909efc1/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 07578b8..f3b1f1a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -841,34 +841,51 @@ class RDD(object):
         """
         Take the first num elements of the RDD.
 
-        This currently scans the partitions *one by one*, so it will be slow if
-        a lot of partitions are required. In that case, use L{collect} to get
-        the whole RDD instead.
+        It works by first scanning one partition, and use the results from
+        that partition to estimate the number of additional partitions needed
+        to satisfy the limit.
+
+        Translated from the Scala implementation in RDD#take().
 
         >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
         [2, 3]
         >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
         [2, 3, 4, 5, 6]
+        >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3)
+        [91, 92, 93]
         """
-        def takeUpToNum(iterator):
-            taken = 0
-            while taken < num:
-                yield next(iterator)
-                taken += 1
-        # Take only up to num elements from each partition we try
-        mapped = self.mapPartitions(takeUpToNum)
         items = []
-        # TODO(shivaram): Similar to the scala implementation, update the take 
-        # method to scan multiple splits based on an estimate of how many 
elements 
-        # we have per-split.
-        with _JavaStackTrace(self.context) as st:
-            for partition in range(mapped._jrdd.splits().size()):
-                partitionsToTake = 
self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
-                partitionsToTake[0] = partition
-                iterator = 
mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
-                items.extend(mapped._collect_iterator_through_file(iterator))
-                if len(items) >= num:
-                    break
+        totalParts = self._jrdd.splits().size()
+        partsScanned = 0
+
+        while len(items) < num and partsScanned < totalParts:
+            # The number of partitions to try in this iteration.
+            # It is ok for this number to be greater than totalParts because
+            # we actually cap it at totalParts in runJob.
+            numPartsToTry = 1
+            if partsScanned > 0:
+                # If we didn't find any rows after the first iteration, just
+                # try all partitions next. Otherwise, interpolate the number
+                # of partitions we need to try, but overestimate it by 50%.
+                if len(items) == 0:
+                    numPartsToTry = totalParts - 1
+                else:
+                    numPartsToTry = int(1.5 * num * partsScanned / len(items))
+
+            left = num - len(items)
+
+            def takeUpToNumLeft(iterator):
+                taken = 0
+                while taken < left:
+                    yield next(iterator)
+                    taken += 1
+
+            p = range(partsScanned, min(partsScanned + numPartsToTry, 
totalParts))
+            res = self.context.runJob(self, takeUpToNumLeft, p, True)
+
+            items += res
+            partsScanned += numPartsToTry
+
         return items[:num]
 
     def first(self):

Reply via email to