Repository: spark
Updated Branches:
  refs/heads/master d484ddeff -> 6a224c31e


SPARK-1868: Users should be allowed to cogroup at least 4 RDDs

Adds cogroup for 4 RDDs.

Author: Allan Douglas R. de Oliveira <[email protected]>

Closes #813 from douglaz/more_cogroups and squashes the following commits:

f8d6273 [Allan Douglas R. de Oliveira] Test python groupWith for one more case
0e9009c [Allan Douglas R. de Oliveira] Added scala tests
c3ffcdd [Allan Douglas R. de Oliveira] Added java tests
517a67f [Allan Douglas R. de Oliveira] Added tests for python groupWith
2f402d5 [Allan Douglas R. de Oliveira] Removed TODO
17474f4 [Allan Douglas R. de Oliveira] Use new cogroup function
7877a2a [Allan Douglas R. de Oliveira] Fixed code
ba02414 [Allan Douglas R. de Oliveira] Added varargs cogroup to pyspark
c4a8a51 [Allan Douglas R. de Oliveira] Added java cogroup 4
e94963c [Allan Douglas R. de Oliveira] Fixed spacing
f1ee57b [Allan Douglas R. de Oliveira] Fixed scala style issues
d7196f1 [Allan Douglas R. de Oliveira] Allow the cogroup of 4 RDDs


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a224c31
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a224c31
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a224c31

Branch: refs/heads/master
Commit: 6a224c31e8563156ad5732a23667e73076984ae1
Parents: d484dde
Author: Allan Douglas R. de Oliveira <[email protected]>
Authored: Fri Jun 20 11:03:03 2014 -0700
Committer: Patrick Wendell <[email protected]>
Committed: Fri Jun 20 11:03:03 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/java/JavaPairRDD.scala | 51 ++++++++++++++++
 .../org/apache/spark/rdd/PairRDDFunctions.scala | 51 ++++++++++++++++
 .../java/org/apache/spark/JavaAPISuite.java     | 63 ++++++++++++++++++++
 .../spark/rdd/PairRDDFunctionsSuite.scala       | 33 ++++++++++
 python/pyspark/join.py                          | 20 +++----
 python/pyspark/rdd.py                           | 22 ++++---
 6 files changed, 223 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6a224c31/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala 
b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 14fa9d8..4f30814 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -544,6 +544,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
     fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
 
   /**
+   * For each key k in `this` or `other1` or `other2` or `other3`,
+   * return a resulting RDD that contains a tuple with the list of values
+   * for that key in `this`, `other1`, `other2` and `other3`.
+   */
+  def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
+      other2: JavaPairRDD[K, W2],
+      other3: JavaPairRDD[K, W3],
+      partitioner: Partitioner)
+  : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], 
JIterable[W3])] =
+    fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, 
partitioner)))
+
+  /**
    * For each key k in `this` or `other`, return a resulting RDD that contains 
a tuple with the
    * list of values for that key in `this` as well as `other`.
    */
@@ -559,6 +571,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
     fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
 
   /**
+   * For each key k in `this` or `other1` or `other2` or `other3`,
+   * return a resulting RDD that contains a tuple with the list of values
+   * for that key in `this`, `other1`, `other2` and `other3`.
+   */
+  def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
+      other2: JavaPairRDD[K, W2],
+      other3: JavaPairRDD[K, W3])
+  : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], 
JIterable[W3])] =
+    fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3)))
+
+  /**
    * For each key k in `this` or `other`, return a resulting RDD that contains 
a tuple with the
    * list of values for that key in `this` as well as `other`.
    */
@@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
     fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
 
+  /**
+   * For each key k in `this` or `other1` or `other2` or `other3`,
+   * return a resulting RDD that contains a tuple with the list of values
+   * for that key in `this`, `other1`, `other2` and `other3`.
+   */
+  def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
+      other2: JavaPairRDD[K, W2],
+      other3: JavaPairRDD[K, W3],
+      numPartitions: Int)
+  : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], 
JIterable[W3])] =
+    fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, 
numPartitions)))
+
   /** Alias for cogroup. */
   def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], 
JIterable[W])] =
     fromRDD(cogroupResultToJava(rdd.groupWith(other)))
@@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
     fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
 
