You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/03 17:34:59 UTC
[incubator-mxnet] branch master updated: fill parameter shape
(#8528)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c02c6b1 fill parameter shape (#8528)
c02c6b1 is described below
commit c02c6b149b40bc9a8db91c95453ff0e96f3edc3c
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Fri Nov 3 10:34:57 2017 -0700
fill parameter shape (#8528)
---
python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py | 5 ++--
python/mxnet/gluon/nn/basic_layers.py | 8 +++---
python/mxnet/gluon/nn/conv_layers.py | 5 ++--
python/mxnet/gluon/parameter.py | 7 +++--
python/mxnet/gluon/rnn/rnn_cell.py | 38 ++++++++++++++++++-------
python/mxnet/gluon/rnn/rnn_layer.py | 4 +--
tests/python/unittest/test_gluon.py | 36 +++++++++++++++++++++++
tests/python/unittest/test_gluon_contrib.py | 7 +++++
tests/python/unittest/test_gluon_rnn.py | 13 +++++++++
9 files changed, 99 insertions(+), 24 deletions(-)
diff --git a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
index cbb3f1a..09db547 100644
--- a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
@@ -131,8 +131,9 @@ class _BaseConvRNNCell(HybridRecurrentCell):
s += ', {_conv_layout}'
s += ')'
attrs = self.__dict__
- mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if self._in_channels
- else self._hidden_channels)
+ shape = self.i2h_weight.shape
+ in_channels = shape[1 if self._channel_axis == 1 else -1]
+ mapping = ('{0} -> {1}'.format(in_channels if in_channels else None, shape[0]))
return s.format(name=self.__class__.__name__,
mapping=mapping,
**attrs)
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index e9fb2ff..906f03e 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -207,10 +207,10 @@ class Dense(HybridBlock):
def __repr__(self):
s = '{name}({layout}, {act})'
+ shape = self.weight.shape
return s.format(name=self.__class__.__name__,
act=self.act if self.act else 'linear',
- layout='{0} -> {1}'.format(self._in_units, self._units) if self._in_units
- else self._units)
+ layout='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]))
class Activation(HybridBlock):
@@ -360,8 +360,8 @@ class BatchNorm(HybridBlock):
def __repr__(self):
s = '{name}({content}'
- if hasattr(self, 'in_channels'):
- s += ', in_channels={0}'.format(self.in_channels)
+ in_channels = self.gamma.shape[0]
+ s += ', in_channels={0}'.format(in_channels if in_channels else None)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 8dcdbc3..645de98 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -153,10 +153,9 @@ class _Conv(HybridBlock):
if self.bias is None:
s += ', bias=False'
s += ')'
+ shape = self.weight.shape
return s.format(name=self.__class__.__name__,
- mapping=self._channels if not self._in_channels
- else '{0} -> {1}'.format(self._in_channels,
- self._channels),
+ mapping='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]),
**self._kwargs)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c73aee2..c42fbaa 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -171,11 +171,12 @@ class Parameter(object):
def _load_init(self, data, ctx):
"""(Re)initializes by loading from data."""
if self.shape:
- for i, j in zip(self.shape, data.shape):
- assert i == 0 or i == j, \
+ for self_dim, data_dim in zip(self.shape, data.shape):
+ assert self_dim == 0 or self_dim == data_dim, \
"Failed loading Parameter %s from saved params: " \
"shape incompatible expacted %s vs saved %s"%(
self.name, str(self.shape), str(data.shape))
+ self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
if self.dtype:
assert np.dtype(self.dtype).type == data.dtype, \
"Failed loading Parameter %s from saved params: " \
@@ -344,6 +345,8 @@ class Parameter(object):
"Parameter %s has not been initialized"%self.name
for arr in self.list_data():
arr[:] = data
+ if not self.shape or np.prod(self.shape) <= 0:
+ self.shape = data.shape
def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 9d318eb..ea0e32f 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -111,17 +111,6 @@ class RecurrentCell(Block):
self._modified = False
self.reset()
- def __repr__(self):
- s = '{name}({mapping}'
- if hasattr(self, '_activation'):
- s += ', {_activation}'
- s += ')'
- mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
- else self._hidden_size)
- return s.format(name=self.__class__.__name__,
- mapping=mapping,
- **self.__dict__)
-
def reset(self):
"""Reset before re-using the cell for another graph."""
self._init_counter = -1
@@ -355,6 +344,17 @@ class RNNCell(HybridRecurrentCell):
def _alias(self):
return 'rnn'
+ def __repr__(self):
+ s = '{name}({mapping}'
+ if hasattr(self, '_activation'):
+ s += ', {_activation}'
+ s += ')'
+ shape = self.i2h_weight.shape
+ mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
+ return s.format(name=self.__class__.__name__,
+ mapping=mapping,
+ **self.__dict__)
+
def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
@@ -453,6 +453,14 @@ class LSTMCell(HybridRecurrentCell):
def _alias(self):
return 'lstm'
+ def __repr__(self):
+ s = '{name}({mapping})'
+ shape = self.i2h_weight.shape
+ mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
+ return s.format(name=self.__class__.__name__,
+ mapping=mapping,
+ **self.__dict__)
+
def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
@@ -551,6 +559,14 @@ class GRUCell(HybridRecurrentCell):
def _alias(self):
return 'gru'
+ def __repr__(self):
+ s = '{name}({mapping})'
+ shape = self.i2h_weight.shape
+ mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
+ return s.format(name=self.__class__.__name__,
+ mapping=mapping,
+ **self.__dict__)
+
def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
# pylint: disable=too-many-locals
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 2d7c008..3a4f712 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -89,8 +89,8 @@ class _RNNLayer(Block):
if self._dir == 2:
s += ', bidirectional'
s += ')'
- mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
- else self._hidden_size)
+ shape = self.i2h_weight[0].shape
+ mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 6f9966b..df0af34 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -553,8 +553,44 @@ def test_lambda():
assert_almost_equal(out1.asnumpy(), out3.asnumpy())
+def test_fill_shape_deferred():
+ net = nn.HybridSequential()
+ with net.name_scope():
+ net.add(nn.Conv2D(64, kernel_size=2, padding=1),
+ nn.BatchNorm(),
+ nn.Dense(10))
+ net.hybridize()
+ net.initialize()
+ net(mx.nd.ones((2,3,5,7)))
+ assert net[0].weight.shape[1] == 3, net[0].weight.shape[1]
+ assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0]
+ assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]
+def test_fill_shape_load():
+ ctx = mx.context.current_context()
+ net1 = nn.HybridSequential()
+ with net1.name_scope():
+ net1.add(nn.Conv2D(64, kernel_size=2, padding=1),
+ nn.BatchNorm(),
+ nn.Dense(10))
+ net1.hybridize()
+ net1.initialize(ctx=ctx)
+ net1(mx.nd.ones((2,3,5,7), ctx))
+ net1.save_params('net_fill.params')
+
+ net2 = nn.HybridSequential()
+ with net2.name_scope():
+ net2.add(nn.Conv2D(64, kernel_size=2, padding=1),
+ nn.BatchNorm(),
+ nn.Dense(10))
+ net2.hybridize()
+ net2.initialize()
+ net2.load_params('net_fill.params', ctx)
+ assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1]
+ assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0]
+ assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]
+
if __name__ == '__main__':
import nose
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index c99836c..07b8956 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -94,6 +94,13 @@ def test_convgru():
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))
+def test_conv_fill_shape():
+ cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,))
+ cell.hybridize()
+ check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
+ assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]
+
+
def test_vardrop():
def check_vardrop(drop_inputs, drop_states, drop_outputs):
cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'),
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index f71ac18..2288842 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -274,6 +274,19 @@ def test_rnn_layers():
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()
+def test_cell_fill_shape():
+ cell = gluon.rnn.LSTMCell(10)
+ cell.hybridize()
+ check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
+ assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
+
+def test_layer_fill_shape():
+ layer = gluon.rnn.LSTM(10)
+ layer.hybridize()
+ check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
+ print(layer)
+ assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1]
+
if __name__ == '__main__':
import nose
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].