You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/03 17:34:59 UTC

[incubator-mxnet] branch master updated: fill parameter shape (#8528)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c02c6b1  fill parameter shape (#8528)
c02c6b1 is described below

commit c02c6b149b40bc9a8db91c95453ff0e96f3edc3c
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Fri Nov 3 10:34:57 2017 -0700

    fill parameter shape (#8528)
---
 python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py |  5 ++--
 python/mxnet/gluon/nn/basic_layers.py           |  8 +++---
 python/mxnet/gluon/nn/conv_layers.py            |  5 ++--
 python/mxnet/gluon/parameter.py                 |  7 +++--
 python/mxnet/gluon/rnn/rnn_cell.py              | 38 ++++++++++++++++++-------
 python/mxnet/gluon/rnn/rnn_layer.py             |  4 +--
 tests/python/unittest/test_gluon.py             | 36 +++++++++++++++++++++++
 tests/python/unittest/test_gluon_contrib.py     |  7 +++++
 tests/python/unittest/test_gluon_rnn.py         | 13 +++++++++
 9 files changed, 99 insertions(+), 24 deletions(-)

diff --git a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
index cbb3f1a..09db547 100644
--- a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
@@ -131,8 +131,9 @@ class _BaseConvRNNCell(HybridRecurrentCell):
         s += ', {_conv_layout}'
         s += ')'
         attrs = self.__dict__
-        mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if self._in_channels
-                   else self._hidden_channels)
+        shape = self.i2h_weight.shape
+        in_channels = shape[1 if self._channel_axis == 1 else -1]
+        mapping = ('{0} -> {1}'.format(in_channels if in_channels else None, shape[0]))
         return s.format(name=self.__class__.__name__,
                         mapping=mapping,
                         **attrs)
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index e9fb2ff..906f03e 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -207,10 +207,10 @@ class Dense(HybridBlock):
 
     def __repr__(self):
         s = '{name}({layout}, {act})'
+        shape = self.weight.shape
         return s.format(name=self.__class__.__name__,
                         act=self.act if self.act else 'linear',
-                        layout='{0} -> {1}'.format(self._in_units, self._units) if self._in_units
-                        else self._units)
+                        layout='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]))
 
 
 class Activation(HybridBlock):
@@ -360,8 +360,8 @@ class BatchNorm(HybridBlock):
 
     def __repr__(self):
         s = '{name}({content}'
-        if hasattr(self, 'in_channels'):
-            s += ', in_channels={0}'.format(self.in_channels)
+        in_channels = self.gamma.shape[0]
+        s += ', in_channels={0}'.format(in_channels if in_channels else None)
         s += ')'
         return s.format(name=self.__class__.__name__,
                         content=', '.join(['='.join([k, v.__repr__()])
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 8dcdbc3..645de98 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -153,10 +153,9 @@ class _Conv(HybridBlock):
         if self.bias is None:
             s += ', bias=False'
         s += ')'
+        shape = self.weight.shape
         return s.format(name=self.__class__.__name__,
-                        mapping=self._channels if not self._in_channels
-                        else '{0} -> {1}'.format(self._in_channels,
-                                                 self._channels),
+                        mapping='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]),
                         **self._kwargs)
 
 
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c73aee2..c42fbaa 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -171,11 +171,12 @@ class Parameter(object):
     def _load_init(self, data, ctx):
         """(Re)initializes by loading from data."""
         if self.shape:
-            for i, j in zip(self.shape, data.shape):
-                assert i == 0 or i == j, \
+            for self_dim, data_dim in zip(self.shape, data.shape):
+                assert self_dim == 0 or self_dim == data_dim, \
                     "Failed loading Parameter %s from saved params: " \
                     "shape incompatible expacted %s vs saved %s"%(
                         self.name, str(self.shape), str(data.shape))
+            self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
         if self.dtype:
             assert np.dtype(self.dtype).type == data.dtype, \
                 "Failed loading Parameter %s from saved params: " \
@@ -344,6 +345,8 @@ class Parameter(object):
             "Parameter %s has not been initialized"%self.name
         for arr in self.list_data():
             arr[:] = data
+        if not self.shape or np.prod(self.shape) <= 0:
+            self.shape = data.shape
 
     def data(self, ctx=None):
         """Returns a copy of this parameter on one context. Must have been
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 9d318eb..ea0e32f 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -111,17 +111,6 @@ class RecurrentCell(Block):
         self._modified = False
         self.reset()
 
-    def __repr__(self):
-        s = '{name}({mapping}'
-        if hasattr(self, '_activation'):
-            s += ', {_activation}'
-        s += ')'
-        mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
-                   else self._hidden_size)
-        return s.format(name=self.__class__.__name__,
-                        mapping=mapping,
-                        **self.__dict__)
-
     def reset(self):
         """Reset before re-using the cell for another graph."""
         self._init_counter = -1
@@ -355,6 +344,17 @@ class RNNCell(HybridRecurrentCell):
     def _alias(self):
         return 'rnn'
 
+    def __repr__(self):
+        s = '{name}({mapping}'
+        if hasattr(self, '_activation'):
+            s += ', {_activation}'
+        s += ')'
+        shape = self.i2h_weight.shape
+        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
+        return s.format(name=self.__class__.__name__,
+                        mapping=mapping,
+                        **self.__dict__)
+
     def hybrid_forward(self, F, inputs, states, i2h_weight,
                        h2h_weight, i2h_bias, h2h_bias):
         prefix = 't%d_'%self._counter
@@ -453,6 +453,14 @@ class LSTMCell(HybridRecurrentCell):
     def _alias(self):
         return 'lstm'
 
+    def __repr__(self):
+        s = '{name}({mapping})'
+        shape = self.i2h_weight.shape
+        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
+        return s.format(name=self.__class__.__name__,
+                        mapping=mapping,
+                        **self.__dict__)
+
     def hybrid_forward(self, F, inputs, states, i2h_weight,
                        h2h_weight, i2h_bias, h2h_bias):
         prefix = 't%d_'%self._counter
@@ -551,6 +559,14 @@ class GRUCell(HybridRecurrentCell):
     def _alias(self):
         return 'gru'
 
+    def __repr__(self):
+        s = '{name}({mapping})'
+        shape = self.i2h_weight.shape
+        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
+        return s.format(name=self.__class__.__name__,
+                        mapping=mapping,
+                        **self.__dict__)
+
     def hybrid_forward(self, F, inputs, states, i2h_weight,
                        h2h_weight, i2h_bias, h2h_bias):
         # pylint: disable=too-many-locals
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 2d7c008..3a4f712 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -89,8 +89,8 @@ class _RNNLayer(Block):
         if self._dir == 2:
             s += ', bidirectional'
         s += ')'
-        mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
-                   else self._hidden_size)
+        shape = self.i2h_weight[0].shape
+        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
         return s.format(name=self.__class__.__name__,
                         mapping=mapping,
                         **self.__dict__)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 6f9966b..df0af34 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -553,8 +553,44 @@ def test_lambda():
     assert_almost_equal(out1.asnumpy(), out3.asnumpy())
 
 
+def test_fill_shape_deferred():
+    net = nn.HybridSequential()
+    with net.name_scope():
+        net.add(nn.Conv2D(64, kernel_size=2, padding=1),
+                nn.BatchNorm(),
+                nn.Dense(10))
+    net.hybridize()
+    net.initialize()
+    net(mx.nd.ones((2,3,5,7)))
+    assert net[0].weight.shape[1] == 3, net[0].weight.shape[1]
+    assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0]
+    assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]
 
 
+def test_fill_shape_load():
+    ctx = mx.context.current_context()
+    net1 = nn.HybridSequential()
+    with net1.name_scope():
+        net1.add(nn.Conv2D(64, kernel_size=2, padding=1),
+                 nn.BatchNorm(),
+                 nn.Dense(10))
+    net1.hybridize()
+    net1.initialize(ctx=ctx)
+    net1(mx.nd.ones((2,3,5,7), ctx))
+    net1.save_params('net_fill.params')
+
+    net2 = nn.HybridSequential()
+    with net2.name_scope():
+        net2.add(nn.Conv2D(64, kernel_size=2, padding=1),
+                 nn.BatchNorm(),
+                 nn.Dense(10))
+    net2.hybridize()
+    net2.initialize()
+    net2.load_params('net_fill.params', ctx)
+    assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1]
+    assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0]
+    assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]
+
 
 if __name__ == '__main__':
     import nose
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index c99836c..07b8956 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -94,6 +94,13 @@ def test_convgru():
     check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
 
 
+def test_conv_fill_shape():
+    cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,))
+    cell.hybridize()
+    check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
+    assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
+
+
 def test_vardrop():
     def check_vardrop(drop_inputs, drop_states, drop_outputs):
         cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'),
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index f71ac18..2288842 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -274,6 +274,19 @@ def test_rnn_layers():
     with mx.autograd.record():
         net(mx.nd.ones((2, 3, 10))).backward()
 
+def test_cell_fill_shape():
+    cell = gluon.rnn.LSTMCell(10)
+    cell.hybridize()
+    check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
+    assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
+
+def test_layer_fill_shape():
+    layer = gluon.rnn.LSTM(10)
+    layer.hybridize()
+    check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
+    print(layer)
+    assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1]
+
 
 if __name__ == '__main__':
     import nose

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].