You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/11 18:06:54 UTC

[GitHub] piiswrong closed pull request #10833: Change class variables to thread local variables

piiswrong closed pull request #10833: Change class variables to thread local variables
URL: https://github.com/apache/incubator-mxnet/pull/10833
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py
index 15d38f81f2e..17044ddaef0 100644
--- a/python/mxnet/attribute.py
+++ b/python/mxnet/attribute.py
@@ -18,10 +18,12 @@
 # coding: utf-8
 """Attribute scoping support for symbolic API."""
 from __future__ import absolute_import
+import threading
+import warnings
 
-from .base import string_types
+from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass
 
-class AttrScope(object):
+class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
     """Attribute manager for scoping.
 
     User can also inherit this object to change naming behavior.
@@ -31,7 +33,7 @@ class AttrScope(object):
     kwargs
         The attributes to set for all symbol creations in the scope.
     """
-    current = None
+    _current = threading.local()
 
     def __init__(self, **kwargs):
         self._old_scope = None
@@ -64,15 +66,35 @@ def get(self, attr):
 
     def __enter__(self):
         # pylint: disable=protected-access
-        self._old_scope = AttrScope.current
-        attr = AttrScope.current._attr.copy()
+        if not hasattr(AttrScope._current, "value"):
+            AttrScope._current.value = AttrScope()
+        self._old_scope = AttrScope._current.value
+        attr = AttrScope._current.value._attr.copy()
         attr.update(self._attr)
         self._attr = attr
-        AttrScope.current = self
+        AttrScope._current.value = self
         return self
 
     def __exit__(self, ptype, value, trace):
         assert self._old_scope
-        AttrScope.current = self._old_scope
+        AttrScope._current.value = self._old_scope
 
-AttrScope.current = AttrScope()
+    #pylint: disable=no-self-argument
+    @classproperty
+    def current(cls):
+        warnings.warn("AttrScope.current has been deprecated. "
+                      "It is advised to use the `with` statement with AttrScope.",
+                      DeprecationWarning)
+        if not hasattr(AttrScope._current, "value"):
+            cls._current.value = AttrScope()
+        return cls._current.value
+
+    @current.setter
+    def current(cls, val):
+        warnings.warn("AttrScope.current has been deprecated. "
+                      "It is advised to use the `with` statement with AttrScope.",
+                      DeprecationWarning)
+        cls._current.value = val
+    #pylint: enable=no-self-argument
+
+AttrScope._current.value = AttrScope()
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 9790e090e38..0fb73b3c7dd 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=invalid-name, no-member, trailing-comma-tuple
+# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument
 """ctypes library of mxnet and helper functions."""
 from __future__ import absolute_import
 
@@ -98,6 +98,67 @@ class MXCallbackList(ctypes.Structure):
         ('contexts', ctypes.POINTER(ctypes.c_void_p))
         ]
 
+# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
+class _MXClassPropertyDescriptor(object):
+    def __init__(self, fget, fset=None):
+        self.fget = fget
+        self.fset = fset
+
+    def __get__(self, obj, clas=None):
+        if clas is None:
+            clas = type(obj)
+        return self.fget.__get__(obj, clas)()
+
+    def __set__(self, obj, value):
+        if not self.fset:
+            raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__)
+        if inspect.isclass(obj):
+            type_ = obj
+            obj = None
+        else:
+            type_ = type(obj)
+        return self.fset.__get__(obj, type_)(value)
+
+    def setter(self, func):
+        if not isinstance(func, (classmethod, staticmethod)):
+            func = classmethod(func)
+        self.fset = func
+        return self
+
+class _MXClassPropertyMetaClass(type):
+    def __setattr__(cls, key, value):
+        if key in cls.__dict__:
+            obj = cls.__dict__.get(key)
+        if obj and isinstance(obj, _MXClassPropertyDescriptor):
+            return obj.__set__(cls, value)
+
+        return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)
+
+# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py
+#pylint: disable=unused-argument
+def with_metaclass(meta, *bases):
+    """Create a base class with a metaclass."""
+    # This requires a bit of explanation: the basic idea is to make a dummy
+    # metaclass for one level of class instantiation that replaces itself with
+    # the actual metaclass.
+    class metaclass(type):
+
+        def __new__(cls, name, this_bases, d):
+            return meta(name, bases, d)
+
+        @classmethod
+        def __prepare__(cls, name, this_bases):
+            return meta.__prepare__(name, bases)
+    return type.__new__(metaclass, 'temporary_class', (), {})
+#pylint: enable=unused-argument
+
+def classproperty(func):
+    if not isinstance(func, (classmethod, staticmethod)):
+        func = classmethod(func)
+
+    return _MXClassPropertyDescriptor(func)
+
+
 
 def _load_lib():
     """Load library by searching possible path."""
