You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/08/03 01:36:07 UTC

spark git commit: [SPARK-12717][PYTHON][BRANCH-2.1] Adding thread-safe broadcast pickle registry

Repository: spark
Updated Branches:
  refs/heads/branch-2.1 b31b30209 -> d93e45b8b


[SPARK-12717][PYTHON][BRANCH-2.1] Adding thread-safe broadcast pickle registry

## What changes were proposed in this pull request?

When using PySpark broadcast variables in a multi-threaded environment,  `SparkContext._pickled_broadcast_vars` becomes a shared resource.  A race condition can occur when broadcast variables that are pickled from one thread get added to the shared ` _pickled_broadcast_vars` and become part of the python command from another thread.  This PR introduces a thread-safe pickled registry using thread local storage so that when python command is pickled (causing the broadcast variable to be pickled and added to the registry) each thread will have their own view of the pickle registry to retrieve and clear the broadcast variables used.

## How was this patch tested?

Added a unit test that causes this race condition using another thread.

Author: Bryan Cutler <cu...@gmail.com>

Closes #18825 from BryanCutler/pyspark-bcast-threadsafe-SPARK-12717-2_1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d93e45b8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d93e45b8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d93e45b8

Branch: refs/heads/branch-2.1
Commit: d93e45b8bad6efd34ed7c03b2602df35788961a4
Parents: b31b302
Author: Bryan Cutler <cu...@gmail.com>
Authored: Thu Aug 3 10:35:56 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Thu Aug 3 10:35:56 2017 +0900

----------------------------------------------------------------------
 python/pyspark/broadcast.py | 19 +++++++++++++++++
 python/pyspark/context.py   |  4 ++--
 python/pyspark/tests.py     | 44 ++++++++++++++++++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d93e45b8/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 74dee14..8f9b42e 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -19,6 +19,7 @@ import os
 import sys
 import gc
 from tempfile import NamedTemporaryFile
+import threading
 
 from pyspark.cloudpickle import print_exec
 
@@ -137,6 +138,24 @@ class Broadcast(object):
         return _from_id, (self._jbroadcast.id(),)
 
 
+class BroadcastPickleRegistry(threading.local):
+    """ Thread-local registry for broadcast variables that have been pickled
+    """
+
+    def __init__(self):
+        self.__dict__.setdefault("_registry", set())
+
+    def __iter__(self):
+        for bcast in self._registry:
+            yield bcast
+
+    def add(self, bcast):
+        self._registry.add(bcast)
+
+    def clear(self):
+        self._registry.clear()
+
+
 if __name__ == "__main__":
     import doctest
     (failure_count, test_count) = doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/d93e45b8/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ac4b2b0..5a4c2fa 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -30,7 +30,7 @@ from py4j.protocol import Py4JError
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
-from pyspark.broadcast import Broadcast
+from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
@@ -200,7 +200,7 @@ class SparkContext(object):
         # This allows other code to determine which Broadcast instances have
         # been pickled, so it can determine which Java broadcast objects to
         # send.
-        self._pickled_broadcast_vars = set()
+        self._pickled_broadcast_vars = BroadcastPickleRegistry()
 
         SparkFiles._sc = self
         root_dir = SparkFiles.getRootDirectory()

http://git-wip-us.apache.org/repos/asf/spark/blob/d93e45b8/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 8d227ea..25ed127 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -793,6 +793,50 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEqual(N, size)
         self.assertEqual(checksum, csum)
 
+    def test_multithread_broadcast_pickle(self):
+        import threading
+
+        b1 = self.sc.broadcast(list(range(3)))
+        b2 = self.sc.broadcast(list(range(3)))
+
+        def f1():
+            return b1.value
+
+        def f2():
+            return b2.value
+
+        funcs_num_pickled = {f1: None, f2: None}
+
+        def do_pickle(f, sc):
+            command = (f, None, sc.serializer, sc.serializer)
+            ser = CloudPickleSerializer()
+            ser.dumps(command)
+
+        def process_vars(sc):
+            broadcast_vars = list(sc._pickled_broadcast_vars)
+            num_pickled = len(broadcast_vars)
+            sc._pickled_broadcast_vars.clear()
+            return num_pickled
+
+        def run(f, sc):
+            do_pickle(f, sc)
+            funcs_num_pickled[f] = process_vars(sc)
+
+        # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
+        do_pickle(f1, self.sc)
+
+        # run all for f2, should only add/count/clear b2 from worker thread local storage
+        t = threading.Thread(target=run, args=(f2, self.sc))
+        t.start()
+        t.join()
+
+        # count number of vars pickled in main thread, only b1 should be counted and cleared
+        funcs_num_pickled[f1] = process_vars(self.sc)
+
+        self.assertEqual(funcs_num_pickled[f1], 1)
+        self.assertEqual(funcs_num_pickled[f2], 1)
+        self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)
+
     def test_large_closure(self):
         N = 200000
         data = [float(i) for i in xrange(N)]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org