You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/06/01 00:47:38 UTC

[spark] branch branch-3.0 updated: [SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for different types of RDD and DStreams

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new bd7f5da  [SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for different types of RDD and DStreams
bd7f5da is described below

commit bd7f5da3dfa0ce3edda0c9864cd0f89db744277f
Author: HyukjinKwon <gu...@apache.org>
AuthorDate: Mon Jun 1 09:43:03 2020 +0900

    [SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for different types of RDD and DStreams
    
    ### What changes were proposed in this pull request?
    
    This PR manually specifies the class for the input array being used in `(SparkContext|StreamingContext).union`. It fixes a regression introduced from SPARK-25737.
    
    ```python
    rdd1 = sc.parallelize([1,2,3,4,5])
    rdd2 = sc.parallelize([6,7,8,9,10])
    pairRDD1 = rdd1.zip(rdd2)
    sc.union([pairRDD1, pairRDD1]).collect()
    ```
    
    in the current master and `branch-3.0`:
    
    ```
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/.../spark/python/pyspark/context.py", line 870, in union
        jrdds[i] = rdds[i]._jrdd
      File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 238, in __setitem__
      File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 221, in __set_item
      File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 332, in get_return_value
    py4j.protocol.Py4JError: An error occurred while calling None.None. Trace:
    py4j.Py4JException: Cannot convert org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD
    	at py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166)
    	at py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144)
    	at py4j.commands.ArrayCommand.execute(ArrayCommand.java:97)
    	at py4j.GatewayConnection.run(GatewayConnection.java:238)
    	at java.lang.Thread.run(Thread.java:748)
    ```
    
    which works in Spark 2.4.5:
    
    ```
    [(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10)]
    ```
    
    It assumed the class of the input array is the same `JavaRDD` or `JavaDStream`; however, that can be different such as `JavaPairRDD`.
    
    This fix is based on redsanket's initial approach, and will be co-authored.
    
    ### Why are the changes needed?
    
    To fix a regression from Spark 2.4.5.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's only in unreleased branches. This is to fix a regression.
    
    ### How was this patch tested?
    
    Manually tested, and a unittest was added.
    
    Closes #28648 from HyukjinKwon/SPARK-31788.
    
    Authored-by: HyukjinKwon <gu...@apache.org>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
    (cherry picked from commit 29c51d682b3735123f78cf9cb8610522a9bb86fd)
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 python/pyspark/context.py           | 18 ++++++++++++++++--
 python/pyspark/streaming/context.py | 15 ++++++++++++---
 python/pyspark/tests/test_rdd.py    | 11 +++++++++++
 3 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index d5f1506..81b6caa 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -25,6 +25,7 @@ from threading import RLock
 from tempfile import NamedTemporaryFile
 
 from py4j.protocol import Py4JError
+from py4j.java_gateway import is_instance_of
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
@@ -864,8 +865,21 @@ class SparkContext(object):
         first_jrdd_deserializer = rdds[0]._jrdd_deserializer
         if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
             rdds = [x._reserialize() for x in rdds]
-        cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
-        jrdds = SparkContext._gateway.new_array(cls, len(rdds))
+        gw = SparkContext._gateway
+        jvm = SparkContext._jvm
+        jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD
+        jpair_rdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD
+        jdouble_rdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD
+        if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls):
+            cls = jrdd_cls
+        elif is_instance_of(gw, rdds[0]._jrdd, jpair_rdd_cls):
+            cls = jpair_rdd_cls
+        elif is_instance_of(gw, rdds[0]._jrdd, jdouble_rdd_cls):
+            cls = jdouble_rdd_cls
+        else:
+            cls_name = rdds[0]._jrdd.getClass().getCanonicalName()
+            raise TypeError("Unsupported Java RDD class %s" % cls_name)
+        jrdds = gw.new_array(cls, len(rdds))
         for i in range(0, len(rdds)):
             jrdds[i] = rdds[i]._jrdd
         return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 769121c..6199611 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -17,7 +17,7 @@
 
 from __future__ import print_function
 
-from py4j.java_gateway import java_import
+from py4j.java_gateway import java_import, is_instance_of
 
 from pyspark import RDD, SparkConf
 from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer
@@ -341,8 +341,17 @@ class StreamingContext(object):
             raise ValueError("All DStreams should have same serializer")
         if len(set(s._slideDuration for s in dstreams)) > 1:
             raise ValueError("All DStreams should have same slide duration")
-        cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
-        jdstreams = SparkContext._gateway.new_array(cls, len(dstreams))
+        jdstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
+        jpair_dstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
+        gw = SparkContext._gateway
+        if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
+            cls = jdstream_cls
+        elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):
+            cls = jpair_dstream_cls
+        else:
+            cls_name = dstreams[0]._jdstream.getClass().getCanonicalName()
+            raise TypeError("Unsupported Java DStream class %s" % cls_name)
+        jdstreams = gw.new_array(cls, len(dstreams))
         for i in range(0, len(dstreams)):
             jdstreams[i] = dstreams[i]._jdstream
         return DStream(self._jssc.union(jdstreams), self, dstreams[0]._jrdd_deserializer)
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index e2d910c..cf58220 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -166,6 +166,17 @@ class RDDTests(ReusedPySparkTestCase):
             set([(x, (x, x)) for x in 'abc'])
         )
 
+    def test_union_pair_rdd(self):
+        # SPARK-31788: test if pair RDDs can be combined by union.
+        rdd = self.sc.parallelize([1, 2])
+        pair_rdd = rdd.zip(rdd)
+        unionRDD = self.sc.union([pair_rdd, pair_rdd])
+        self.assertEqual(
+            set(unionRDD.collect()),
+            set([(1, 1), (2, 2), (1, 1), (2, 2)])
+        )
+        self.assertEqual(unionRDD.count(), 4)
+
     def test_deleting_input_files(self):
         # Regression test for SPARK-1025
         tempFile = tempfile.NamedTemporaryFile(delete=False)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org