@@ -227,6 +288,7 @@ def c_str_array(strings):
         arr[:] = [s.encode('utf-8') for s in strings]
         return arr
 
+
 def c_array(ctype, values):
     """Create ctypes array from a Python array.
 
diff --git a/python/mxnet/context.py b/python/mxnet/context.py
index eb47614e333..5861890f40c 100644
--- a/python/mxnet/context.py
+++ b/python/mxnet/context.py
@@ -18,8 +18,11 @@
 # coding: utf-8
 """Context management API of mxnet."""
 from __future__ import absolute_import
+import threading
+import warnings
+from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
 
-class Context(object):
+class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
     """Constructs a context.
 
     MXNet can run operations on CPU and different GPUs.
@@ -61,7 +64,7 @@ class Context(object):
     gpu(1)
     """
     # static class variable
-    default_ctx = None
+    _default_ctx = threading.local()
     devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
     devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
     def __init__(self, device_type, device_id=0):
@@ -109,15 +112,37 @@ def __repr__(self):
         return self.__str__()
 
     def __enter__(self):
-        self._old_ctx = Context.default_ctx
-        Context.default_ctx = self
+        if not hasattr(Context._default_ctx, "value"):
+            Context._default_ctx.value = Context('cpu', 0)
+        self._old_ctx = Context._default_ctx.value
+        Context._default_ctx.value = self
         return self
 
     def __exit__(self, ptype, value, trace):
-        Context.default_ctx = self._old_ctx
+        Context._default_ctx.value = self._old_ctx
+
+    #pylint: disable=no-self-argument
+    @classproperty
+    def default_ctx(cls):
+        warnings.warn("Context.default_ctx has been deprecated. "
+                      "Please use Context.current_context() instead. "
+                      "Please use test_utils.set_default_context to set a default context",
+                      DeprecationWarning)
+        if not hasattr(Context._default_ctx, "value"):
+            cls._default_ctx.value = Context('cpu', 0)
+        return cls._default_ctx.value
+
+    @default_ctx.setter
+    def default_ctx(cls, val):
+        warnings.warn("Context.default_ctx has been deprecated. "
+                      "Please use Context.current_context() instead. "
+                      "Please use test_utils.set_default_context to set a default context",
+                      DeprecationWarning)
+        cls._default_ctx.value = val
+    #pylint: enable=no-self-argument
 
 # initialize the default context in Context
-Context.default_ctx = Context('cpu', 0)
+Context._default_ctx.value = Context('cpu', 0)
 
 
 def cpu(device_id=0):
@@ -234,4 +259,6 @@ def current_context():
     -------
     default_ctx : Context
     """
-    return Context.default_ctx
+    if not hasattr(Context._default_ctx, "value"):
+        Context._default_ctx.value = Context('cpu', 0)
+    return Context._default_ctx.value
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index abc474850f2..7e4127250a0 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -20,6 +20,7 @@
 """Base container class for all neural network models."""
 __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 
+import threading
 import copy
 import warnings
 import re
@@ -35,7 +36,7 @@
 
 class _BlockScope(object):
     """Scope for collecting child `Block` s."""
-    _current = None
+    _current = threading.local()
 
     def __init__(self, block):
         self._block = block
@@ -46,10 +47,10 @@ def __init__(self, block):
     @staticmethod
     def create(prefix, params, hint):
         """Creates prefix and params for new `Block`."""
-        current = _BlockScope._current
+        current = getattr(_BlockScope._current, "value", None)
         if current is None:
             if prefix is None:
-                prefix = _name.NameManager.current.get(None, hint) + '_'
+                prefix = _name.NameManager._current.value.get(None, hint) + '_'
             if params is None:
                 params = ParameterDict(prefix)
             else:
