You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/02/27 17:33:41 UTC

[spark] branch master updated: [SPARK-27000][PYTHON] Upgrades cloudpickle to v0.8.0

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a67e842  [SPARK-27000][PYTHON] Upgrades cloudpickle to v0.8.0
a67e842 is described below

commit a67e8426e37af3a19ff44fb1203c94a85c5f23c4
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Thu Feb 28 02:33:10 2019 +0900

    [SPARK-27000][PYTHON] Upgrades cloudpickle to v0.8.0
    
    ## What changes were proposed in this pull request?
    
    After upgrading cloudpickle to 0.6.1 at https://github.com/apache/spark/pull/20691, one regression was found. Cloudpickle had a critical https://github.com/cloudpipe/cloudpickle/pull/240 for that.
    
    Basically, it currently looks existing globals would override globals shipped in a function's, meaning:
    
    **Before:**
    
    ```python
    >>> def hey():
    ...     return "Hi"
    ...
    >>> spark.range(1).rdd.map(lambda _: hey()).collect()
    ['Hi']
    >>> def hey():
    ...     return "Yeah"
    ...
    >>> spark.range(1).rdd.map(lambda _: hey()).collect()
    ['Hi']
    ```
    
    **After:**
    
    ```python
    >>> def hey():
    ...     return "Hi"
    ...
    >>> spark.range(1).rdd.map(lambda _: hey()).collect()
    ['Hi']
    >>>
    >>> def hey():
    ...     return "Yeah"
    ...
    >>> spark.range(1).rdd.map(lambda _: hey()).collect()
    ['Yeah']
    ```
    
    Therefore, this PR upgrades cloudpickle to 0.8.0.
    
    Note that cloudpickle's release cycle is quite short.
    
    Between 0.6.1 and 0.7.0, it contains minor bug fixes. I don't see notable changes to double check and/or avoid.
    
    There is virtually only this fix between 0.7.0 and 0.8.1 - other fixes are about testing.
    
    ## How was this patch tested?
    
    Manually tested, tests were added. Verified unit tests were added in cloudpickle.
    
    Closes #23904 from HyukjinKwon/SPARK-27000.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/cloudpickle.py    | 163 +++++++++++++++++++++------------------
 python/pyspark/tests/test_rdd.py |  10 +++
 2 files changed, 100 insertions(+), 73 deletions(-)

diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index bf92569..54d745c 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -42,21 +42,20 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """
 from __future__ import print_function
 
-import io
 import dis
-import sys
-import types
+from functools import partial
+import importlib
+import io
+import itertools
+import logging
 import opcode
+import operator
 import pickle
 import struct
-import logging
-import weakref
-import operator
-import importlib
-import itertools
+import sys
 import traceback
-from functools import partial
-
+import types
+import weakref
 
 # cloudpickle is meant for inter process communication: we expect all
 # communicating processes to run the same Python version hence we favor
@@ -64,36 +63,22 @@ from functools import partial
 DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
 
 
-if sys.version < '3':
+if sys.version_info[0] < 3:  # pragma: no branch
     from pickle import Pickler
     try:
         from cStringIO import StringIO
     except ImportError:
         from StringIO import StringIO
+    string_types = (basestring,)  # noqa
     PY3 = False
 else:
     types.ClassType = type
     from pickle import _Pickler as Pickler
     from io import BytesIO as StringIO
+    string_types = (str,)
     PY3 = True
 
 
-# Container for the global namespace to ensure consistent unpickling of
-# functions defined in dynamic modules (modules not registed in sys.modules).
-_dynamic_modules_globals = weakref.WeakValueDictionary()
-
-
-class _DynamicModuleFuncGlobals(dict):
-    """Global variables referenced by a function defined in a dynamic module
-
-    To avoid leaking references we store such context in a WeakValueDictionary
-    instance.  However instances of python builtin types such as dict cannot
-    be used directly as values in such a construct, hence the need for a
-    derived class.
-    """
-    pass
-
-
 def _make_cell_set_template_code():
     """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF
 
@@ -112,7 +97,7 @@ def _make_cell_set_template_code():
 
            return _stub
 
-        _cell_set_template_code = f()
+        _cell_set_template_code = f().__code__
 
     This function is _only_ a LOAD_FAST(arg); STORE_DEREF, but that is
     invalid syntax on Python 2. If we use this function we also don't need
@@ -127,7 +112,7 @@ def _make_cell_set_template_code():
     # NOTE: we are marking the cell variable as a free variable intentionally
     # so that we simulate an inner function instead of the outer function. This
     # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way.
