This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new bd7f5da [SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for different types of RDD and DStreams bd7f5da is described below commit bd7f5da3dfa0ce3edda0c9864cd0f89db744277f Author: HyukjinKwon <gurwls...@apache.org> AuthorDate: Mon Jun 1 09:43:03 2020 +0900 [SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for different types of RDD and DStreams ### What changes were proposed in this pull request? This PR manually specifies the class for the input array being used in `(SparkContext|StreamingContext).union`. It fixes a regression introduced from SPARK-25737. ```python rdd1 = sc.parallelize([1,2,3,4,5]) rdd2 = sc.parallelize([6,7,8,9,10]) pairRDD1 = rdd1.zip(rdd2) sc.union([pairRDD1, pairRDD1]).collect() ``` in the current master and `branch-3.0`: ``` Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../spark/python/pyspark/context.py", line 870, in union jrdds[i] = rdds[i]._jrdd File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 238, in __setitem__ File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 221, in __set_item File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 332, in get_return_value py4j.protocol.Py4JError: An error occurred while calling None.None. Trace: py4j.Py4JException: Cannot convert org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD at py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166) at py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144) at py4j.commands.ArrayCommand.execute(ArrayCommand.java:97) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) ``` which works in Spark 2.4.5: ``` [(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10)] ``` It assumed the class of the input array is the same `JavaRDD` or `JavaDStream`; however, that can be different such as `JavaPairRDD`. This fix is based on redsanket's initial approach, and will be co-authored. ### Why are the changes needed? To fix a regression from Spark 2.4.5. ### Does this PR introduce _any_ user-facing change? No, it's only in unreleased branches. This is to fix a regression. ### How was this patch tested? Manually tested, and a unittest was added. Closes #28648 from HyukjinKwon/SPARK-31788. Authored-by: HyukjinKwon <gurwls...@apache.org> Signed-off-by: HyukjinKwon <gurwls...@apache.org> (cherry picked from commit 29c51d682b3735123f78cf9cb8610522a9bb86fd) Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- python/pyspark/context.py | 18 ++++++++++++++++-- python/pyspark/streaming/context.py | 15 ++++++++++++--- python/pyspark/tests/test_rdd.py | 11 +++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index d5f1506..81b6caa 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -25,6 +25,7 @@ from threading import RLock from tempfile import NamedTemporaryFile from py4j.protocol import Py4JError +from py4j.java_gateway import is_instance_of from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -864,8 +865,21 @@ class SparkContext(object): first_jrdd_deserializer = rdds[0]._jrdd_deserializer if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): rdds = [x._reserialize() for x in rdds] - cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD - jrdds = SparkContext._gateway.new_array(cls, len(rdds)) + gw = SparkContext._gateway + jvm = SparkContext._jvm + jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD + jpair_rdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD + jdouble_rdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD + if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls): + cls = jrdd_cls + elif is_instance_of(gw, rdds[0]._jrdd, jpair_rdd_cls): + cls = jpair_rdd_cls + elif is_instance_of(gw, rdds[0]._jrdd, jdouble_rdd_cls): + cls = jdouble_rdd_cls + else: + cls_name = rdds[0]._jrdd.getClass().getCanonicalName() + raise TypeError("Unsupported Java RDD class %s" % cls_name) + jrdds = gw.new_array(cls, len(rdds)) for i in range(0, len(rdds)): jrdds[i] = rdds[i]._jrdd return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 769121c..6199611 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -17,7 +17,7 @@ from __future__ import print_function -from py4j.java_gateway import java_import +from py4j.java_gateway import java_import, is_instance_of from pyspark import RDD, SparkConf from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer @@ -341,8 +341,17 @@ class StreamingContext(object): raise ValueError("All DStreams should have same serializer") if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") - cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream - jdstreams = SparkContext._gateway.new_array(cls, len(dstreams)) + jdstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream + jpair_dstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream + gw = SparkContext._gateway + if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls): + cls = jdstream_cls + elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls): + cls = jpair_dstream_cls + else: + cls_name = dstreams[0]._jdstream.getClass().getCanonicalName() + raise TypeError("Unsupported Java DStream class %s" % cls_name) + jdstreams = gw.new_array(cls, len(dstreams)) for i in range(0, len(dstreams)): jdstreams[i] = dstreams[i]._jdstream return DStream(self._jssc.union(jdstreams), self, dstreams[0]._jrdd_deserializer) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index e2d910c..cf58220 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -166,6 +166,17 @@ class RDDTests(ReusedPySparkTestCase): set([(x, (x, x)) for x in 'abc']) ) + def test_union_pair_rdd(self): + # SPARK-31788: test if pair RDDs can be combined by union. + rdd = self.sc.parallelize([1, 2]) + pair_rdd = rdd.zip(rdd) + unionRDD = self.sc.union([pair_rdd, pair_rdd]) + self.assertEqual( + set(unionRDD.collect()), + set([(1, 1), (2, 2), (1, 1), (2, 2)]) + ) + self.assertEqual(unionRDD.count(), 4) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org