@@ -70,8 +71,8 @@ def create(prefix, params, hint):
     def __enter__(self):
         if self._block._empty_prefix:
             return self
-        self._old_scope = _BlockScope._current
-        _BlockScope._current = self
+        self._old_scope = getattr(_BlockScope._current, "value", None)
+        _BlockScope._current.value = self
         self._name_scope = _name.Prefix(self._block.prefix)
         self._name_scope.__enter__()
         return self
@@ -81,7 +82,7 @@ def __exit__(self, ptype, value, trace):
             return
         self._name_scope.__exit__(ptype, value, trace)
         self._name_scope = None
-        _BlockScope._current = self._old_scope
+        _BlockScope._current.value = self._old_scope
 
 
 def _flatten(args, inout_str):
diff --git a/python/mxnet/name.py b/python/mxnet/name.py
index 966d38280ef..4149d1db273 100644
--- a/python/mxnet/name.py
+++ b/python/mxnet/name.py
@@ -18,13 +18,16 @@
 # coding: utf-8
 """Automatic naming support for symbolic API."""
 from __future__ import absolute_import
+import threading
+import warnings
+from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
 
-class NameManager(object):
+class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)):
     """NameManager to do automatic naming.
 
     Developers can also inherit from this class to change naming behavior.
     """
-    current = None
+    _current = threading.local()
 
     def __init__(self):
         self._counter = {}
@@ -62,14 +65,30 @@ def get(self, name, hint):
         return name
 
     def __enter__(self):
-        self._old_manager = NameManager.current
-        NameManager.current = self
+        if not hasattr(NameManager._current, "value"):
+            NameManager._current.value = NameManager()
+        self._old_manager = NameManager._current.value
+        NameManager._current.value = self
         return self
 
     def __exit__(self, ptype, value, trace):
         assert self._old_manager
-        NameManager.current = self._old_manager
-
+        NameManager._current.value = self._old_manager
+
+    #pylint: disable=no-self-argument
+    @classproperty
+    def current(cls):
+        warnings.warn("NameManager.current has been deprecated. "
+                      "It is advised to use the `with` statement with NameManager.",
+                      DeprecationWarning)
+        if not hasattr(NameManager._current, "value"):
+            cls._current.value = NameManager()
+        return cls._current.value
+
+    @current.setter
+    def current(cls, val):
+        cls._current.value = val
+    #pylint: enable=no-self-argument
 
 class Prefix(NameManager):
     """A name manager that attaches a prefix to all names.
@@ -92,4 +111,4 @@ def get(self, name, hint):
         return self._prefix + name
 
 # initialize the default name manager
-NameManager.current = NameManager()
+NameManager._current.value = NameManager()
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7bfb3c79b35..007b3c82def 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -37,7 +37,7 @@
 from ..base import c_array, c_array_buf, c_handle_array, mx_real_t
 from ..base import mx_uint, NDArrayHandle, check_call
 from ..base import ctypes2buffer
-from ..context import Context
+from ..context import Context, current_context
 from . import _internal
 from . import op
 from ._internal import NDArrayBase
@@ -2261,7 +2261,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs):
         The shape of the empty array.
     ctx : Context, optional
         An optional device context.
-        Defaults to the current default context (``mxnet.Context.default_ctx``).
+        Defaults to the current default context (``mxnet.context.current_context()``).
     dtype : str or numpy.dtype, optional
         An optional value type (default is `float32`).
     out : NDArray, optional
@@ -2283,7 +2283,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs):
     """
     # pylint: disable= unused-argument
     if ctx is None:
-        ctx = Context.default_ctx
+        ctx = current_context()
     dtype = mx_real_t if dtype is None else dtype
     # pylint: disable= no-member, protected-access
     return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
@@ -2439,7 +2439,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
     array([2, 2, 2, 4, 4, 4], dtype=int32)
     """
     if ctx is None:
-        ctx = Context.default_ctx
+        ctx = current_context()
     return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
                              dtype=dtype, ctx=str(ctx))
 # pylint: enable= no-member, protected-access, too-many-arguments
@@ -3666,7 +3666,7 @@ def zeros(shape, ctx=None, dtype=None, **kwargs):
     """
     # pylint: disable= unused-argument
     if ctx is None:
