You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/08 20:46:48 UTC

[GitHub] sxjscience closed pull request #10032: add axes support for dropouts in gluon

sxjscience closed pull request #10032: add axes support for dropouts in gluon
URL: https://github.com/apache/incubator-mxnet/pull/10032
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index d6402b769cb..b964c712ace 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -180,16 +180,12 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
         states = _get_begin_state(self, F, begin_state, inputs, batch_size)
 
         if self.drop_inputs:
-            first_input = inputs.slice_axis(axis, 0, 1).split(1, axis=axis, squeeze_axis=True)
-            self._initialize_input_masks(F, first_input, states)
-            inputs = F.broadcast_mul(inputs, self.drop_inputs_mask.expand_dims(axis=axis))
+            inputs = F.Dropout(inputs, p=self.drop_inputs, axes=(axis,))
 
         outputs, states = self.base_cell.unroll(length, inputs, states, layout, merge_outputs=True,
                                                 valid_length=valid_length)
         if self.drop_outputs:
-            first_output = outputs.slice_axis(axis, 0, 1).split(1, axis=axis, squeeze_axis=True)
-            self._initialize_output_mask(F, first_output)
-            outputs = F.broadcast_mul(outputs, self.drop_outputs_mask.expand_dims(axis=axis))
+            outputs = F.Dropout(outputs, p=self.drop_outputs, axes=(axis,))
         merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \
             merge_outputs
         outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index b61540dd61b..9dc1a240681 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -226,6 +226,8 @@ class Dropout(HybridBlock):
     ----------
     rate : float
         Fraction of the input units to drop. Must be a number between 0 and 1.
+    axes : tuple of int, default ()
+        The axes on which dropout mask is shared. If empty, regular dropout is applied.
 
 
     Inputs:
