You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by lx...@apache.org on 2017/07/07 15:58:20 UTC
[07/50] [abbrv] incubator-mxnet-test git commit: Module forward
reshape resolving conflicts (#6805)
Module forward reshape resolving conflicts (#6805)
* Add module reshape
* Module forward reshape
* Small fix
* Pass dtype
* Resolve conflict
* More fix
* Fix lint
Project: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/commit/ff968225
Tree: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/tree/ff968225
Diff: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/diff/ff968225
Branch: refs/heads/master
Commit: ff968225167a9f10e255600044028280f5696469
Parents: 459efce
Author: Yao Wang <ke...@gmail.com>
Authored: Tue Jun 27 10:15:12 2017 -0700
Committer: Eric Junyuan Xie <pi...@users.noreply.github.com>
Committed: Tue Jun 27 10:15:12 2017 -0700
----------------------------------------------------------------------
python/mxnet/module/base_module.py | 25 ++++--
python/mxnet/module/module.py | 33 +++++++-
tests/python/unittest/test_module.py | 131 +++++++++++++++++++++++++++---
3 files changed, 172 insertions(+), 17 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/ff968225/python/mxnet/module/base_module.py
----------------------------------------------------------------------
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index a5c4c70..cb6cfcc 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -751,7 +751,11 @@ class BaseModule(object):
pass
def forward(self, data_batch, is_train=None):
- """Forward computation.
+ """Forward computation. It supports data batches with different shapes, such as
+ different batch sizes or different image sizes.
+ If reshaping of data batch relates to modification of symbol or module, such as
+ changing image layout ordering or switching from training to predicting, module
+ rebinding is required.
Parameters
----------
@@ -762,16 +766,25 @@ class BaseModule(object):
Examples
--------
- >>> # An example of forward computation.
+ >>> import mxnet as mx
>>> from collections import namedtuple
>>> Batch = namedtuple('Batch', ['data'])
- >>> mod.bind(data_shapes=[('data', (1, 10, 10))])
+ >>> data = mx.sym.Variable('data')
+ >>> out = data * 2
+ >>> mod = mx.mod.Module(symbol=out, label_names=None)
+ >>> mod.bind(data_shapes=[('data', (1, 10))])
>>> mod.init_params()
- >>> data1 = [mx.nd.ones([1, 10, 10])]
+ >>> data1 = [mx.nd.ones((1, 10))]
>>> mod.forward(Batch(data1))
>>> print mod.get_outputs()[0].asnumpy()
- [[ 0.09999977 0.10000153 0.10000716 0.10000195 0.09999853 0.09999743
- 0.10000272 0.10000113 0.09999088 0.09999888]]
+ [[ 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
+ >>> # Forward with data batch of different shape
+ >>> data2 = [mx.nd.ones((3, 5))]
+ >>> mod.forward(Batch(data2))
+ >>> print mod.get_outputs()[0].asnumpy()
+ [[ 2. 2. 2. 2. 2.]
+ [ 2. 2. 2. 2. 2.]
+ [ 2. 2. 2. 2. 2.]]
"""
raise NotImplementedError()
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/ff968225/python/mxnet/module/module.py
----------------------------------------------------------------------
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index f5f3c2a..1b5ecbc 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -15,6 +15,7 @@ from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
from ..model import load_checkpoint
from ..initializer import Uniform, InitDesc
+from ..io import DataDesc
from .base_module import BaseModule, _check_input_names, _parse_data_desc
@@ -535,7 +536,11 @@ class Module(BaseModule):
self.optimizer_initialized = True
def forward(self, data_batch, is_train=None):
- """Forward computation.
+ """Forward computation. It supports data batches with different shapes, such as
+ different batch sizes or different image sizes.
+ If reshaping of data batch relates to modification of symbol or module, such as
+ changing image layout ordering or switching from training to predicting, module
+ rebinding is required.
See Also
----------
@@ -549,6 +554,32 @@ class Module(BaseModule):
Default is ``None``, which means ``is_train`` takes the value of ``self.for_training``.
"""
assert self.binded and self.params_initialized
+
+ # If start to inference, force rebind module.
+ if self._label_shapes and not data_batch.label:
+ raise RuntimeError("If you are trying to do inference, rebind module "
+ "with 'force_rebind=True' and 'for_training=False'")
+
+ curr_data_shapes = (i.shape for i in self._data_shapes)
+ new_data_shapes = (i.shape for i in data_batch.data)
+
+ if curr_data_shapes != new_data_shapes:
+ if hasattr(data_batch, "provide_data") and data_batch.provide_data:
+ new_dshape = data_batch.provide_data
+ else:
+ new_dshape = [DataDesc(i.name, shape, i.dtype, i.layout) \
+ for i, shape in zip(self._data_shapes, new_data_shapes)]
+
+ if hasattr(data_batch, "provide_label") and data_batch.provide_label:
+ new_lshape = data_batch.provide_label
+ elif hasattr(data_batch, "label") and data_batch.label:
+ new_lshape = [DataDesc(i.name, j.shape, i.dtype, i.layout) \
+ for i, j in zip(self._label_shapes, data_batch.label)]
+ else:
+ new_lshape = None
+
+ self.reshape(new_dshape, new_lshape)
+
self._exec_group.forward(data_batch, is_train)
def backward(self, out_grads=None):
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/ff968225/tests/python/unittest/test_module.py
----------------------------------------------------------------------
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 8990aaf..766995d 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -4,6 +4,7 @@ import numpy as np
from functools import reduce
from mxnet.module.executor_group import DataParallelExecutorGroup
from common import assertRaises
+from collections import namedtuple
def test_module_dtype():
@@ -440,14 +441,124 @@ def test_executor_group():
shared_arg_names=shared_arg_names, extra_args=extra_args)
+def test_forward_reshape():
+ num_class=10
+ data1 = mx.sym.Variable('data1')
+ data2 = mx.sym.Variable('data2')
+ conv1 = mx.sym.Convolution(data=data1, kernel=(2, 2), num_filter=2, stride=(2, 2))
+ conv2 = mx.sym.Convolution(data=data2, kernel=(3, 3), num_filter=3, stride=(1, 1))
+ pooling1 = mx.sym.Pooling(data=conv1, kernel=(2, 2), stride=(1, 1), pool_type="avg")
+ pooling2 = mx.sym.Pooling(data=conv2, kernel=(2, 2), stride=(1, 1), pool_type="max")
+ flatten1 = mx.sym.flatten(data=pooling1)
+ flatten2 = mx.sym.flatten(data=pooling2)
+ sum = mx.sym.sum(data=flatten1, axis=1) + mx.sym.sum(data=flatten2, axis=1)
+ fc = mx.sym.FullyConnected(data=sum, num_hidden=num_class)
+ sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+
+ dshape1 = (10, 3, 64, 64)
+ dshape2 = (10, 3, 32, 32)
+ lshape = (10,)
+
+ mod = mx.mod.Module(symbol=sym, data_names=['data1', 'data2'],
+ label_names=['softmax_label'])
+ mod.bind(data_shapes=[('data1', dshape1), ('data2', dshape2)],
+ label_shapes=[('softmax_label', lshape)])
+ mod.init_params()
+ mod.init_optimizer(optimizer_params={'learning_rate': 0.01})
+
+ # Train with original data shapes
+ data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+ mx.nd.random_uniform(5, 15, dshape2)],
+ label=[mx.nd.ones(lshape)])
+ mod.forward(data_batch)
+ assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+ mod.backward()
+ mod.update()
+
+ # Train with different batch size
+ dshape1 = (3, 3, 64, 64)
+ dshape2 = (3, 3, 32, 32)
+ lshape = (3,)
+ data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+ mx.nd.random_uniform(5, 15, dshape2)],
+ label=[mx.nd.ones(lshape)])
+ mod.forward(data_batch)
+ assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+ mod.backward()
+ mod.update()
+
+ dshape1 = (20, 3, 64, 64)
+ dshape2 = (20, 3, 32, 32)
+ lshape = (20,)
+ data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(3, 5, dshape1),
+ mx.nd.random_uniform(10, 25, dshape2)],
+ label=[mx.nd.ones(lshape)])
+ mod.forward(data_batch)
+ assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+ mod.backward()
+ mod.update()
+
+ #Train with both different batch size and data shapes
+ dshape1 = (20, 3, 120, 120)
+ dshape2 = (20, 3, 32, 64)
+ lshape = (20,)
+ data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+ mx.nd.random_uniform(5, 15, dshape2)],
+ label=[mx.nd.ones(lshape)])
+ mod.forward(data_batch)
+ assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+ mod.backward()
+ mod.update()
+
+ dshape1 = (5, 3, 28, 40)
+ dshape2 = (5, 3, 24, 16)
+ lshape = (5,)
+ data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+ mx.nd.random_uniform(15, 25, dshape2)],
+ label=[mx.nd.ones(lshape)])
+ mod.forward(data_batch)
+ assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+ mod.backward()
+ mod.update()
+
+ #Test score
+ dataset_shape1 = (30, 3, 30, 30)
+ dataset_shape2 = (30, 3, 20, 40)
+ labelset_shape = (30,)
+
+ eval_dataiter = mx.io.NDArrayIter(data=[mx.nd.random_uniform(0, 9, dataset_shape1),
+ mx.nd.random_uniform(15, 25, dataset_shape2)],
+ label=[mx.nd.ones(labelset_shape)],
+ batch_size=5)
+ assert len(mod.score(eval_data=eval_dataiter, eval_metric='acc')) == 1
+
+ #Test prediction
+ dshape1 = (1, 3, 30, 30)
+ dshape2 = (1, 3, 20, 40)
+ dataset_shape1 = (10, 3, 30, 30)
+ dataset_shape2 = (10, 3, 20, 40)
+
+ pred_dataiter = mx.io.NDArrayIter(data=[mx.nd.random_uniform(0, 9, dataset_shape1),
+ mx.nd.random_uniform(15, 25, dataset_shape2)])
+ mod.bind(data_shapes=[('data1', dshape1), ('data2', dshape2)],
+ for_training=False, force_rebind=True)
+ assert mod.predict(pred_dataiter).shape == tuple([10, num_class])
+
+ #Test forward with other data batch API
+ Batch = namedtuple('Batch', ['data'])
+ data = mx.sym.Variable('data')
+ out = data * 2
+ mod = mx.mod.Module(symbol=out, label_names=None)
+ mod.bind(data_shapes=[('data', (1, 10))])
+ mod.init_params()
+ data1 = [mx.nd.ones((1, 10))]
+ mod.forward(Batch(data1))
+ assert mod.get_outputs()[0].shape == (1, 10)
+ data2 = [mx.nd.ones((3, 5))]
+ mod.forward(Batch(data2))
+ assert mod.get_outputs()[0].shape == (3, 5)
+
+
if __name__ == '__main__':
- test_module_dtype()
- test_module_input_grads()
- test_module_states()
- test_module_reshape()
- test_module_set_params()
- test_save_load()
- test_module_layout()
- test_module_switch_bucket()
- test_monitor()
- test_executor_group()
+ import nose
+ nose.runmodule()