-        ctx = Context.default_ctx
+        ctx = current_context()
     dtype = mx_real_t if dtype is None else dtype
     # pylint: disable= no-member, protected-access
     return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
@@ -3705,7 +3705,7 @@ def eye(N, M=0, k=0, ctx=None, dtype=None, **kwargs):
     """
     # pylint: disable= unused-argument
     if ctx is None:
-        ctx = Context.default_ctx
+        ctx = current_context()
     dtype = mx_real_t if dtype is None else dtype
     # pylint: disable= no-member, protected-access
     return _internal._eye(N=N, M=M, k=k, ctx=ctx, dtype=dtype, **kwargs)
@@ -3733,7 +3733,7 @@ def empty(shape, ctx=None, dtype=None):
     if isinstance(shape, int):
         shape = (shape, )
     if ctx is None:
-        ctx = Context.default_ctx
+        ctx = current_context()
     if dtype is None:
         dtype = mx_real_t
     return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index c7355c2e46d..9c02b8e2cf2 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -42,7 +42,7 @@
 from ..base import _LIB, numeric_types
 from ..base import c_array_buf, mx_real_t, integer_types
 from ..base import mx_uint, NDArrayHandle, check_call
-from ..context import Context
+from ..context import Context, current_context
 from . import _internal
 from . import op
 try:
@@ -977,7 +977,7 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None,
     # pylint: disable= no-member, protected-access
     storage_type = 'csr'
     # context
-    ctx = Context.default_ctx if ctx is None else ctx
+    ctx = current_context() if ctx is None else ctx
     # types
     dtype = _prepare_default_dtype(data, dtype)
     indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type
@@ -1140,7 +1140,7 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None,
     """Create a `RowSparseNDArray` based on data and indices"""
     storage_type = 'row_sparse'
     # context
-    ctx = Context.default_ctx if ctx is None else ctx
+    ctx = current_context() if ctx is None else ctx
     # types
     dtype = _prepare_default_dtype(data, dtype)
     indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type
@@ -1529,7 +1529,7 @@ def zeros(stype, shape, ctx=None, dtype=None, **kwargs):
     if stype == 'default':
         return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs)
     if ctx is None:
-        ctx = Context.default_ctx
+        ctx = current_context()
     dtype = mx_real_t if dtype is None else dtype
     if stype == 'row_sparse' or stype == 'csr':
         aux_types = _STORAGE_AUX_TYPES[stype]
@@ -1562,7 +1562,7 @@ def empty(stype, shape, ctx=None, dtype=None):
     if isinstance(shape, int):
         shape = (shape, )
     if ctx is None:
-        ctx = Context.default_ctx
+        ctx = current_context()
     if dtype is None:
         dtype = mx_real_t
     assert(stype is not None)
@@ -1603,7 +1603,7 @@ def array(source_array, ctx=None, dtype=None):
     >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2)))
     <RowSparseNDArray 3x2 @cpu(0)>
     """
-    ctx = Context.default_ctx if ctx is None else ctx
+    ctx = current_context() if ctx is None else ctx
     if isinstance(source_array, NDArray):
         assert(source_array.stype != 'default'), \
                "Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray"
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index 6f9e868e232..3e81dcf3a6c 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -113,9 +113,9 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
             dtype_name, dtype_name, dtype_name))
             code.append("""
     attr = kwargs.pop('attr', None)
-    kwargs.update(AttrScope.current.get(attr))
+    kwargs.update(AttrScope._current.value.get(attr))
     name = kwargs.pop('name', None)
-    name = NameManager.current.get(name, '%s')
+    name = NameManager._current.value.get(name, '%s')
     _ = kwargs.pop('out', None)
     keys = []
     vals = []
@@ -141,7 +141,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
 def %s(%s):"""%(func_name, ', '.join(signature)))
         if not signature_only:
             code.append("""
-    kwargs.update(AttrScope.current.get(attr))
+    kwargs.update(AttrScope._current.value.get(attr))
     sym_kwargs = dict()
     _keys = []
     _vals = []
@@ -172,7 +172,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
         _vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
 
             code.append("""
-    name = NameManager.current.get(name, '%s')
+    name = NameManager._current.value.get(name, '%s')
     return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name)"""%(
         func_name.lower(), handle.value))
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 1ab7cf87bf5..49023db2fe0 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -37,7 +37,7 @@
 from ..base import mx_uint, py_str, string_types
 from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
 from ..base import check_call, MXNetError, NotImplementedForSymbol
-from ..context import Context
+from ..context import Context, current_context
 from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
 from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from ..ndarray import _ndarray_cls
@@ -1767,7 +1767,7 @@ def eval(self, ctx=None, **kwargs):
         the result will be a list with one element.
         """
         if ctx is None:
