You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/29 21:13:16 UTC

[GitHub] zhreshold closed pull request #11434: add ignore_reinit to initialize to skip warnings

zhreshold closed pull request #11434: add ignore_reinit to initialize to skip warnings
URL: https://github.com/apache/incubator-mxnet/pull/11434
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 0ef28496c20..776592de6d7 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -478,7 +478,7 @@ def apply(self, fn):
         return self
 
     def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
-                   force_reinit=False):
+                   force_reinit=False, ignore_reinit=False):
         """Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children.
         Equivalent to ``block.collect_params().initialize(...)``
 
@@ -493,8 +493,10 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
             Whether to verbosely print out details on initialization.
         force_reinit : bool, default False
             Whether to force re-initialization if parameter is already initialized.
+        ignore_reinit : bool, default False
+            Whether to ignore re-initialization warning if `force_reinit` is not True.
         """
-        self.collect_params().initialize(init, ctx, verbose, force_reinit)
+        self.collect_params().initialize(init, ctx, verbose, force_reinit, ignore_reinit)
 
     def hybridize(self, active=True, **kwargs):
         """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 0c6aae92135..4edd0377d5b 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -323,7 +323,7 @@ def _reduce(self):
         return data
 
     def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
-                   force_reinit=False):
+                   force_reinit=False, ignore_reinit=False):
         """Initializes parameter and gradient arrays. Only used for :py:class:`NDArray` API.
 
         Parameters
@@ -344,6 +344,8 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
             and :py:meth:`Parameter.init` are ``None``.
         force_reinit : bool, default False
             Whether to force re-initialization if parameter is already initialized.
+        ignore_reinit : bool, default False
+            Whether to ignore re-initialization warning if `force_reinit` is not True.
 
         Examples
         --------
@@ -368,9 +370,10 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
         <NDArray 2x2 @gpu(1)>
         """
         if self._data is not None and not force_reinit:
-            warnings.warn("Parameter '%s' is already initialized, ignoring. " \
-                          "Set force_reinit=True to re-initialize."%self.name,
-                          stacklevel=2)
+            if not ignore_reinit:
+                warnings.warn("Parameter '%s' is already initialized, ignoring. " \
+                              "Set force_reinit=True to re-initialize."%self.name,
+                              stacklevel=2)
             return
         self._data = self._grad = None
 
@@ -789,7 +792,7 @@ def update(self, other):
             self._params[k] = v
 
     def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
-                   force_reinit=False):
+                   force_reinit=False, ignore_reinit=False):
         """Initializes all Parameters managed by this dictionary to be used for :py:class:`NDArray`
         API. It has no effect when using :py:class:`Symbol` API.
 
@@ -804,11 +807,13 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
             Whether to verbosely print out details on initialization.
         force_reinit : bool, default False
             Whether to force re-initialization if parameter is already initialized.
+        ignore_reinit : bool, default False
+            Whether to ignore re-initialization warning if `force_reinit` is not True.
         """
         if verbose:
             init.set_verbosity(verbose=verbose)
         for _, v in self.items():
-            v.initialize(None, ctx, init, force_reinit=force_reinit)
+            v.initialize(None, ctx, init, force_reinit=force_reinit, ignore_reinit=ignore_reinit)
 
     def zero_grad(self):
         """Sets all Parameters' gradient buffer to 0."""
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index cd3cc685bdd..1f844691eaf 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1360,6 +1360,29 @@ def test_hybrid_static_memory_recording():
         net(x)
     net(x)
 
+@with_seed()
+def test_ignore_force_reinit():
+    net = nn.HybridSequential()
+    net.add(nn.Dense(4))
+    net.add(nn.Dense(2))
+    net.add(nn.Dense(1))
+    net.collect_params().initialize()
+    net(mx.nd.zeros((4,)))
+
+    with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter("always")
+        net.initialize(force_reinit=True)
+        assert len(w) == 0
+
+    with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter("always")
+        net.initialize(force_reinit=False, ignore_reinit=False)
+        assert len(w) == (3 * 2)  # weight and bias, 3 layers
+
+    with warnings.catch_warnings(record=True) as w:
+        warnings.simplefilter("always")
+        net.initialize(force_reinit=False, ignore_reinit=True)
+        assert len(w) == 0
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services