You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/06/01 00:47:38 UTC
[spark] branch branch-3.0 updated:
[SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for
different types of RDD and DStreams
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new bd7f5da [SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for different types of RDD and DStreams
bd7f5da is described below
commit bd7f5da3dfa0ce3edda0c9864cd0f89db744277f
Author: HyukjinKwon <gu...@apache.org>
AuthorDate: Mon Jun 1 09:43:03 2020 +0900
[SPARK-31788][CORE][DSTREAM][PYTHON] Recover the support of union for different types of RDD and DStreams
### What changes were proposed in this pull request?
This PR manually specifies the class for the input array being used in `(SparkContext|StreamingContext).union`. It fixes a regression introduced from SPARK-25737.
```python
rdd1 = sc.parallelize([1,2,3,4,5])
rdd2 = sc.parallelize([6,7,8,9,10])
pairRDD1 = rdd1.zip(rdd2)
sc.union([pairRDD1, pairRDD1]).collect()
```
in the current master and `branch-3.0`:
```
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/context.py", line 870, in union
jrdds[i] = rdds[i]._jrdd
File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 238, in __setitem__
File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 221, in __set_item
File "/.../spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 332, in get_return_value
py4j.protocol.Py4JError: An error occurred while calling None.None. Trace:
py4j.Py4JException: Cannot convert org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD
at py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166)
at py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144)
at py4j.commands.ArrayCommand.execute(ArrayCommand.java:97)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
```
which works in Spark 2.4.5:
```
[(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10)]
```
It assumed the class of the input array is the same `JavaRDD` or `JavaDStream`; however, that can be different such as `JavaPairRDD`.
This fix is based on redsanket's initial approach, and will be co-authored.
### Why are the changes needed?
To fix a regression from Spark 2.4.5.
### Does this PR introduce _any_ user-facing change?
No, it's only in unreleased branches. This is to fix a regression.
### How was this patch tested?
Manually tested, and a unittest was added.
Closes #28648 from HyukjinKwon/SPARK-31788.
Authored-by: HyukjinKwon <gu...@apache.org>
Signed-off-by: HyukjinKwon <gu...@apache.org>
(cherry picked from commit 29c51d682b3735123f78cf9cb8610522a9bb86fd)
Signed-off-by: HyukjinKwon <gu...@apache.org>
---
python/pyspark/context.py | 18 ++++++++++++++++--
python/pyspark/streaming/context.py | 15 ++++++++++++---
python/pyspark/tests/test_rdd.py | 11 +++++++++++
3 files changed, 39 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index d5f1506..81b6caa 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -25,6 +25,7 @@ from threading import RLock
from tempfile import NamedTemporaryFile
from py4j.protocol import Py4JError
+from py4j.java_gateway import is_instance_of
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@@ -864,8 +865,21 @@ class SparkContext(object):
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
rdds = [x._reserialize() for x in rdds]
- cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
- jrdds = SparkContext._gateway.new_array(cls, len(rdds))
+ gw = SparkContext._gateway
+ jvm = SparkContext._jvm
+ jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD
+ jpair_rdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD
+ jdouble_rdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD
+ if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls):
+ cls = jrdd_cls
+ elif is_instance_of(gw, rdds[0]._jrdd, jpair_rdd_cls):
+ cls = jpair_rdd_cls
+ elif is_instance_of(gw, rdds[0]._jrdd, jdouble_rdd_cls):
+ cls = jdouble_rdd_cls
+ else:
+ cls_name = rdds[0]._jrdd.getClass().getCanonicalName()
+ raise TypeError("Unsupported Java RDD class %s" % cls_name)
+ jrdds = gw.new_array(cls, len(rdds))
for i in range(0, len(rdds)):
jrdds[i] = rdds[i]._jrdd
return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 769121c..6199611 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -17,7 +17,7 @@
from __future__ import print_function
-from py4j.java_gateway import java_import
+from py4j.java_gateway import java_import, is_instance_of
from pyspark import RDD, SparkConf
from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer
@@ -341,8 +341,17 @@ class StreamingContext(object):
raise ValueError("All DStreams should have same serializer")
if len(set(s._slideDuration for s in dstreams)) > 1:
raise ValueError("All DStreams should have same slide duration")
- cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
- jdstreams = SparkContext._gateway.new_array(cls, len(dstreams))
+ jdstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaDStream
+ jpair_dstream_cls = SparkContext._jvm.org.apache.spark.streaming.api.java.JavaPairDStream
+ gw = SparkContext._gateway
+ if is_instance_of(gw, dstreams[0]._jdstream, jdstream_cls):
+ cls = jdstream_cls
+ elif is_instance_of(gw, dstreams[0]._jdstream, jpair_dstream_cls):
+ cls = jpair_dstream_cls
+ else:
+ cls_name = dstreams[0]._jdstream.getClass().getCanonicalName()
+ raise TypeError("Unsupported Java DStream class %s" % cls_name)
+ jdstreams = gw.new_array(cls, len(dstreams))
for i in range(0, len(dstreams)):
jdstreams[i] = dstreams[i]._jdstream
return DStream(self._jssc.union(jdstreams), self, dstreams[0]._jrdd_deserializer)
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index e2d910c..cf58220 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -166,6 +166,17 @@ class RDDTests(ReusedPySparkTestCase):
set([(x, (x, x)) for x in 'abc'])
)
+ def test_union_pair_rdd(self):
+ # SPARK-31788: test if pair RDDs can be combined by union.
+ rdd = self.sc.parallelize([1, 2])
+ pair_rdd = rdd.zip(rdd)
+ unionRDD = self.sc.union([pair_rdd, pair_rdd])
+ self.assertEqual(
+ set(unionRDD.collect()),
+ set([(1, 1), (2, 2), (1, 1), (2, 2)])
+ )
+ self.assertEqual(unionRDD.count(), 4)
+
def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org