You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/07/31 02:55:58 UTC

[incubator-mxnet] branch master updated: add reset_ctx (#7221)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 42544ed  add reset_ctx (#7221)
42544ed is described below

commit 42544eda02d5e7ff527704b42d62c0f81faa17f6
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Sun Jul 30 19:55:55 2017 -0700

    add reset_ctx (#7221)
    
    * add reset_ctx
    
    * add paramdict
    
    * add symbolBlock
    
    * rename blocks
    
    * rename
    
    * fix
---
 nnvm                                           |   2 +-
 python/mxnet/gluon/block.py                    | 156 +++++++++++++++----
 python/mxnet/gluon/model_zoo/vision/alexnet.py |  22 +--
 python/mxnet/gluon/nn/basic_layers.py          |  25 ++--
 python/mxnet/gluon/nn/conv_layers.py           |  11 +-
 python/mxnet/gluon/parameter.py                | 199 ++++++++++++++++---------
 python/mxnet/gluon/rnn/rnn_cell.py             |  81 +++++-----
 python/mxnet/symbol.py                         |  16 +-
 src/ndarray/autograd.cc                        |   3 +-
 tests/python/unittest/test_nn.py               |  24 +++
 10 files changed, 361 insertions(+), 178 deletions(-)

diff --git a/nnvm b/nnvm
index c96dd0e..0a45136 160000
--- a/nnvm
+++ b/nnvm
@@ -1 +1 @@
-Subproject commit c96dd0e126a788089fe700cf6effe4e87bc40e05
+Subproject commit 0a45136fae475a8313dc66b6bebd87a722f20e7f
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index e8ec12b..cfc5e57 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -2,6 +2,8 @@
 # pylint: disable= arguments-differ
 """Base container class for all neural network models."""
 
+import copy
+
 from .. import symbol, ndarray, initializer
 from ..symbol import Symbol
 from ..ndarray import NDArray
@@ -18,6 +20,7 @@ class _BlockScope(object):
         self._block = block
         self._counter = {}
         self._old_scope = None
+        self._name_scope = None
 
     @staticmethod
     def create(prefix, params, hint):
@@ -46,9 +49,13 @@ class _BlockScope(object):
     def __enter__(self):
         self._old_scope = _BlockScope._current
         _BlockScope._current = self
+        self._name_scope = _name.Prefix(self._block.prefix)
+        self._name_scope.__enter__()
         return self
 
     def __exit__(self, ptype, value, trace):
+        self._name_scope.__exit__(ptype, value, trace)
+        self._name_scope = None
         _BlockScope._current = self._old_scope
 
 
@@ -134,6 +141,7 @@ class Block(object):
     """
     def __init__(self, prefix=None, params=None):
         self._prefix, self._params = _BlockScope.create(prefix, params, self._alias())
+        self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix
         self._scope = _BlockScope(self)
         self._children = []
 
@@ -162,9 +170,7 @@ class Block(object):
     @property
     def name(self):
         """Name of this `Block`, without '_' in the end."""
-        if self.prefix.endswith('_'):
-            return self.prefix[:-1]
-        return self.prefix
+        return self._name
 
     def name_scope(self):
         """Returns a name space object managing a child `Block` and parameter
@@ -309,26 +315,26 @@ class HybridBlock(Block):
         super(HybridBlock, self).hybridize(active)
 
     def _get_graph(self, *args):
-        if self._cached_graph:
-            return self._cached_graph
+        if not self._cached_graph:
+            args, self._in_format = _flatten(args)
+            inputs = [symbol.var('input_%d'%i) for i in range(len(args))]
+            grouped_inputs = _regroup(inputs, self._in_format)[0]
 
-        args, self._in_format = _flatten(args)
-        syms = [symbol.var(str(i)) for i in range(len(args))]
-        sym_args = _regroup(syms, self._in_format)[0]
+            params = {i: j.var() for i, j in self._reg_params.items()}
+            with self.name_scope():
+                out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
+            out, self._out_format = _flatten(out)
 
-        params = {i: j.var() for i, j in self._reg_params.items()}
-        out = self.hybrid_forward(symbol, *sym_args, **params)  # pylint: disable=no-value-for-parameter
-        out, self._out_format = _flatten(out)
+            self._cached_graph = inputs, symbol.Group(out)
 
-        self._cached_graph = syms, symbol.Group(out)
         return self._cached_graph
 
     def infer_shape(self, *args):
         """Infers shape of Parameters from inputs."""
-        syms, out = self._get_graph(*args)
-        args, _, = _flatten(args)
+        inputs, out = self._get_graph(*args)
+        args, _ = _flatten(args)
         arg_shapes, _, aux_shapes = out.infer_shape(
-            **{i.name: j.shape for i, j in zip(syms, args)})
+            **{i.name: j.shape for i, j in zip(inputs, args)})
         sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
         sdict.update({name : shape for name, shape in \
                       zip(out.list_auxiliary_states(), aux_shapes)})
@@ -336,21 +342,33 @@ class HybridBlock(Block):
             i.shape = sdict[i.name]
 
     def _build_cache(self, *args):
-        self.infer_shape(*args)
-        for i in self.collect_params().values():
-            i._finish_deferred_init()
-
-        _, out = self._get_graph(*args)
+        inputs, out = self._get_graph(*args)
         self._cached_op = ndarray.CachedOp(out)
+
         params = dict(self.collect_params().items())
         self._cached_params = [params.get(name, None) for name in out.list_inputs()]
-        self._in_idx = [(i, int(name)) for i, name in enumerate(out.list_inputs())
+        assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
+            "Wrong number of inputs."
+
+        name2pos = {var.name: i for i, var in enumerate(inputs)}
+        self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
                         if name not in params]
 
     def _call_cached_op(self, *args):
+        if self._cached_op is None:
+            self._build_cache(*args)
+
+        try:
+            cargs = [i.data() if i else None for i in self._cached_params]
+        except DeferredInitializationError:
+            self.infer_shape(*args)
+            for i in self._cached_params:
+                if i is not None:
+                    i._finish_deferred_init()
+            cargs = [i.data() if i else None for i in self._cached_params]
+
         args, fmt = _flatten(args)
         assert fmt == self._in_format, "Invalid input format"
-        cargs = [i.data() if i else None for i in self._cached_params]
         for i, j in self._in_idx:
             cargs[i] = args[j]
         out = self._cached_op(*cargs)
@@ -362,9 +380,6 @@ class HybridBlock(Block):
         """Defines the forward computation. Arguments can be either
         `NDArray` or `Symbol`."""
         if isinstance(x, NDArray):
-            if self._active and self._cached_op is None:
-                self._build_cache(x, *args)
-
             with x.context as ctx:
                 if self._active:
                     return self._call_cached_op(x, *args)
@@ -376,11 +391,12 @@ class HybridBlock(Block):
                         i._finish_deferred_init()
                     params = {i: j.data(ctx) for i, j in self._reg_params.items()}
                 return self.hybrid_forward(ndarray, x, *args, **params)
-        else:
-            assert isinstance(x, Symbol), \
-                "HybridBlock requires the first argument to forward be either " \
-                "Symbol or NDArray, but got %s"%type(x)
-            params = {i: j.var() for i, j in self._reg_params.items()}
+
+        assert isinstance(x, Symbol), \
+            "HybridBlock requires the first argument to forward be either " \
+            "Symbol or NDArray, but got %s"%type(x)
+        params = {i: j.var() for i, j in self._reg_params.items()}
+        with self.name_scope():
             return self.hybrid_forward(symbol, x, *args, **params)
 
     def hybrid_forward(self, F, x, *args, **kwargs):
@@ -395,3 +411,83 @@ class HybridBlock(Block):
         """
         # pylint: disable= invalid-name
         raise NotImplementedError
+
+
+class SymbolBlock(HybridBlock):
+    """Construct block from symbol. This is useful for using pre-trained models
+    as feature extractors. For example, you may want to extract get the output
+    from fc2 layer in AlexNet.
+
+    Parameters
+    ----------
+    outputs : Symbol or list of Symbol
+        The desired output for SymbolBlock.
+    inputs : Symbol or list of Symbol
+        The Variables in output's argument that should be used as inputs.
+    params : ParameterDict
+        Parameter dictionary for arguments and auxililary states of outputs
+        that are not inputs.
+
+    Examples
+    --------
+    >>> # To extract the feature from fc1 and fc2 layers of AlexNet:
+    >>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(),
+                                                 prefix='model_')
+    >>> inputs = mx.sym.var('data')
+    >>> out = alexnet(inputs)
+    >>> internals = out.get_internals()
+    >>> print(internals.list_outputs())
+    ['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...]
+    >>> outputs = [internals['model_dense0_relu_fwd_output'],
+                   internals['model_dense1_relu_fwd_output']]
+    >>> # Create SymbolBlock that shares parameters with alexnet
+    >>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())
+    >>> x = mx.nd.random_normal(shape=(16, 3, 224, 224))
+    >>> print(feat_model(x))
+    """
+    def __init__(self, outputs, inputs, params=None):
+        super(SymbolBlock, self).__init__(prefix=None, params=None)
+        self._prefix = ''
+        self._params = ParameterDict('', params)
+        if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1:
+            inputs = [inputs]
+        if isinstance(outputs, symbol.Symbol) and len(outputs.list_outputs()) == 1:
+            outputs = [outputs]
+
+        syms, self._in_format = _flatten(inputs)
+        out, self._out_format = _flatten(outputs)
+        out = symbol.Group(out)
+
+        input_names = set()
+        for i in syms:
+            assert len(i.get_internals().list_outputs()) == 1, \
+                "Input symbols must be variable, but %s is an output of operators"%str(i)
+            input_names.add(i.name)
+
+        for i in out.list_arguments():
+            if i not in input_names:
+                self.params.get(i, allow_deferred_init=True)
+
+        for i in out.list_auxiliary_states():
+            if i not in input_names:
+                self.params.get(i, grad_req='null', allow_deferred_init=True)
+
+        self._cached_graph = syms, out
+        self._build_cache()
+
+    def forward(self, x, *args):
+        if isinstance(x, NDArray):
+            with x.context:
+                return self._call_cached_op(x, *args)
+
+        assert isinstance(x, Symbol), \
+            "HybridBlock requires the first argument to forward be either " \
+            "Symbol or NDArray, but got %s"%type(x)
+        args, in_fmt = _flatten([x] + list(args))
+        assert in_fmt == self._in_format, "Invalid input format"
+        ret = copy.copy(self._cached_graph[1])
+        ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)})
+        return _regroup(ret, self._out_format)[0]
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        raise NotImplementedError
diff --git a/python/mxnet/gluon/model_zoo/vision/alexnet.py b/python/mxnet/gluon/model_zoo/vision/alexnet.py
index dd5104d..86ff932 100644
--- a/python/mxnet/gluon/model_zoo/vision/alexnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/alexnet.py
@@ -21,27 +21,27 @@ class AlexNet(HybridBlock):
         with self.name_scope():
             self.features = nn.HybridSequential(prefix='')
             with self.features.name_scope():
-                self.features.add(nn.Conv2D(64, kernel_size=11, strides=4, padding=2))
-                self.features.add(nn.Activation('relu'))
+                self.features.add(nn.Conv2D(64, kernel_size=11, strides=4,
+                                            padding=2, activation='relu'))
                 self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
-                self.features.add(nn.Conv2D(192, kernel_size=5, padding=2))
-                self.features.add(nn.Activation('relu'))
+                self.features.add(nn.Conv2D(192, kernel_size=5, padding=2,
+                                            activation='relu'))
                 self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
-                self.features.add(nn.Conv2D(384, kernel_size=3, padding=1))
-                self.features.add(nn.Activation('relu'))
-                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
-                self.features.add(nn.Activation('relu'))
-                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1))
-                self.features.add(nn.Activation('relu'))
+                self.features.add(nn.Conv2D(384, kernel_size=3, padding=1,
+                                            activation='relu'))
+                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,
+                                            activation='relu'))
+                self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,
+                                            activation='relu'))
                 self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                 self.features.add(nn.Flatten())
 
             self.classifier = nn.HybridSequential(prefix='')
             with self.classifier.name_scope():
-                self.classifier.add(nn.Dropout(0.5))
                 self.classifier.add(nn.Dense(4096, activation='relu'))
                 self.classifier.add(nn.Dropout(0.5))
                 self.classifier.add(nn.Dense(4096, activation='relu'))
+                self.classifier.add(nn.Dropout(0.5))
                 self.classifier.add(nn.Dense(classes))
 
     def hybrid_forward(self, F, x):
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 6a2000e..063deb4 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -132,15 +132,17 @@ class Dense(HybridBlock):
             else:
                 self.bias = None
             if activation is not None:
-                self.act = Activation(activation)
+                self.act = Activation(activation, prefix=activation+'_')
             else:
                 self.act = None
 
     def hybrid_forward(self, F, x, weight, bias=None):
         if bias is None:
-            act = F.FullyConnected(x, weight, no_bias=True, num_hidden=self._units)
+            act = F.FullyConnected(x, weight, no_bias=True, num_hidden=self._units,
+                                   name='fwd')
         else:
-            act = F.FullyConnected(x, weight, bias, num_hidden=self._units)
+            act = F.FullyConnected(x, weight, bias, num_hidden=self._units,
+                                   name='fwd')
         if self.act is not None:
             act = self.act(act)
         return act
@@ -177,7 +179,7 @@ class Activation(HybridBlock):
         return self._act_type
 
     def hybrid_forward(self, F, x):
-        return F.Activation(x, act_type=self._act_type)
+        return F.Activation(x, act_type=self._act_type, name='fwd')
 
     def __repr__(self):
         s = '{name}({_act_type})'
@@ -213,7 +215,7 @@ class Dropout(HybridBlock):
         self._rate = rate
 
     def hybrid_forward(self, F, x):
-        return F.Dropout(x, p=self._rate)
+        return F.Dropout(x, p=self._rate, name='fwd')
 
     def __repr__(self):
         s = '{name}(p = {_rate})'
@@ -271,7 +273,7 @@ class BatchNorm(HybridBlock):
                  in_channels=0, **kwargs):
         super(BatchNorm, self).__init__(**kwargs)
         self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
-                        'fix_gamma': not center}
+                        'fix_gamma': not scale}
         if in_channels != 0:
             self.in_channels = in_channels
 
@@ -291,7 +293,8 @@ class BatchNorm(HybridBlock):
                                            allow_deferred_init=True)
 
     def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
-        return F.BatchNorm(x, gamma, beta, running_mean, running_var, **self._kwargs)
+        return F.BatchNorm(x, gamma, beta, running_mean, running_var,
+                           name='fwd', **self._kwargs)
 
     def __repr__(self):
         s = '{name}({content}'
@@ -328,7 +331,7 @@ class LeakyReLU(HybridBlock):
         self._alpha = alpha
 
     def hybrid_forward(self, F, x):
-        return F.LeakyReLU(x, act_type='leaky', slope=self._alpha)
+        return F.LeakyReLU(x, act_type='leaky', slope=self._alpha, name='fwd')
 
     def __repr__(self):
         s = '{name}({alpha})'
@@ -369,11 +372,11 @@ class Embedding(HybridBlock):
                                       allow_deferred_init=True)
 
     def hybrid_forward(self, F, x, weight):
-        return F.Embedding(x, weight, **self._kwargs)
+        return F.Embedding(x, weight, name='fwd', **self._kwargs)
 
     def __repr__(self):
-        s = '{name}({input_dim} -> {output_dim}, {dtype})'
-        return s.format(name=self.__class__.__name__,
+        s = '{block_name}({input_dim} -> {output_dim}, {dtype})'
+        return s.format(block_name=self.__class__.__name__,
                         **self._kwargs)
 
 
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index bb2ffea..d9608a1 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -98,15 +98,15 @@ class _Conv(HybridBlock):
                 self.bias = None
 
             if activation is not None:
-                self.act = Activation(activation)
+                self.act = Activation(activation, prefix=activation+'_')
             else:
                 self.act = None
 
     def hybrid_forward(self, F, x, weight, bias=None):
         if bias is None:
-            act = getattr(F, self._op_name)(x, weight, **self._kwargs)
+            act = getattr(F, self._op_name)(x, weight, name='fwd', **self._kwargs)
         else:
-            act = getattr(F, self._op_name)(x, weight, bias, **self._kwargs)
+            act = getattr(F, self._op_name)(x, weight, bias, name='fwd', **self._kwargs)
         if self.act is not None:
             act = self.act(act)
         return act
@@ -644,8 +644,11 @@ class _Pooling(HybridBlock):
             'global_pool': global_pool, 'pool_type': pool_type,
             'pooling_convention': 'full' if ceil_mode else 'valid'}
 
+    def _alias(self):
+        return 'pool'
+
     def hybrid_forward(self, F, x):
-        return F.Pooling(x, **self._kwargs)
+        return F.Pooling(x, name='fwd', **self._kwargs)
 
     def __repr__(self):
         s = '{name}(size={kernel}, stride={stride}, padding={pad}, ceil_mode={ceil_mode})'
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 981b78b..657981c 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -3,6 +3,7 @@
 """Neural network parameter."""
 
 from collections import OrderedDict
+import warnings
 import numpy as np
 
 
@@ -82,59 +83,22 @@ class Parameter(object):
         s = 'Parameter {name} (shape={shape}, dtype={dtype})'
         return s.format(**self.__dict__)
 
-    def initialize(self, init=None, ctx=None, default_init=initializer.Uniform()):
-        """Initializes parameter and gradient arrays. Only used for `NDArray` API.
-
-        Parameters
-        ----------
-        init : Initializer
-            The initializer to use. Overrides `Parameter.init` and default_init.
-        ctx : Context or list of Context, defaults to `context.current_context()`.
-            Initialize Parameter on given context. If ctx is a list of Context, a
-            copy will be made for each context.
-
-            .. note:: Copies are independent arrays. User is responsible for keeping
-            their values consistent when updating. Normally `gluon.Trainer` does this for you.
-        default_init : Initializer
-            Default initializer is used when both `init` and `Parameter.init` are `None`.
-
-        Examples
-        --------
-        >>> weight = mx.gluon.Parameter('weight', shape=(2, 2))
-        >>> weight.initialize(ctx=mx.cpu(0))
-        >>> weight.data()
-        [[-0.01068833  0.01729892]
-         [ 0.02042518 -0.01618656]]
-        <NDArray 2x2 @cpu(0)>
-        >>> weight.grad()
-        [[ 0.  0.]
-         [ 0.  0.]]
-        <NDArray 2x2 @cpu(0)>
-        >>> weight.initialize(ctx=[mx.gpu(0), mx.gpu(1)])
-        >>> weight.data(mx.gpu(0))
-        [[-0.00873779 -0.02834515]
-         [ 0.05484822 -0.06206018]]
-        <NDArray 2x2 @gpu(0)>
-        >>> weight.data(mx.gpu(1))
-        [[-0.00873779 -0.02834515]
-         [ 0.05484822 -0.06206018]]
-        <NDArray 2x2 @gpu(1)>
-        """
-        if ctx is None:
-            ctx = [context.current_context()]
-        if isinstance(ctx, Context):
-            ctx = [ctx]
-        if init is None:
-            init = default_init if self.init is None else self.init
-        if not self.shape or np.prod(self.shape) <= 0:
-            if self.allow_deferred_init:
-                self._defered_init = (init, ctx, default_init)
-                return
-            raise ValueError("Cannot initialize Parameter %s because it has " \
-                             "invalid shape: %s."%(self.name, str(self.shape)))
-
-        self._defered_init = (init, ctx, default_init)
-        self._finish_deferred_init()
+    def _check_initialized(self, ctx=None):
+        if self._data is not None:
+            if ctx is not None and ctx not in self._data:
+                raise RuntimeError(
+                    "Parameter %s was not initialized on context %s. "
+                    "It was only initialized on %s."%(
+                        self.name, str(ctx), str(self.list_ctx())))
+            return
+        if self._defered_init:
+            raise DeferredInitializationError
+        raise RuntimeError(
+            "Parameter %s has not been initialized. Note that " \
+            "you should initialize parameters and create Trainer " \
+            "with Block.collect_params() instead of Block.params " \
+            "because the later does not include Parameters of " \
+            "nested child Blocks"%(self.name))
 
     def _load_init(self, data, ctx):
         """(Re)initializes by loading from data."""
@@ -202,6 +166,98 @@ class Parameter(object):
 
         autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req)
 
+    def _reduce(self):
+        """Reduce data from multiple context."""
+        block = self.list_data()
+        data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
+        return data
+
+    def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
+                   force_reinit=False):
+        """Initializes parameter and gradient arrays. Only used for `NDArray` API.
+
+        Parameters
+        ----------
+        init : Initializer
+            The initializer to use. Overrides `Parameter.init` and default_init.
+        ctx : Context or list of Context, defaults to `context.current_context()`.
+            Initialize Parameter on given context. If ctx is a list of Context, a
+            copy will be made for each context.
+
+            .. note:: Copies are independent arrays. User is responsible for keeping
+            their values consistent when updating. Normally `gluon.Trainer` does this for you.
+        default_init : Initializer
+            Default initializer is used when both `init` and `Parameter.init` are `None`.
+        force_reinit : bool, default False
+            Whether to force re-initialization if parameter is already initialized.
+
+        Examples
+        --------
+        >>> weight = mx.gluon.Parameter('weight', shape=(2, 2))
+        >>> weight.initialize(ctx=mx.cpu(0))
+        >>> weight.data()
+        [[-0.01068833  0.01729892]
+         [ 0.02042518 -0.01618656]]
+        <NDArray 2x2 @cpu(0)>
+        >>> weight.grad()
+        [[ 0.  0.]
+         [ 0.  0.]]
+        <NDArray 2x2 @cpu(0)>
+        >>> weight.initialize(ctx=[mx.gpu(0), mx.gpu(1)])
+        >>> weight.data(mx.gpu(0))
+        [[-0.00873779 -0.02834515]
+         [ 0.05484822 -0.06206018]]
+        <NDArray 2x2 @gpu(0)>
+        >>> weight.data(mx.gpu(1))
+        [[-0.00873779 -0.02834515]
+         [ 0.05484822 -0.06206018]]
+        <NDArray 2x2 @gpu(1)>
+        """
+        if self._data is not None and not force_reinit:
+            warnings.warn("Parameter %s is already initialized, ignoring. " \
+                          "Set force_reinit=True to re-initialize."%self.name)
+            return
+        self._data = self._grad = None
+
+        if ctx is None:
+            ctx = [context.current_context()]
+        if isinstance(ctx, Context):
+            ctx = [ctx]
+        if init is None:
+            init = default_init if self.init is None else self.init
+        if not self.shape or np.prod(self.shape) <= 0:
+            if self.allow_deferred_init:
+                self._defered_init = (init, ctx, default_init)
+                return
+            raise ValueError("Cannot initialize Parameter %s because it has " \
+                             "invalid shape: %s."%(self.name, str(self.shape)))
+
+        self._defered_init = (init, ctx, default_init)
+        self._finish_deferred_init()
+
+    def reset_ctx(self, ctx):
+        """Re-assign Parameter to other contexts.
+
+        ctx : Context or list of Context, default `context.current_context()`.
+            Assign Parameter to given context. If ctx is a list of Context, a
+            copy will be made for each context.
+        """
+        if ctx is None:
+            ctx = [context.current_context()]
+        if isinstance(ctx, Context):
+            ctx = [ctx]
+        if self._data:
+            data = self._reduce()
+            with autograd.pause():
+                self._init_impl(data, ctx)
+        elif self._defered_init:
+            init, _, default_init = self._defered_init
+            self._defered_init = (init, ctx, default_init)
+        else:
+            raise ValueError("Cannot reset context for Parameter %s because it "
+                             "has not been initialized."%self.name)
+
+
     def set_data(self, data):
         """Sets this parameter's value on all contexts to data."""
         assert self._data is not None, \