-    if not PY3:
+    if not PY3:  # pragma: no branch
         return types.CodeType(
             co.co_argcount,
             co.co_nlocals,
@@ -228,14 +213,14 @@ _BUILTIN_TYPE_CONSTRUCTORS = {
 }
 
 
-if sys.version_info < (3, 4):
+if sys.version_info < (3, 4):  # pragma: no branch
     def _walk_global_ops(code):
         """
         Yield (opcode, argument number) tuples for all
         global-referencing instructions in *code*.
         """
         code = getattr(code, 'co_code', b'')
-        if not PY3:
+        if not PY3:  # pragma: no branch
             code = map(ord, code)
 
         n = len(code)
@@ -273,8 +258,6 @@ class CloudPickler(Pickler):
         if protocol is None:
             protocol = DEFAULT_PROTOCOL
         Pickler.__init__(self, file, protocol=protocol)
-        # set of modules to unpickle
-        self.modules = set()
         # map ids to dictionary. used to ensure that functions can share global env
         self.globals_ref = {}
 
@@ -294,7 +277,7 @@ class CloudPickler(Pickler):
 
     dispatch[memoryview] = save_memoryview
 
-    if not PY3:
+    if not PY3:  # pragma: no branch
         def save_buffer(self, obj):
             self.save(str(obj))
 
@@ -304,7 +287,6 @@ class CloudPickler(Pickler):
         """
         Save a module as an import
         """
-        self.modules.add(obj)
         if _is_dynamic(obj):
             self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)),
                              obj=obj)
@@ -317,7 +299,7 @@ class CloudPickler(Pickler):
         """
         Save a code object
         """
-        if PY3:
+        if PY3:  # pragma: no branch
             args = (
                 obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
                 obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
@@ -384,7 +366,6 @@ class CloudPickler(Pickler):
             lookedup_by_name = None
 
         if themodule:
-            self.modules.add(themodule)
             if lookedup_by_name is obj:
                 return self.save_global(obj, name)
 
@@ -396,7 +377,7 @@ class CloudPickler(Pickler):
         # So we pickle them here using save_reduce; have to do it differently
         # for different python versions.
         if not hasattr(obj, '__code__'):
-            if PY3:
+            if PY3:  # pragma: no branch
                 rv = obj.__reduce_ex__(self.proto)
             else:
                 if hasattr(obj, '__self__'):
@@ -434,8 +415,31 @@ class CloudPickler(Pickler):
 
     def _save_subimports(self, code, top_level_dependencies):
         """
-        Ensure de-pickler imports any package child-modules that
-        are needed by the function
+        Save submodules used by a function but not listed in its globals.
+
+        In the example below:
+
+        ```
+        import concurrent.futures
+        import cloudpickle
+
+
+        def func():
+            x = concurrent.futures.ThreadPoolExecutor
+
+
+        if __name__ == '__main__':
+            cloudpickle.dumps(func)
+        ```
+
+        the globals extracted by cloudpickle in the function's state include
+        the concurrent module, but not its submodule (here,
+        concurrent.futures), which is the module used by func.
+
+        To ensure that calling the depickled function does not raise an
+        AttributeError, this function looks for any currently loaded submodule
+        that the function uses and whose parent is present in the function
+        globals, and saves it before saving the function.
         """
 
         # check if any known dependency is an imported package
@@ -481,6 +485,17 @@ class CloudPickler(Pickler):
         # doc can't participate in a cycle with the original class.
         type_kwargs = {'__doc__': clsdict.pop('__doc__', None)}
 
+        if hasattr(obj, "__slots__"):
+            type_kwargs['__slots__'] = obj.__slots__
+            # pickle string length optimization: member descriptors of obj are
+            # created automatically from obj's __slots__ attribute, no need to
+            # save them in obj's state
+            if isinstance(obj.__slots__, string_types):
+                clsdict.pop(obj.__slots__)
+            else:
+                for k in obj.__slots__:
+                    clsdict.pop(k, None)
+
         # If type overrides __dict__ as a property, include it in the type kwargs.
         # In Python 2, we can't set this attribute after construction.
         __dict__ = clsdict.pop('__dict__', None)
@@ -639,17 +654,17 @@ class CloudPickler(Pickler):
         # save the dict
         dct = func.__dict__
 
-        base_globals = self.globals_ref.get(id(func.__globals__), None)
-        if base_globals is None:
-            # For functions defined in a well behaved module use
-            # vars(func.__module__) for base_globals. This is necessary to
-            # share the global variables across multiple pickled functions from
-            # this module.
-            if hasattr(func, '__module__') and func.__module__ is not None:
-                base_globals = func.__module__
-            else:
-                base_globals = {}
-        self.globals_ref[id(func.__globals__)] = base_globals
+        # base_globals represents the future global namespace of func at
+        # unpickling time. Looking it up and storing it in globals_ref allow
+        # functions sharing the same globals at pickling time to also
+        # share them once unpickled, at one condition: since globals_ref is
+        # an attribute of a Cloudpickler instance, and that a new CloudPickler is
+        # created each time pickle.dump or pickle.dumps is called, functions
+        # also need to be saved within the same invokation of
+        # cloudpickle.dump/cloudpickle.dumps (for example: cloudpickle.dumps([f1, f2])). There
+        # is no such limitation when using Cloudpickler.dump, as long as the
+        # multiple invokations are bound to the same Cloudpickler.
+        base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
 
         return (code, f_globals, defaults, closure, dct, base_globals)
 
@@ -699,7 +714,7 @@ class CloudPickler(Pickler):
         if obj.__self__ is None:
             self.save_reduce(getattr, (obj.im_class, obj.__name__))
         else:
-            if PY3:
+            if PY3:  # pragma: no branch
                 self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
             else:
                 self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
@@ -752,7 +767,7 @@ class CloudPickler(Pickler):
         save(stuff)
         write(pickle.BUILD)
 
-    if not PY3:
+    if not PY3:  # pragma: no branch
         dispatch[types.InstanceType] = save_inst
 
     def save_property(self, obj):
@@ -852,7 +867,7 @@ class CloudPickler(Pickler):
 
     try:               # Python 2
         dispatch[file] = save_file
-    except NameError:  # Python 3
+    except NameError:  # Python 3  # pragma: no branch
         dispatch[io.TextIOWrapper] = save_file
 
     dispatch[type(Ellipsis)] = save_ellipsis
@@ -873,6 +888,12 @@ class CloudPickler(Pickler):
 
     dispatch[logging.RootLogger] = save_root_logger
 
+    if hasattr(types, "MappingProxyType"):  # pragma: no branch
+        def save_mappingproxy(self, obj):
+            self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj)
+
+        dispatch[types.MappingProxyType] = save_mappingproxy
+
     """Special functions for Add-on libraries"""
     def inject_addons(self):
         """Plug in system. Register additional pickling functions if modules already loaded"""