+  /** Alias for cogroup. */
+  def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1],
+      other2: JavaPairRDD[K, W2],
+      other3: JavaPairRDD[K, W3])
+  : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], 
JIterable[W3])] =
+    fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3)))
+
   /**
    * Return the list of values in the RDD for key `key`. This operation is 
done efficiently if the
    * RDD has a known partitioner by only searching the partition that the key 
maps to.
@@ -786,6 +828,15 @@ object JavaPairRDD {
       .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), 
asJavaIterable(x._3)))
   }
 
+  private[spark]
+  def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3](
+      rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))])
+  : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = {
+    rddToPairRDDFunctions(rdd)
+      .mapValues(x =>
+        (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), 
asJavaIterable(x._4)))
+  }
+
   def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = 
{
     new JavaPairRDD[K, V](rdd)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a224c31/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala 
b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index fe36c80..443d1c5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -568,6 +568,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
   }
 
   /**
+   * For each key k in `this` or `other1` or `other2` or `other3`,
+   * return a resulting RDD that contains a tuple with the list of values
+   * for that key in `this`, `other1`, `other2` and `other3`.
+   */
+  def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
+      other2: RDD[(K, W2)],
+      other3: RDD[(K, W3)],
+      partitioner: Partitioner)
+      : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+    if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
+      throw new SparkException("Default partitioner cannot partition array 
keys.")
+    }
+    val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), 
partitioner)
+    cg.mapValues { case Seq(vs, w1s, w2s, w3s) =>
+      (vs.asInstanceOf[Seq[V]],
+        w1s.asInstanceOf[Seq[W1]],
+        w2s.asInstanceOf[Seq[W2]],
+        w3s.asInstanceOf[Seq[W3]])
+    }
+  }
+
+  /**
    * For each key k in `this` or `other`, return a resulting RDD that contains 
a tuple with the
    * list of values for that key in `this` as well as `other`.
    */
@@ -600,6 +622,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
   }
 
   /**
+   * For each key k in `this` or `other1` or `other2` or `other3`,
+   * return a resulting RDD that contains a tuple with the list of values
+   * for that key in `this`, `other1`, `other2` and `other3`.
+   */
+  def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: 
RDD[(K, W3)])
+      : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+    cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, 
other3))
+  }
+
+  /**
    * For each key k in `this` or `other`, return a resulting RDD that contains 
a tuple with the
    * list of values for that key in `this` as well as `other`.
    */
@@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
     cogroup(other1, other2, new HashPartitioner(numPartitions))
   }
 
+  /**
+   * For each key k in `this` or `other1` or `other2` or `other3`,
+   * return a resulting RDD that contains a tuple with the list of values
+   * for that key in `this`, `other1`, `other2` and `other3`.
+   */
+  def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
+      other2: RDD[(K, W2)],
+      other3: RDD[(K, W3)],
+      numPartitions: Int)
+      : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+    cogroup(other1, other2, other3, new HashPartitioner(numPartitions))
+  }
+
   /** Alias for cogroup. */
   def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = 
{
     cogroup(other, defaultPartitioner(self, other))
@@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
     cogroup(other1, other2, defaultPartitioner(self, other1, other2))
   }
 
