You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2014/09/19 03:11:54 UTC

git commit: [SPARK-3554] [PySpark] use broadcast automatically for large closure

Repository: spark
Updated Branches:
  refs/heads/master 9306297d1 -> e77fa81a6


[SPARK-3554] [PySpark] use broadcast automatically for large closure

Py4j can not handle large string efficiently, so we should use broadcast for large closure automatically. (Broadcast use local filesystem to pass through data).

Author: Davies Liu <da...@gmail.com>

Closes #2417 from davies/command and squashes the following commits:

fbf4e97 [Davies Liu] bugfix
aefd508 [Davies Liu] use broadcast automatically for large closure


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e77fa81a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e77fa81a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e77fa81a

Branch: refs/heads/master
Commit: e77fa81a61798c89d5a9b6c9dc067d11785254b7
Parents: 9306297
Author: Davies Liu <da...@gmail.com>
Authored: Thu Sep 18 18:11:48 2014 -0700
Committer: Josh Rosen <jo...@apache.org>
Committed: Thu Sep 18 18:11:48 2014 -0700

----------------------------------------------------------------------
 python/pyspark/rdd.py    | 4 ++++
 python/pyspark/sql.py    | 8 ++++++--
 python/pyspark/tests.py  | 6 ++++++
 python/pyspark/worker.py | 4 +++-
 4 files changed, 19 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e77fa81a/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cb09c19..b43606b 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2061,8 +2061,12 @@ class PipelinedRDD(RDD):
             self._jrdd_deserializer = NoOpSerializer()
         command = (self.func, self._prev_jrdd_deserializer,
                    self._jrdd_deserializer)
+        # the serialized command will be compressed by broadcast
         ser = CloudPickleSerializer()
         pickled_command = ser.dumps(command)
+        if pickled_command > (1 << 20):  # 1M
+            broadcast = self.ctx.broadcast(pickled_command)
+            pickled_command = ser.dumps(broadcast)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
             self.ctx._gateway._gateway_client)

http://git-wip-us.apache.org/repos/asf/spark/blob/e77fa81a/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 8f6dbab..42a9920 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -27,7 +27,7 @@ import warnings
 from array import array
 from operator import itemgetter
 
-from pyspark.rdd import RDD, PipelinedRDD
+from pyspark.rdd import RDD
 from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
@@ -975,7 +975,11 @@ class SQLContext(object):
         command = (func,
                    BatchedSerializer(PickleSerializer(), 1024),
                    BatchedSerializer(PickleSerializer(), 1024))
-        pickled_command = CloudPickleSerializer().dumps(command)
+        ser = CloudPickleSerializer()
+        pickled_command = ser.dumps(command)
+        if pickled_command > (1 << 20):  # 1M
+            broadcast = self._sc.broadcast(pickled_command)
+            pickled_command = ser.dumps(broadcast)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
             self._sc._gateway._gateway_client)

http://git-wip-us.apache.org/repos/asf/spark/blob/e77fa81a/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 0b38543..7301966 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -434,6 +434,12 @@ class TestRDDFunctions(PySparkTestCase):
         m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
         self.assertEquals(N, m)
 
+    def test_large_closure(self):
+        N = 1000000
+        data = [float(i) for i in xrange(N)]
+        m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
+        self.assertEquals(N, m)
+
     def test_zip_with_different_serializers(self):
         a = self.sc.parallelize(range(5))
         b = self.sc.parallelize(range(100, 105))

http://git-wip-us.apache.org/repos/asf/spark/blob/e77fa81a/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 252176a..d6c06e2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -77,10 +77,12 @@ def main(infile, outfile):
                 _broadcastRegistry[bid] = Broadcast(bid, value)
             else:
                 bid = - bid - 1
-                _broadcastRegistry.remove(bid)
+                _broadcastRegistry.pop(bid)
 
         _accumulatorRegistry.clear()
         command = pickleSer._read_with_length(infile)
+        if isinstance(command, Broadcast):
+            command = pickleSer.loads(command.value)
         (func, deserializer, serializer) = command
         init_time = time.time()
         iterator = deserializer.load_stream(infile)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org