@@ -239,15 +241,16 @@ class Dropout(HybridBlock):
         `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
         <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_
     """
-    def __init__(self, rate, **kwargs):
+    def __init__(self, rate, axes=(), **kwargs):
         super(Dropout, self).__init__(**kwargs)
         self._rate = rate
+        self._axes = axes
 
     def hybrid_forward(self, F, x):
-        return F.Dropout(x, p=self._rate, name='fwd')
+        return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd')
 
     def __repr__(self):
-        s = '{name}(p = {_rate})'
+        s = '{name}(p = {_rate}, axes={_axes})'
         return s.format(name=self.__class__.__name__,
                         **self.__dict__)
 
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 61bf24e8cd1..f5c72f5f3e7 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -713,6 +713,8 @@ class DropoutCell(HybridRecurrentCell):
     rate : float
         Percentage of elements to drop out, which
         is 1 - percentage to retain.
+    axes : tuple of int, default ()
+        The axes on which dropout mask is shared. If empty, regular dropout is applied.
 
 
     Inputs:
@@ -723,13 +725,14 @@ class DropoutCell(HybridRecurrentCell):
         - **out**: output tensor with shape `(batch_size, size)`.
         - **next_states**: returns input `states` directly.
     """
-    def __init__(self, rate, prefix=None, params=None):
+    def __init__(self, rate, axes=(), prefix=None, params=None):
         super(DropoutCell, self).__init__(prefix, params)
         assert isinstance(rate, numeric_types), "rate must be a number"
-        self.rate = rate
+        self._rate = rate
+        self._axes = axes
 
     def __repr__(self):
-        s = '{name}(rate = {rate})'
+        s = '{name}(rate={_rate}, axes={_axes})'
         return s.format(name=self.__class__.__name__,
                         **self.__dict__)
 
@@ -740,8 +743,9 @@ def _alias(self):
         return 'dropout'
 
     def hybrid_forward(self, F, inputs, states):
-        if self.rate > 0:
-            inputs = F.Dropout(data=inputs, p=self.rate, name='t%d_fwd'%self._counter)
+        if self._rate > 0:
+            inputs = F.Dropout(data=inputs, p=self._rate, axes=self._axes,
+                               name='t%d_fwd'%self._counter)
         return inputs, states
 
     def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 89f52154370..889d210da34 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -827,6 +827,46 @@ def selu(x):
     x = point_to_validate.reshape((1, 3, 2))
     assert_almost_equal(prelu(x).asnumpy(), mx.nd.where(x >= 0, x, 0.25 * x).asnumpy())
 
+@with_seed()
+def test_dropout():
+    def get_slice(x, axis, idx):
+        ix = ()
+        for i in range(x.ndim):
+            if i == axis:
+                ix += (idx,)
+            else:
+                ix += (slice(None, None, None),)
+        return x[ix]
+
+    def check_dropout_axes(ratio, shape, axes):
+        compactshape = list(shape)
+        for axis in axes:
+            compactshape[axis] = 1
+        compactx = mx.random.uniform(shape=tuple(compactshape))
+        broadcastx = compactx.broadcast_to(shape)
+        dropouty = mx.gluon.nn.Dropout(rate=ratio, axes=axes)(broadcastx)
+        for axis in axes:
+            target = get_slice(dropouty, axis, 0).asnumpy()
+            for i in range(1, shape[axis]):
+                assert(get_slice(dropouty, axis, i).asnumpy() == target).all()
+
+    nshape = (10, 10, 10, 10)
+    with mx.autograd.train_mode():
+        check_dropout_axes(0.25, nshape, axes = (0,))
+        check_dropout_axes(0.25, nshape, axes = (1,))
+        check_dropout_axes(0.25, nshape, axes = (2,))
+        check_dropout_axes(0.25, nshape, axes = (3,))
+        check_dropout_axes(0.25, nshape, axes = (0, 1))
+        check_dropout_axes(0.25, nshape, axes = (0, 2))
+        check_dropout_axes(0.25, nshape, axes = (0, 3))
+        check_dropout_axes(0.25, nshape, axes = (1, 2))
+        check_dropout_axes(0.25, nshape, axes = (1, 3))
+        check_dropout_axes(0.25, nshape, axes = (2, 3))
+        check_dropout_axes(0.25, nshape, axes = (0, 1, 2))
+        check_dropout_axes(0.25, nshape, axes = (0, 2, 3))
+        check_dropout_axes(0.25, nshape, axes = (1, 2, 3))
+
+
 
 
 if __name__ == '__main__':
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index 03e4261ad16..29850dce6ae 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -120,11 +120,8 @@ def check_vardrop(drop_inputs, drop_states, drop_outputs):
         input_data = mx.nd.random_uniform(shape=(10, 3, 50), ctx=mx.context.current_context())
         with mx.autograd.record():
             outputs1, _ = cell.unroll(3, input_data, merge_outputs=True)
-            mask1 = cell.drop_outputs_mask.asnumpy()
             mx.nd.waitall()
             outputs2, _ = cell.unroll(3, input_data, merge_outputs=True)
-            mask2 = cell.drop_outputs_mask.asnumpy()
-        assert not almost_equal(mask1, mask2)
         assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy())
 
         inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 91b8faa49c1..2208a33e801 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -268,7 +268,7 @@ def check_regression(symbol, forward, backward, shape, stype='default', densitie
                      lambda x: x,
                      lambda x, y : x - y,
                      shape, stype='csr')
-   
+
 
 def check_softmax_grad(xpu):
     x = mx.sym.Variable('x')
@@ -4674,19 +4674,20 @@ def check_dropout_axes(ratio, shape, axes):
     check_dropout_ratio(0.25, shape)
 
     nshape = (10, 10, 10, 10)
-    check_dropout_axes(0.25, nshape, axes = (0,))
-    check_dropout_axes(0.25, nshape, axes = (1,))
-    check_dropout_axes(0.25, nshape, axes = (2,))
-    check_dropout_axes(0.25, nshape, axes = (3,))
-    check_dropout_axes(0.25, nshape, axes = (0, 1))
-    check_dropout_axes(0.25, nshape, axes = (0, 2))
-    check_dropout_axes(0.25, nshape, axes = (0, 3))
-    check_dropout_axes(0.25, nshape, axes = (1, 2))
-    check_dropout_axes(0.25, nshape, axes = (1, 3))
-    check_dropout_axes(0.25, nshape, axes = (2, 3))
-    check_dropout_axes(0.25, nshape, axes = (0, 1, 2))
-    check_dropout_axes(0.25, nshape, axes = (0, 2, 3))
-    check_dropout_axes(0.25, nshape, axes = (1, 2, 3))
+    with mx.autograd.train_mode():
+        check_dropout_axes(0.25, nshape, axes = (0,))
+        check_dropout_axes(0.25, nshape, axes = (1,))
+        check_dropout_axes(0.25, nshape, axes = (2,))
+        check_dropout_axes(0.25, nshape, axes = (3,))
+        check_dropout_axes(0.25, nshape, axes = (0, 1))
+        check_dropout_axes(0.25, nshape, axes = (0, 2))
+        check_dropout_axes(0.25, nshape, axes = (0, 3))
+        check_dropout_axes(0.25, nshape, axes = (1, 2))
+        check_dropout_axes(0.25, nshape, axes = (1, 3))
+        check_dropout_axes(0.25, nshape, axes = (2, 3))
+        check_dropout_axes(0.25, nshape, axes = (0, 1, 2))
+        check_dropout_axes(0.25, nshape, axes = (0, 2, 3))
+        check_dropout_axes(0.25, nshape, axes = (1, 2, 3))
 
 
 @with_seed()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services