+  /** Alias for cogroup. */
+  def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], 
other3: RDD[(K, W3)])
+      : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
+    cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, 
other3))
+  }
+
   /**
    * Return an RDD with the pairs from `this` whose keys are not in `other`.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/6a224c31/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java 
b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e46298c..761f2d6 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -21,6 +21,9 @@ import java.io.*;
 import java.util.*;
 
 import scala.Tuple2;
+import scala.Tuple3;
+import scala.Tuple4;
+
 
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Iterators;
@@ -306,6 +309,66 @@ public class JavaAPISuite implements Serializable {
 
   @SuppressWarnings("unchecked")
   @Test
+  public void cogroup3() {
+    JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+      new Tuple2<String, String>("Apples", "Fruit"),
+      new Tuple2<String, String>("Oranges", "Fruit"),
+      new Tuple2<String, String>("Oranges", "Citrus")
+      ));
+    JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
+      new Tuple2<String, Integer>("Oranges", 2),
+      new Tuple2<String, Integer>("Apples", 3)
+    ));
+    JavaPairRDD<String, Integer> quantities = 
sc.parallelizePairs(Arrays.asList(
+      new Tuple2<String, Integer>("Oranges", 21),
+      new Tuple2<String, Integer>("Apples", 42)
+    ));
+
+    JavaPairRDD<String, Tuple3<Iterable<String>, Iterable<Integer>, 
Iterable<Integer>>> cogrouped =
+        categories.cogroup(prices, quantities);
+    Assert.assertEquals("[Fruit, Citrus]",
+                        
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+    Assert.assertEquals("[2]", 
Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
+    Assert.assertEquals("[42]", 
Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
+
+
+    cogrouped.collect();
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void cogroup4() {
+    JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+      new Tuple2<String, String>("Apples", "Fruit"),
+      new Tuple2<String, String>("Oranges", "Fruit"),
+      new Tuple2<String, String>("Oranges", "Citrus")
+      ));
+    JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
+      new Tuple2<String, Integer>("Oranges", 2),
+      new Tuple2<String, Integer>("Apples", 3)
+    ));
+    JavaPairRDD<String, Integer> quantities = 
sc.parallelizePairs(Arrays.asList(
+      new Tuple2<String, Integer>("Oranges", 21),
+      new Tuple2<String, Integer>("Apples", 42)
+    ));
+    JavaPairRDD<String, String> countries = sc.parallelizePairs(Arrays.asList(
+      new Tuple2<String, String>("Oranges", "BR"),
+      new Tuple2<String, String>("Apples", "US")
+    ));
+
+    JavaPairRDD<String, Tuple4<Iterable<String>, Iterable<Integer>, 
Iterable<Integer>, Iterable<String>>> cogrouped =
+        categories.cogroup(prices, quantities, countries);
+    Assert.assertEquals("[Fruit, Citrus]",
+                        
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+    Assert.assertEquals("[2]", 
Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
+    Assert.assertEquals("[42]", 
Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
+    Assert.assertEquals("[BR]", 
Iterables.toString(cogrouped.lookup("Oranges").get(0)._4()));
+
+    cogrouped.collect();
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
   public void leftOuterJoin() {
     JavaPairRDD<Integer, Integer> rdd1 = sc.parallelizePairs(Arrays.asList(
       new Tuple2<Integer, Integer>(1, 1),

http://git-wip-us.apache.org/repos/asf/spark/blob/6a224c31/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala 
b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 0b90044..447e38e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with 
SharedSparkContext {
     ))
   }
 
+  test("groupWith3") {
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+    val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+    val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
+    val joined = rdd1.groupWith(rdd2, rdd3).collect()
+    assert(joined.size === 4)
+    val joinedSet = joined.map(x => (x._1,
+      (x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet
+    assert(joinedSet === Set(
+      (1, (List(1, 2), List('x'), List('a'))),
+      (2, (List(1), List('y', 'z'), List())),
+      (3, (List(1), List(), List('b'))),
+      (4, (List(), List('w'), List('c', 'd')))
+    ))
+  }
+
+  test("groupWith4") {
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+    val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+    val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
+    val rdd4 = sc.parallelize(Array((2, '@')))
+    val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect()
+    assert(joined.size === 4)
+    val joinedSet = joined.map(x => (x._1,
+      (x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet
+    assert(joinedSet === Set(
+      (1, (List(1, 2), List('x'), List('a'), List())),
+      (2, (List(1), List('y', 'z'), List(), List('@'))),
+      (3, (List(1), List(), List('b'), List())),
+      (4, (List(), List('w'), List('c', 'd'), List()))
+    ))
+  }
+
   test("zero-partition RDD") {
     val emptyDir = Files.createTempDir()
     emptyDir.deleteOnExit()

http://git-wip-us.apache.org/repos/asf/spark/blob/6a224c31/python/pyspark/join.py
----------------------------------------------------------------------
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 6f94d26..5f3a7e7 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -79,15 +79,15 @@ def python_left_outer_join(rdd, other, numPartitions):
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
-def python_cogroup(rdd, other, numPartitions):
-    vs = rdd.map(lambda (k, v): (k, (1, v)))
-    ws = other.map(lambda (k, v): (k, (2, v)))
+def python_cogroup(rdds, numPartitions):
+    def make_mapper(i):
+        return lambda (k, v): (k, (i, v))
+    vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
+    union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
+    rdd_len = len(vrdds)
     def dispatch(seq):
-        vbuf, wbuf = [], []
+        bufs = [[] for i in range(rdd_len)]
         for (n, v) in seq:
-            if n == 1:
-                vbuf.append(v)
-            elif n == 2:
-                wbuf.append(v)
-        return (ResultIterable(vbuf), ResultIterable(wbuf))
-    return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
+            bufs[n].append(v)
+        return tuple(map(ResultIterable, bufs))
+    return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)

http://git-wip-us.apache.org/repos/asf/spark/blob/6a224c31/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 62a95c8..1d55c35 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1233,7 +1233,7 @@ class RDD(object):
                     combiners[k] = mergeCombiners(combiners[k], v)
             return combiners.iteritems()
         return shuffled.mapPartitions(_mergeCombiners)
-   
+
     def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
         """
         Aggregate the values of each key, using given combine functions and a 
neutral "zero value".
@@ -1245,7 +1245,7 @@ class RDD(object):
         """
         def createZero():
           return copy.deepcopy(zeroValue)
-        
+
         return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, 
combFunc, numPartitions)
 
     def foldByKey(self, zeroValue, func, numPartitions=None):
@@ -1323,12 +1323,20 @@ class RDD(object):
         map_values_fn = lambda (k, v): (k, f(v))
         return self.map(map_values_fn, preservesPartitioning=True)
 
-    # TODO: support varargs cogroup of several RDDs.
-    def groupWith(self, other):
+    def groupWith(self, other, *others):
         """
-        Alias for cogroup.
+        Alias for cogroup but with support for multiple RDDs.
+
+        >>> w = sc.parallelize([("a", 5), ("b", 6)])
+        >>> x = sc.parallelize([("a", 1), ("b", 4)])
+        >>> y = sc.parallelize([("a", 2)])
+        >>> z = sc.parallelize([("b", 42)])
+        >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), 
list(y[3])))), \
+                sorted(list(w.groupWith(x, y, z).collect())))
+        [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]
+
         """
-        return self.cogroup(other)
+        return python_cogroup((self, other) + others, numPartitions=None)
 
     # TODO: add variant with custom parittioner
     def cogroup(self, other, numPartitions=None):
@@ -1342,7 +1350,7 @@ class RDD(object):
         >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), 
sorted(list(x.cogroup(y).collect())))
         [('a', ([1], [2])), ('b', ([4], []))]
         """
-        return python_cogroup(self, other, numPartitions)
+        return python_cogroup((self, other), numPartitions)
 
     def subtractByKey(self, other, numPartitions=None):
         """

Reply via email to