Repository: spark
Updated Branches:
  refs/heads/branch-2.1 726217eb7 -> e0173f14e


[SPARK-16589] [PYTHON] Chained cartesian produces incorrect number of records

## What changes were proposed in this pull request?

Fixes a bug in the python implementation of rdd cartesian product related to 
batching that showed up in repeated cartesian products with seemingly random 
results. The root cause being multiple iterators pulling from the same stream 
in the wrong order because of logic that ignored batching.

`CartesianDeserializer` and `PairDeserializer` were changed to implement 
`_load_stream_without_unbatching` and borrow the one line implementation of 
`load_stream` from `BatchedSerializer`. The default implementation of 
`_load_stream_without_unbatching` was changed to give consistent results 
(always an iterable) so that it could be used without additional checks.

`PairDeserializer` no longer extends `CartesianDeserializer` as it was not 
really proper. If wanted a new common super class could be added.

Both `CartesianDeserializer` and `PairDeserializer` now only extend 
`Serializer` (which has no `dump_stream` implementation) since they are only 
meant for *de*serialization.

## How was this patch tested?

Additional unit tests (sourced from #14248) plus one for testing a cartesian 
with zip.

Author: Andrew Ray <[email protected]>

Closes #16121 from aray/fix-cartesian.

(cherry picked from commit 3c68944b229aaaeeaee3efcbae3e3be9a2914855)
Signed-off-by: Davies Liu <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0173f14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0173f14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0173f14

Branch: refs/heads/branch-2.1
Commit: e0173f14e3ea28d83c1c46bf97f7d3755960a8fc
Parents: 726217e
Author: Andrew Ray <[email protected]>
Authored: Thu Dec 8 11:08:12 2016 -0800
Committer: Davies Liu <[email protected]>
Committed: Thu Dec 8 11:08:27 2016 -0800

----------------------------------------------------------------------
 python/pyspark/serializers.py | 58 +++++++++++++++++++++++---------------
 python/pyspark/tests.py       | 18 ++++++++++++
 2 files changed, 53 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e0173f14/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 2a13269..c4f2f08 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -61,7 +61,7 @@ import itertools
 if sys.version < '3':
     import cPickle as pickle
     protocol = 2
-    from itertools import izip as zip
+    from itertools import izip as zip, imap as map
 else:
     import pickle
     protocol = 3
@@ -96,7 +96,12 @@ class Serializer(object):
         raise NotImplementedError
 
     def _load_stream_without_unbatching(self, stream):
-        return self.load_stream(stream)
+        """
+        Return an iterator of deserialized batches (lists) of objects from the 
input stream.
+        if the serializer does not operate on batches the default 
implementation returns an
+        iterator of single element lists.
+        """
+        return map(lambda x: [x], self.load_stream(stream))
 
     # Note: our notion of "equality" is that output generated by
     # equal serializers can be deserialized using the same serializer.
@@ -278,50 +283,57 @@ class AutoBatchedSerializer(BatchedSerializer):
         return "AutoBatchedSerializer(%s)" % self.serializer
 
 
-class CartesianDeserializer(FramedSerializer):
+class CartesianDeserializer(Serializer):
 
     """
     Deserializes the JavaRDD cartesian() of two PythonRDDs.
+    Due to pyspark batching we cannot simply use the result of the Java RDD 
cartesian,
+    we additionally need to do the cartesian within each pair of batches.
     """
 
     def __init__(self, key_ser, val_ser):
-        FramedSerializer.__init__(self)
         self.key_ser = key_ser
         self.val_ser = val_ser
 
-    def prepare_keys_values(self, stream):
-        key_stream = self.key_ser._load_stream_without_unbatching(stream)
-        val_stream = self.val_ser._load_stream_without_unbatching(stream)
-        key_is_batched = isinstance(self.key_ser, BatchedSerializer)
-        val_is_batched = isinstance(self.val_ser, BatchedSerializer)
-        for (keys, vals) in zip(key_stream, val_stream):
-            keys = keys if key_is_batched else [keys]
-            vals = vals if val_is_batched else [vals]
-            yield (keys, vals)
+    def _load_stream_without_unbatching(self, stream):
+        key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
+        val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
+        for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
+            # for correctness with repeated cartesian/zip this must be 
returned as one batch
+            yield product(key_batch, val_batch)
 
     def load_stream(self, stream):
-        for (keys, vals) in self.prepare_keys_values(stream):
-            for pair in product(keys, vals):
-                yield pair
+        return 
chain.from_iterable(self._load_stream_without_unbatching(stream))
 
     def __repr__(self):
         return "CartesianDeserializer(%s, %s)" % \
                (str(self.key_ser), str(self.val_ser))
 
 
-class PairDeserializer(CartesianDeserializer):
+class PairDeserializer(Serializer):
 
     """
     Deserializes the JavaRDD zip() of two PythonRDDs.
+    Due to pyspark batching we cannot simply use the result of the Java RDD 
zip,
+    we additionally need to do the zip within each pair of batches.
     """
 
+    def __init__(self, key_ser, val_ser):
+        self.key_ser = key_ser
+        self.val_ser = val_ser
+
+    def _load_stream_without_unbatching(self, stream):
+        key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
+        val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
+        for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
+            if len(key_batch) != len(val_batch):
+                raise ValueError("Can not deserialize PairRDD with different 
number of items"
+                                 " in batches: (%d, %d)" % (len(key_batch), 
len(val_batch)))
+            # for correctness with repeated cartesian/zip this must be 
returned as one batch
+            yield zip(key_batch, val_batch)
+
     def load_stream(self, stream):
-        for (keys, vals) in self.prepare_keys_values(stream):
-            if len(keys) != len(vals):
-                raise ValueError("Can not deserialize RDD with different 
number of items"
-                                 " in pair: (%d, %d)" % (len(keys), len(vals)))
-            for pair in zip(keys, vals):
-                yield pair
+        return 
chain.from_iterable(self._load_stream_without_unbatching(stream))
 
     def __repr__(self):
         return "PairDeserializer(%s, %s)" % (str(self.key_ser), 
str(self.val_ser))

http://git-wip-us.apache.org/repos/asf/spark/blob/e0173f14/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index ab4bef8..89fce8a 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -548,6 +548,24 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEqual(u"Hello World!", x.strip())
         self.assertEqual(u"Hello World!", y.strip())
 
+    def test_cartesian_chaining(self):
+        # Tests for SPARK-16589
+        rdd = self.sc.parallelize(range(10), 2)
+        self.assertSetEqual(
+            set(rdd.cartesian(rdd).cartesian(rdd).collect()),
+            set([((x, y), z) for x in range(10) for y in range(10) for z in 
range(10)])
+        )
+
+        self.assertSetEqual(
+            set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
+            set([(x, (y, z)) for x in range(10) for y in range(10) for z in 
range(10)])
+        )
+
+        self.assertSetEqual(
+            set(rdd.cartesian(rdd.zip(rdd)).collect()),
+            set([(x, (y, y)) for x in range(10) for y in range(10)])
+        )
+
     def test_deleting_input_files(self):
         # Regression test for SPARK-1025
         tempFile = tempfile.NamedTemporaryFile(delete=False)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to