SPARK-705: implement sortByKey() in PySpark

Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/fdbae41e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/fdbae41e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/fdbae41e

Branch: refs/heads/scala-2.10
Commit: fdbae41e88af0994e97ac8741f86547184a05de9
Parents: d585613
Author: Andre Schumacher <schum...@icsi.berkeley.edu>
Authored: Mon Oct 7 10:42:39 2013 -0700
Committer: Andre Schumacher <schum...@icsi.berkeley.edu>
Committed: Mon Oct 7 12:16:33 2013 -0700

----------------------------------------------------------------------
 python/pyspark/rdd.py | 48 +++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 47 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/fdbae41e/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 39c402b..7dfabb0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -263,7 +263,53 @@ class RDD(object):
             raise TypeError
         return self.union(other)
 
-    # TODO: sort
+    def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda 
x: x):
+        """
+        Sorts this RDD, which is assumed to consist of (key, value) pairs.
+
+        >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+        >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
+        [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
+        >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 
5)]
+        >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 
9)])
+        >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: 
k.lower()).collect()
+        [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), 
('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
+        """
+        if numPartitions is None:
+            numPartitions = self.ctx.defaultParallelism
+
+        bounds = list()
+
+        # first compute the boundary of each part via sampling: we want to 
partition
+        # the key-space into bins such that the bins have roughly the same
+        # number of (key, value) pairs falling into them
+        if numPartitions > 1:
+            rddSize = self.count()
+            maxSampleSize = numPartitions * 20.0 # constant from Spark's 
RangePartitioner
+            fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+
+            samples = self.sample(False, fraction, 1).map(lambda (k, v): 
k).collect()        
+            samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+            # we have numPartitions many parts but one of the them has
+            # an implicit boundary
+            for i in range(0, numPartitions - 1):
+                index = (len(samples) - 1) * (i + 1) / numPartitions
+                bounds.append(samples[index])
+
+        def rangePartitionFunc(k):
+            p = 0
+            while p < len(bounds) and keyfunc(k) > bounds[p]:
+                p += 1
+            if ascending:
+                return p
+            else:
+                return numPartitions-1-p
+
+        def mapFunc(iterator):
+            yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): 
keyfunc(k))
+
+        return self.partitionBy(numPartitions, 
partitionFunc=rangePartitionFunc).mapPartitions(mapFunc,preservesPartitioning=True).flatMap(lambda
 x: x, preservesPartitioning=True)
 
     def glom(self):
         """

Reply via email to