@@ -209,23 +265,6 @@ class Parameter(object):
         for arr in self.list_data():
             arr[:] = data
 
-    def _check_initialized(self, ctx=None):
-        if self._data is not None:
-            if ctx is not None and ctx not in self._data:
-                raise RuntimeError(
-                    "Parameter %s was not initialized on context %s. "
-                    "It was only initialized on %s."%(
-                        self.name, str(ctx), str(self.list_ctx())))
-            return
-        if self._defered_init:
-            raise DeferredInitializationError
-        raise RuntimeError(
-            "Parameter %s has not been initialized. Note that " \
-            "you should initialize parameters and create Trainer " \
-            "with Block.collect_params() instead of Block.params " \
-            "because the later does not include Parameters of " \
-            "nested child Blocks"%(self.name))
-
     def data(self, ctx=None):
         """Returns a copy of this parameter on one context. Must have been
         initialized on this context before.
@@ -404,7 +443,8 @@ class ParameterDict(object):
             else:
                 self._params[k] = v
 
-    def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False):
+    def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
+                   force_reinit=False):
         """Initializes all Parameters managed by this dictionary to be used for `NDArray`
         API. It has no effect when using `Symbol` API.
 
@@ -415,17 +455,29 @@ class ParameterDict(object):
             Otherwise, `Parameter.init` takes precedence.
         ctx : Context or list of Context
             Keeps a copy of Parameters on one or many context(s).
