You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by gi...@git.apache.org on 2017/07/29 19:07:39 UTC

[GitHub] szha commented on a change in pull request #7221: add reset_ctx

szha commented on a change in pull request #7221: add reset_ctx
URL: https://github.com/apache/incubator-mxnet/pull/7221#discussion_r129973890
 
 

 ##########
 File path: python/mxnet/gluon/parameter.py
 ##########
 @@ -202,30 +165,96 @@ def _init_impl(self, data, ctx):
 
         autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req)
 
+    def _reduce(self):
+        """Reduce data from multiple context."""
+        block = self.list_data()
+        data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
+        return data
+
+    def initialize(self, init=None, ctx=None, default_init=initializer.Uniform()):
+        """Initializes parameter and gradient arrays. Only used for `NDArray` API.
+
+        Parameters
+        ----------
+        init : Initializer
+            The initializer to use. Overrides `Parameter.init` and default_init.
+        ctx : Context or list of Context, defaults to `context.current_context()`.
+            Initialize Parameter on given context. If ctx is a list of Context, a
+            copy will be made for each context.
+
+            .. note:: Copies are independent arrays. User is responsible for keeping
+            their values consistent when updating. Normally `gluon.Trainer` does this for you.
+        default_init : Initializer
+            Default initializer is used when both `init` and `Parameter.init` are `None`.
+
+        Examples
+        --------
+        >>> weight = mx.gluon.Parameter('weight', shape=(2, 2))
+        >>> weight.initialize(ctx=mx.cpu(0))
+        >>> weight.data()
+        [[-0.01068833  0.01729892]
+         [ 0.02042518 -0.01618656]]
+        <NDArray 2x2 @cpu(0)>
+        >>> weight.grad()
+        [[ 0.  0.]
+         [ 0.  0.]]
+        <NDArray 2x2 @cpu(0)>
+        >>> weight.initialize(ctx=[mx.gpu(0), mx.gpu(1)])
+        >>> weight.data(mx.gpu(0))
+        [[-0.00873779 -0.02834515]
+         [ 0.05484822 -0.06206018]]
+        <NDArray 2x2 @gpu(0)>
+        >>> weight.data(mx.gpu(1))
+        [[-0.00873779 -0.02834515]
+         [ 0.05484822 -0.06206018]]
+        <NDArray 2x2 @gpu(1)>
+        """
+        if ctx is None:
+            ctx = [context.current_context()]
+        if isinstance(ctx, Context):
+            ctx = [ctx]
+        if init is None:
+            init = default_init if self.init is None else self.init
+        if not self.shape or np.prod(self.shape) <= 0:
+            if self.allow_deferred_init:
+                self._defered_init = (init, ctx, default_init)
+                return
+            raise ValueError("Cannot initialize Parameter %s because it has " \
+                             "invalid shape: %s."%(self.name, str(self.shape)))
+
+        self._defered_init = (init, ctx, default_init)
+        self._finish_deferred_init()
+
+    def reset_ctx(self, ctx):
+        """Re-assign Parameter to other contexts.
+
+        ctx : Context or list of Context, default `context.current_context()`.
+            Assign Parameter to given context. If ctx is a list of Context, a
+            copy will be made for each context.
+        """
+        if ctx is None:
+            ctx = [context.current_context()]
+        if isinstance(ctx, Context):
+            ctx = [ctx]
+        if self._data:
+            data = self._reduce()
+            with autograd.pause():
+                self._init_impl(data, ctx)
+        elif self._defered_init:
+            init, _, default_init = self._defered_init
+            self._defered_init = (init, ctx, default_init)
+        else:
+            raise ValueError("Cannot reset context for Parameter %s because it "
+                             "has not been initialized."%self.name)
+
+
     def set_data(self, data):
         """Sets this parameter's value on all contexts to data."""
         assert self._data is not None, \
             "Parameter %s has not been initialized"%self.name
         for arr in self.list_data():
             arr[:] = data
 
 Review comment:
   and None data?
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services