-            ctx = Context.default_ctx
+            ctx = current_context()
         return self.bind(ctx, kwargs).forward()
 
     def reshape(self, *args, **kwargs):
@@ -2448,7 +2448,7 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
     handle = SymbolHandle()
     check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle)))
     ret = Symbol(handle)
-    attr = AttrScope.current.get(attr)
+    attr = AttrScope._current.value.get(attr)
     attr = {} if attr is None else attr
     if shape is not None:
         attr['__shape__'] = str(shape)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index aa388c14ea1..bcdcc9c6408 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -44,7 +44,7 @@
     # in rare cases requests may be not installed
     pass
 import mxnet as mx
-from .context import Context
+from .context import Context, current_context
 from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from .ndarray import array
 from .symbol import Symbol
@@ -54,12 +54,12 @@ def default_context():
     """Get default context for regression test."""
     # _TODO: get context from environment variable to support
     # testing with GPUs
-    return Context.default_ctx
+    return current_context()
 
 
 def set_default_context(ctx):
     """Set default context."""
-    Context.default_ctx = ctx
+    Context._default_ctx.value = ctx
 
 
 def default_dtype():
diff --git a/tests/nightly/test_tlocal_racecondition.py b/tests/nightly/test_tlocal_racecondition.py
new file mode 100644
index 00000000000..d43c45937c0
--- /dev/null
+++ b/tests/nightly/test_tlocal_racecondition.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet import gluon
+from mxnet import image
+from mxnet import nd
+import numpy as np
+import logging
+
+logger = logging.getLogger()
+logger.setLevel(logging.DEBUG)
+
+root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/'
+            'gluon/dataset/pikachu/')
+data_dir = './data/pikachu/'
+dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',
+          'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',
+          'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}
+for k, v in dataset.items():
+    gluon.utils.download(root_url+k, data_dir+k, sha1_hash=v)
+
+T = 1
+devs = [mx.gpu(i) for i in range(4)]
+data_shape = 224 * T
+batch_size = 20 * len(devs)
+rgb_mean = np.array([1,2,3])
+
+class_names = ['pikachu']
+num_class = len(class_names)
+
+def get_iterators(data_shape, batch_size):
+    train_iter = image.ImageDetIter(
+        batch_size=batch_size,
+        data_shape=(3, data_shape, data_shape),
+        path_imgrec=data_dir+'train.rec',
+        path_imgidx=data_dir+'train.idx',
+        shuffle=True,
+        mean=True,
+        rand_crop=1,
+        min_object_covered=0.95,
+        max_attempts=200)
+    val_iter = image.ImageDetIter(
+        batch_size=batch_size,
+        data_shape=(3, data_shape, data_shape),
+        path_imgrec=data_dir+'val.rec',
+        shuffle=False,
+        mean=True)
+    return train_iter, val_iter, class_names, num_class
+
+train_data, test_data, class_names, num_class = get_iterators(
+    data_shape, batch_size)
+
+
+class MyCustom(mx.operator.CustomOp):
+    def __init__(self):
+        super(MyCustom, self).__init__()
+    def forward(self, is_train, req, in_data, out_data, aux):
+        self.assign(out_data[0], req[0], 0)
+    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+        self.assign(in_grad[0], req[0], 0)
+        self.assign(in_grad[1], req[1], 0)
+
+@mx.operator.register("MyCustom")
+class MyCustomProp(mx.operator.CustomOpProp):
+    def __init__(self):
+        super(MyCustomProp, self).__init__(need_top_grad = False)
+    def list_arguments(self):
+        return ["data", "label"]
+    def list_outputs(self):
+        return ["loss"]
+    def infer_shape(self, in_shape):
+        return [in_shape[0], in_shape[1]], [(1, )], []
+    def infer_type(self, in_type):
+        dtype = in_type[0]
+        return [dtype, dtype], [dtype], []
+    def create_operator(self, ctx, shapes, dtypes):
+        return MyCustom()
+
+class MyMetric(mx.metric.EvalMetric):
+    def __init__(self):
+        super(MyMetric, self).__init__("MyMetric")
+        self.name = ['empty']
+    def update(self, labels, preds):
+        pass
+    def get(self):
+        return self.name, [0]
+
+if __name__ == '__main__':
+    x = mx.sym.Variable("data")
+    label = mx.sym.Variable("label")
+    x = mx.sym.FullyConnected(data = x, num_hidden = 100)
+    label = mx.sym.Reshape(data = label, shape = (0, -1))
+    sym = mx.sym.Custom(data = x, label = label, op_type = "MyCustom")
+    model = mx.module.Module(context = devs, symbol = sym, data_names = ('data',), label_names = ('label',))
+    model.fit(train_data = train_data, begin_epoch = 0, num_epoch = 20, allow_missing = True, batch_end_callback = mx.callback.Speedometer(batch_size, 5), eval_metric = MyMetric())
diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py
index 800426c035b..5618e11a040 100644
--- a/tests/python/unittest/test_contrib_operator.py
+++ b/tests/python/unittest/test_contrib_operator.py
@@ -42,7 +42,7 @@ def test_box_nms_backward(data, grad, expected, thresh=0.5, topk=-1, coord=2, sc
         op = mx.contrib.sym.box_nms(in_var, overlap_thresh=thresh, topk=topk,
                                 coord_start=coord, score_index=score, id_index=cid,
                                 force_suppress=force, in_format=in_format, out_format=out_format)
-        exe = op.bind(ctx=mx.context.Context.default_ctx, args=[arr_data], args_grad=[arr_grad])
+        exe = op.bind(ctx=default_context(), args=[arr_data], args_grad=[arr_grad])
         exe.forward(is_train=True)
         exe.backward(mx.nd.array(grad))
         assert_almost_equal(arr_grad.asnumpy(), expected)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1db836b0918..b7c5e49cda0 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3714,7 +3714,7 @@ def test_tile_backward():
         reps2 = 2
         reps = (reps1, reps2)
         test = mx.sym.tile(data, reps=reps)
-        exe = test.bind(ctx=mx.context.Context.default_ctx, args=[arr_data], args_grad=[arr_grad])
+        exe = test.bind(ctx=default_context(), args=[arr_data], args_grad=[arr_grad])
         npout_grad = np.random.randint(0, 10, n1 * n2 * reps1 * reps2).reshape(n1 * reps1, n2 * reps2)
         out_grad = mx.nd.array(npout_grad)
         exe.backward(out_grad)
@@ -4421,7 +4421,7 @@ def test_psroipooling():
                                                      output_dim=num_classes, name='test_op')
                     rtol, atol = 1e-2, 1e-3
                     # By now we only have gpu implementation