@@ -1059,10 +1080,16 @@ def _fill_function(*args):
     else:
         raise ValueError('Unexpected _fill_value arguments: %r' % (args,))
 
-    # Only set global variables that do not exist.
-    for k, v in state['globals'].items():
-        if k not in func.__globals__:
-            func.__globals__[k] = v
+    # - At pickling time, any dynamic global variable used by func is
+    #   serialized by value (in state['globals']).
+    # - At unpickling time, func's __globals__ attribute is initialized by
+    #   first retrieving an empty isolated namespace that will be shared
+    #   with other functions pickled from the same original module
+    #   by the same CloudPickler instance and then updated with the
+    #   content of state['globals'] to populate the shared isolated
+    #   namespace with all the global variables that are specifically
+    #   referenced for this function.
+    func.__globals__.update(state['globals'])
 
     func.__defaults__ = state['defaults']
     func.__dict__ = state['dict']
@@ -1100,21 +1127,11 @@ def _make_skel_func(code, cell_count, base_globals=None):
         code and the correct number of cells in func_closure.  All other
         func attributes (e.g. func_globals) are empty.
     """
-    if base_globals is None:
+    # This is backward-compatibility code: for cloudpickle versions between
+    # 0.5.4 and 0.7, base_globals could be a string or None. base_globals
+    # should now always be a dictionary.
+    if base_globals is None or isinstance(base_globals, str):
         base_globals = {}
-    elif isinstance(base_globals, str):
-        base_globals_name = base_globals
-        try:
-            # First try to reuse the globals from the module containing the
-            # function. If it is not possible to retrieve it, fallback to an
-            # empty dictionary.
-            base_globals = vars(importlib.import_module(base_globals))
-        except ImportError:
-            base_globals = _dynamic_modules_globals.get(
-                    base_globals_name, None)
-            if base_globals is None:
-                base_globals = _DynamicModuleFuncGlobals()
-            _dynamic_modules_globals[base_globals_name] = base_globals
 
     base_globals['__builtins__'] = __builtins__
 
@@ -1182,7 +1199,7 @@ def _getobject(modname, attribute):
 
 """ Use copy_reg to extend global pickle definitions """
 
-if sys.version_info < (3, 4):
+if sys.version_info < (3, 4):  # pragma: no branch
     method_descriptor = type(str.upper)
 
     def _reduce_method_descriptor(obj):
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index 9e4ac6c..e789cbe 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -32,6 +32,9 @@ if sys.version_info[0] >= 3:
     xrange = range
 
 
+global_func = lambda: "Hi"
+
+
 class RDDTests(ReusedPySparkTestCase):
 
     def test_range(self):
@@ -726,6 +729,13 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
                                 seq_rdd.aggregate, 0, lambda *x: 1, stopit)
 
+    def test_overwritten_global_func(self):
+        # Regression test for SPARK-27000
+        global global_func
+        self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Hi")
+        global_func = lambda: "Yeah"
+        self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah")
+
 
 if __name__ == "__main__":
     import unittest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org