HyukjinKwon commented on a change in pull request #28648:
URL: https://github.com/apache/spark/pull/28648#discussion_r430825553



##########
File path: python/pyspark/streaming/context.py
##########
@@ -341,8 +341,17 @@ def union(self, *dstreams):
             raise ValueError("All DStreams should have same serializer")
         if len(set(s._slideDuration for s in dstreams)) > 1:
             raise ValueError("All DStreams should have same slide duration")
-        cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
-        jdstreams = SparkContext._gateway.new_array(cls, len(dstreams))
+        jdstream_cls = 
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
+        jpair_dstream_cls = 
SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
+        gw = SparkContext._gateway
+        if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
+            cls = jdstream_cls
+        elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):

Review comment:
       I don't think `JavaPairDStream` is actually used in PySpark side, if I 
am not wrong. This is just to restore the previous behaviour and be 
conservatively safe.

##########
File path: python/pyspark/context.py
##########
@@ -864,8 +865,21 @@ def union(self, rdds):
         first_jrdd_deserializer = rdds[0]._jrdd_deserializer
         if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
             rdds = [x._reserialize() for x in rdds]
-        cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
-        jrdds = SparkContext._gateway.new_array(cls, len(rdds))
+        gw = SparkContext._gateway
+        jvm = SparkContext._jvm
+        jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD
+        jpair_jrdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD
+        jdouble_jrdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD

Review comment:
       Ideally we should do something just like:
   
   ```python
   cls = JavaClass(rdds[0]._jrdd.getClass().getCanonicalName(), 
SparkContext._gateway)
   jdstreams = SparkContext._gateway.new_array(cls, len(rdds))
   ```
   
   so we forget about the signature matching in Python side but seems not 
working in Py4J side yet - it seems not respecting the parent and inheritance. 
e.g.) if we have `JavaLongRDD` extends `JavaRDD`, it doesn't work.

##########
File path: python/pyspark/context.py
##########
@@ -864,8 +865,21 @@ def union(self, rdds):
         first_jrdd_deserializer = rdds[0]._jrdd_deserializer
         if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
             rdds = [x._reserialize() for x in rdds]
-        cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
-        jrdds = SparkContext._gateway.new_array(cls, len(rdds))
+        gw = SparkContext._gateway
+        jvm = SparkContext._jvm
+        jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD
+        jpair_jrdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD
+        jdouble_jrdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD

Review comment:
       Ideally we should do something just like:
   
   ```python
   cls = JavaClass(rdds[0]._jrdd.getClass().getCanonicalName(), 
SparkContext._gateway)
   jdstreams = SparkContext._gateway.new_array(cls, len(rdds))
   ```
   
   so we forget about the signature matching in Python side but seems not 
working in Py4J side yet - it seems not respecting the parent and inheritance. 
e.g.) if we have `JavaLongRDD` that extends `JavaRDD`, it doesn't work.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to