-                    if mx.Context.default_ctx.device_type == 'gpu':
+                    if default_context().device_type == 'gpu':
                         check_numeric_gradient(op, [im_data, rois_data], rtol=rtol, atol=atol,
                                                grad_nodes=grad_nodes, ctx=mx.gpu(0))
 
@@ -4459,7 +4459,7 @@ def test_deformable_convolution():
                         else:
                             rtol, atol = 0.05, 1e-3
                         # By now we only have gpu implementation
-                        if mx.Context.default_ctx.device_type == 'gpu':
+                        if default_context().device_type == 'gpu':
                             check_numeric_gradient(op, [im_data, offset_data, weight, bias], rtol=rtol, atol=atol,
                                                    grad_nodes=grad_nodes, ctx=mx.gpu(0))
 
@@ -4495,7 +4495,7 @@ def test_deformable_psroipooling():
                     else:
                         rtol, atol = 1e-2, 1e-3
                     # By now we only have gpu implementation
-                    if mx.Context.default_ctx.device_type == 'gpu':
+                    if default_context().device_type == 'gpu':
                         check_numeric_gradient(op, [im_data, rois_data, offset_data], rtol=rtol, atol=atol,
                                                grad_nodes=grad_nodes, ctx=mx.gpu(0))
 
diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py
new file mode 100644
index 00000000000..a571a25ab2a
--- /dev/null
+++ b/tests/python/unittest/test_thread_local.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import threading
+import mxnet as mx
+from mxnet import context, attribute, name
+from mxnet.gluon import block
+from mxnet.context import Context
+from mxnet.attribute import AttrScope
+from mxnet.name import NameManager
+from mxnet.test_utils import set_default_context
+
+def test_context():
+    ctx_list = []
+    ctx_list.append(Context.default_ctx)
+    def f():
+        set_default_context(mx.gpu(11))
+        ctx_list.append(Context.default_ctx)
+    thread = threading.Thread(target=f)
+    thread.start()
+    thread.join()
+    assert Context.devtype2str[ctx_list[0].device_typeid] == "cpu"
+    assert ctx_list[0].device_id == 0
+    assert Context.devtype2str[ctx_list[1].device_typeid] == "gpu"
+    assert ctx_list[1].device_id == 11
+
+    event = threading.Event()
+    status = [False]
+    def g():
+        with mx.cpu(10):
+            event.wait()
+            if Context.default_ctx.device_id == 10:
+                status[0] = True
+    thread = threading.Thread(target=g)
+    thread.start()
+    Context.default_ctx = Context("cpu", 11)
+    event.set()
+    thread.join()
+    event.clear()
+    assert status[0], "Spawned thread didn't set the correct context"
+
+def test_attrscope():
+    attrscope_list = []
+    AttrScope.current = AttrScope(y="hi", z="hey")
+    attrscope_list.append(AttrScope.current)
+    def f():
+        AttrScope.current = AttrScope(x="hello")
+        attrscope_list.append(AttrScope.current)
+    thread = threading.Thread(target=f)
+    thread.start()
+    thread.join()
+    assert len(attrscope_list[0]._attr) == 2
+    assert attrscope_list[1]._attr["x"] == "hello"
+
+    event = threading.Event()
+    status = [False]
+    def g():
+        with mx.AttrScope(x="hello"):
+            event.wait()
+            if "hello" in AttrScope.current._attr.values():
+                status[0] = True
+    thread = threading.Thread(target=g)
+    thread.start()
+    AttrScope.current = AttrScope(x="hi")
+    event.set()
+    thread.join()
+    AttrScope.current = AttrScope()
+    event.clear()
+    assert status[0], "Spawned thread didn't set the correct attr key values"
+
+def test_name():
+    name_list = []
+    NameManager.current = NameManager()
+    NameManager.current.get(None, "main_thread")
+    name_list.append(NameManager.current)
+    def f():
+        NameManager.current = NameManager()
+        NameManager.current.get(None, "spawned_thread")
+        name_list.append(NameManager.current)
+    thread = threading.Thread(target=f)
+    thread.start()
+    thread.join()
+    assert "main_thread" in name_list[0]._counter, "cannot find the string `main thread` in name_list[0]._counter"
+    assert "spawned_thread" in name_list[1]._counter, "cannot find the string `spawned thread` in name_list[1]._counter"
+
+    event = threading.Event()
+    status = [False]
+    def g():
+        with NameManager():
+            if "main_thread" not in NameManager.current._counter:
+                status[0] = True
+    thread = threading.Thread(target=g)
+    thread.start()
+    NameManager.current = NameManager()
+    NameManager.current.get(None, "main_thread")
+    event.set()
+    thread.join()
+    event.clear()
+    assert status[0], "Spawned thread isn't using thread local NameManager"
+
+def test_blockscope():
+    class dummy_block(object):
+        def __init__(self, prefix):
+            self.prefix = prefix
+            self._empty_prefix = False
+    blockscope_list = []
+    status = [False]
+    event = threading.Event()
+    def f():
+        with block._BlockScope(dummy_block("spawned_")):
+            x= NameManager.current.get(None, "hello")
+            event.wait()
+            if x == "spawned_hello0":
+                status[0] = True
+    thread = threading.Thread(target=f)
+    thread.start()
+    block._BlockScope.create("main_thread", None, "hi")
+    event.set()
+    thread.join()
+    event.clear()
+    assert status[0], "Spawned thread isn't using the correct blockscope namemanager"
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services