+        force_reinit : bool, default False
+            Whether to force re-initialization if parameter is already initialized.
         """
         if verbose:
             init.set_verbosity(verbose=verbose)
         for _, v in self.items():
-            v.initialize(None, ctx, init)
+            v.initialize(None, ctx, init, force_reinit=force_reinit)
 
     def zero_grad(self):
         """Sets all Parameters' gradient buffer to 0."""
         for i in self.values():
             i.zero_grad()
 
+    def reset_ctx(self, ctx):
+        """Re-assign all Parameters to other contexts.
+
+        ctx : Context or list of Context, default `context.current_context()`.
+            Assign Parameter to given context. If ctx is a list of Context, a
+            copy will be made for each context.
+        """
+        for i in self.values():
+            i.reset_ctx(ctx)
+
     def save(self, filename, strip_prefix=''):
         """Save parameters to file.
 
@@ -436,8 +488,7 @@ class ParameterDict(object):
         """
         arg_dict = {}
         for param in self.values():
-            block = param.list_data()
-            weight = sum(w.copyto(context.cpu()) for w in block) / len(block)
+            weight = param._reduce()
             if not param.name.startswith(strip_prefix):
                 raise ValueError(
                     "Prefix %s is to be striped before saving, but Parameter " \
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index e06599c..7315a27 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -108,10 +108,6 @@ class RecurrentCell(Block):
         """shape and layout information of states"""
         raise NotImplementedError()
 
-    @property
-    def _curr_prefix(self):
-        return '%st%d_'%(self.prefix, self._counter)
-
     def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
         """Initial state for this cell.
 
@@ -313,15 +309,15 @@ class RNNCell(HybridRecurrentCell):
 
     def hybrid_forward(self, F, inputs, states, i2h_weight,
                        h2h_weight, i2h_bias, h2h_bias):
-        name = self._curr_prefix
+        prefix = 't%d_'%self._counter
         i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
                                num_hidden=self._hidden_size,
-                               name='%si2h'%name)
+                               name=prefix+'i2h')
         h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
                                num_hidden=self._hidden_size,
-                               name='%sh2h'%name)
+                               name=prefix+'h2h')
         output = self._get_activation(F, i2h + h2h, self._activation,
-                                      name='%sout'%name)
+                                      name=prefix+'out')
 
         return output, [output]
 
@@ -382,28 +378,21 @@ class LSTMCell(HybridRecurrentCell):
 
     def hybrid_forward(self, F, inputs, states, i2h_weight,
                        h2h_weight, i2h_bias, h2h_bias):
-        name = self._curr_prefix
+        prefix = 't%d_'%self._counter
         i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
-                               num_hidden=self._hidden_size*4,
-                               name='%si2h'%name)
+                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
         h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
-                               num_hidden=self._hidden_size*4,
-                               name='%sh2h'%name)
+                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
         gates = i2h + h2h
-        slice_gates = F.SliceChannel(gates, num_outputs=4,
-                                     name="%sslice"%name)
-        in_gate = F.Activation(slice_gates[0], act_type="sigmoid",
-                               name='%si'%name)
-        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid",
-                                   name='%sf'%name)
-        in_transform = F.Activation(slice_gates[2], act_type="tanh",
-                                    name='%sc'%name)
-        out_gate = F.Activation(slice_gates[3], act_type="sigmoid",
-                                name='%so'%name)
+        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
+        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
+        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
+        in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
+        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
         next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
-                                   name='%sstate'%name)
+                                   name=prefix+'state')
         next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
-                                  name='%sout'%name)
+                                  name=prefix+'out')
 
         return next_h, [next_h, next_c]
 
@@ -463,32 +452,34 @@ class GRUCell(HybridRecurrentCell):
     def hybrid_forward(self, F, inputs, states, i2h_weight,
                        h2h_weight, i2h_bias, h2h_bias):
         # pylint: disable=too-many-locals
-        name = self._curr_prefix
+        prefix = 't%d_'%self._counter
         prev_state_h = states[0]
         i2h = F.FullyConnected(data=inputs,
                                weight=i2h_weight,
                                bias=i2h_bias,
                                num_hidden=self._hidden_size * 3,
-                               name="%si2h" % name)
+                               name=prefix+'i2h')
         h2h = F.FullyConnected(data=prev_state_h,
                                weight=h2h_weight,
                                bias=h2h_bias,
                                num_hidden=self._hidden_size * 3,
-                               name="%sh2h" % name)
+                               name=prefix+'h2h')
 
-        i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3, name="%si2h_slice" % name)
-        h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3, name="%sh2h_slice" % name)
+        i2h_r, i2h_z, i2h = F.SliceChannel(i2h, num_outputs=3,
+                                           name=prefix+'i2h_slice')
+        h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
+                                           name=prefix+'h2h_slice')
 
         reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
-                                  name="%sr_act" % name)
+                                  name=prefix+'r_act')
         update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
-                                   name="%sz_act" % name)
+                                   name=prefix+'z_act')
 
         next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh",
-                                  name="%sh_act" % name)
+                                  name=prefix+'h_act')
 
         next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h,
-                                   name='%sout' % name)
+                                   name=prefix+'out')
 
         return next_h, [next_h]
 
@@ -563,17 +554,17 @@ class DropoutCell(HybridRecurrentCell):
 
     Parameters
     ----------
-    dropout : float
+    rate : float
         Percentage of elements to drop out, which
         is 1 - percentage to retain.
     """
-    def __init__(self, dropout, prefix=None, params=None):
+    def __init__(self, rate, prefix=None, params=None):
         super(DropoutCell, self).__init__(prefix, params)
-        assert isinstance(dropout, numeric_types), "dropout probability must be a number"
-        self.dropout = dropout
+        assert isinstance(rate, numeric_types), "rate must be a number"
+        self.rate = rate
 
     def __repr__(self):
-        s = '{name}(p = {dropout})'
+        s = '{name}(rate = {rate})'
         return s.format(name=self.__class__.__name__,
                         **self.__dict__)
 
@@ -584,8 +575,8 @@ class DropoutCell(HybridRecurrentCell):
         return 'dropout'
 
     def hybrid_forward(self, F, inputs, states):
-        if self.dropout > 0:
-            inputs = F.Dropout(data=inputs, p=self.dropout)
+        if self.rate > 0:
+            inputs = F.Dropout(data=inputs, p=self.rate, name='t%d_fwd'%self._counter)
         return inputs, states
 
     def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
@@ -610,13 +601,15 @@ class ModifierCell(HybridRecurrentCell):
     should be used instead.
     """
     def __init__(self, base_cell):
-        super(ModifierCell, self).__init__(prefix=None, params=None)
+        assert not base_cell._modified, \
+            "Cell %s is already modified. One cell cannot be modified twice"%base_cell.name
         base_cell._modified = True
+        super(ModifierCell, self).__init__(prefix=base_cell.prefix+self._alias(),
+                                           params=None)
         self.base_cell = base_cell
 
     @property
     def params(self):
-        self._own_params = False
         return self.base_cell.params
 
     def state_info(self, batch_size=0):
@@ -697,7 +690,7 @@ class ResidualCell(ModifierCell):
 
     def hybrid_forward(self, F, inputs, states):
         output, states = self.base_cell(inputs, states)
-        output = F.elemwise_add(output, inputs, name="%s_plus_residual" % output.name)
+        output = F.elemwise_add(output, inputs, name='t%d_fwd'%self._counter)
         return output, states
 
     def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py
index 56ab27b..f467f9c 100644
--- a/python/mxnet/symbol.py
+++ b/python/mxnet/symbol.py
@@ -3,6 +3,10 @@
 # pylint: disable=import-error, no-name-in-module
 """Symbolic configuration API of MXNet."""
 from __future__ import absolute_import as _abs
+try:
+    from __builtin__ import slice as py_slice
+except ImportError:
+    from builtins import slice as py_slice
 
 import ctypes
 import warnings
@@ -484,9 +488,16 @@ class Symbol(SymbolBase):
             Indexing key
 
         """
+        output_names = self.list_outputs()
+        if isinstance(index, py_slice):
+            start = 0 if index.start is None else index.start
+            stop = len(output_names) if index.stop is None else index.stop
+            step = 1 if index.step is None else index.step
+            return Group([self[i] for i in range(start, stop, step)])
+
         if isinstance(index, string_types):
             idx = None
-            for i, name in enumerate(self.list_outputs()):
+            for i, name in enumerate(output_names):
                 if name == index:
                     if idx is not None:
                         raise ValueError('There are multiple outputs with name \"%s\"' % index)
@@ -494,9 +505,10 @@ class Symbol(SymbolBase):
             if idx is None:
                 raise ValueError('Cannot find output that matches name \"%s\"' % index)
             index = idx
+
         if not isinstance(index, int):
             raise TypeError('Symbol only support integer index to fetch i-th output')
-        if index >= (len(self.list_outputs())):
+        if index >= len(output_names):
             # Important, python determines the end by this exception
             raise IndexError
         handle = SymbolHandle()
diff --git a/src/ndarray/autograd.cc b/src/ndarray/autograd.cc
index b606a4d..f990ee2 100644
--- a/src/ndarray/autograd.cc
+++ b/src/ndarray/autograd.cc
@@ -133,7 +133,8 @@ AGNodePtr AutogradRuntime::RecordOp(const nnvm::Op* op,
 
   for (uint32_t i = 0; i < outputs.size(); ++i) {
     CHECK(outputs[i].entry_.is_none())
-      << "Inplace operation is not supported when recording with autograd. "
+      << "Inplace operations (+=, -=, x[:]=, etc) are not supported when "
+      << "recording with autograd. "
       << "Assigning to NDArrays that are already in a computational graph "
       << "will cause undefined behavior when evaluating gradients. "
       << "Please call backward first to clear the graph or do this out side of "
diff --git a/tests/python/unittest/test_nn.py b/tests/python/unittest/test_nn.py
index 5883978..d4514e2 100644
--- a/tests/python/unittest/test_nn.py
+++ b/tests/python/unittest/test_nn.py
@@ -13,6 +13,9 @@ def test_parameter():
     assert p.data(mx.cpu(0)).shape == (10, 10)
     assert p.var().name == 'weight'
 
+    p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
+    assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
+
 
 def test_paramdict():
     params = gluon.ParameterDict('net_')
@@ -65,6 +68,27 @@ def test_basic():
     x.wait_to_read()
 
 
+def test_symbol_block():
+    model = nn.HybridSequential()
+    model.add(nn.Dense(128, activation='tanh'))
+    model.add(nn.Dropout(0.5))
+    model.add(nn.Dense(64, activation='tanh'))
+    model.add(nn.Dense(32, in_units=64))
+    model.add(nn.Activation('relu'))
+
+    model.initialize()
+
+    inputs = mx.sym.var('data')
+    outputs = model(inputs).get_internals()
+
+    smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params())
+
+    assert len(smodel(mx.nd.zeros((16, 10)))) == 14
+
+    out = smodel(mx.sym.var('in'))
+    assert len(out.get_internals().list_outputs()) == len(outputs.list_outputs())
+
+
 def check_layer_forward(layer, dshape):
     layer.collect_params().initialize()
     with mx.autograd.record():

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].