You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by lx...@apache.org on 2017/07/07 15:58:20 UTC

[07/50] [abbrv] incubator-mxnet-test git commit: Module forward reshape resolving conflicts (#6805)

Module forward reshape resolving conflicts (#6805)

* Add module reshape

* Module forward reshape

* Small fix

* Pass dtype

* Resolve conflict

* More fix

* Fix lint


Project: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/commit/ff968225
Tree: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/tree/ff968225
Diff: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/diff/ff968225

Branch: refs/heads/master
Commit: ff968225167a9f10e255600044028280f5696469
Parents: 459efce
Author: Yao Wang <ke...@gmail.com>
Authored: Tue Jun 27 10:15:12 2017 -0700
Committer: Eric Junyuan Xie <pi...@users.noreply.github.com>
Committed: Tue Jun 27 10:15:12 2017 -0700

----------------------------------------------------------------------
 python/mxnet/module/base_module.py   |  25 ++++--
 python/mxnet/module/module.py        |  33 +++++++-
 tests/python/unittest/test_module.py | 131 +++++++++++++++++++++++++++---
 3 files changed, 172 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/ff968225/python/mxnet/module/base_module.py
----------------------------------------------------------------------
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index a5c4c70..cb6cfcc 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -751,7 +751,11 @@ class BaseModule(object):
         pass
 
     def forward(self, data_batch, is_train=None):
-        """Forward computation.
+        """Forward computation. It supports data batches with different shapes, such as
+        different batch sizes or different image sizes.
+        If reshaping of data batch relates to modification of symbol or module, such as
+        changing image layout ordering or switching from training to predicting, module
+        rebinding is required.
 
         Parameters
         ----------
@@ -762,16 +766,25 @@ class BaseModule(object):
 
         Examples
         --------
-        >>> # An example of forward computation.
+        >>> import mxnet as mx
         >>> from collections import namedtuple
         >>> Batch = namedtuple('Batch', ['data'])
-        >>> mod.bind(data_shapes=[('data', (1, 10, 10))])
+        >>> data = mx.sym.Variable('data')
+        >>> out = data * 2
+        >>> mod = mx.mod.Module(symbol=out, label_names=None)
+        >>> mod.bind(data_shapes=[('data', (1, 10))])
         >>> mod.init_params()
-        >>> data1 = [mx.nd.ones([1, 10, 10])]
+        >>> data1 = [mx.nd.ones((1, 10))]
         >>> mod.forward(Batch(data1))
         >>> print mod.get_outputs()[0].asnumpy()
-        [[ 0.09999977  0.10000153  0.10000716  0.10000195  0.09999853  0.09999743
-           0.10000272  0.10000113  0.09999088  0.09999888]]
+        [[ 2.  2.  2.  2.  2.  2.  2.  2.  2.  2.]]
+        >>> # Forward with data batch of different shape
+        >>> data2 = [mx.nd.ones((3, 5))]
+        >>> mod.forward(Batch(data2))
+        >>> print mod.get_outputs()[0].asnumpy()
+        [[ 2.  2.  2.  2.  2.]
+         [ 2.  2.  2.  2.  2.]
+         [ 2.  2.  2.  2.  2.]]
         """
         raise NotImplementedError()
 

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/ff968225/python/mxnet/module/module.py
----------------------------------------------------------------------
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index f5f3c2a..1b5ecbc 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -15,6 +15,7 @@ from .executor_group import DataParallelExecutorGroup
 from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
 from ..model import load_checkpoint
 from ..initializer import Uniform, InitDesc
+from ..io import DataDesc
 
 from .base_module import BaseModule, _check_input_names, _parse_data_desc
 
@@ -535,7 +536,11 @@ class Module(BaseModule):
         self.optimizer_initialized = True
 
     def forward(self, data_batch, is_train=None):
-        """Forward computation.
+        """Forward computation. It supports data batches with different shapes, such as
+        different batch sizes or different image sizes.
+        If reshaping of data batch relates to modification of symbol or module, such as
+        changing image layout ordering or switching from training to predicting, module
+        rebinding is required.
 
         See Also
         ----------
@@ -549,6 +554,32 @@ class Module(BaseModule):
             Default is ``None``, which means ``is_train`` takes the value of ``self.for_training``.
         """
         assert self.binded and self.params_initialized
+
+        # If start to inference, force rebind module.
+        if self._label_shapes and not data_batch.label:
+            raise RuntimeError("If you are trying to do inference, rebind module "
+                               "with 'force_rebind=True' and 'for_training=False'")
+
+        curr_data_shapes = (i.shape for i in self._data_shapes)
+        new_data_shapes = (i.shape for i in data_batch.data)
+
+        if curr_data_shapes != new_data_shapes:
+            if hasattr(data_batch, "provide_data") and data_batch.provide_data:
+                new_dshape = data_batch.provide_data
+            else:
+                new_dshape = [DataDesc(i.name, shape, i.dtype, i.layout) \
+                              for i, shape in zip(self._data_shapes, new_data_shapes)]
+
+            if hasattr(data_batch, "provide_label") and data_batch.provide_label:
+                new_lshape = data_batch.provide_label
+            elif hasattr(data_batch, "label") and data_batch.label:
+                new_lshape = [DataDesc(i.name, j.shape, i.dtype, i.layout) \
+                              for i, j in zip(self._label_shapes, data_batch.label)]
+            else:
+                new_lshape = None
+
+            self.reshape(new_dshape, new_lshape)
+
         self._exec_group.forward(data_batch, is_train)
 
     def backward(self, out_grads=None):

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/ff968225/tests/python/unittest/test_module.py
----------------------------------------------------------------------
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 8990aaf..766995d 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -4,6 +4,7 @@ import numpy as np
 from functools import reduce
 from mxnet.module.executor_group import DataParallelExecutorGroup
 from common import assertRaises
+from collections import namedtuple
 
 
 def test_module_dtype():
@@ -440,14 +441,124 @@ def test_executor_group():
                            shared_arg_names=shared_arg_names, extra_args=extra_args)
 
 
+def test_forward_reshape():
+    num_class=10
+    data1 = mx.sym.Variable('data1')
+    data2 = mx.sym.Variable('data2')
+    conv1 = mx.sym.Convolution(data=data1, kernel=(2, 2), num_filter=2, stride=(2, 2))
+    conv2 = mx.sym.Convolution(data=data2, kernel=(3, 3), num_filter=3, stride=(1, 1))
+    pooling1 = mx.sym.Pooling(data=conv1, kernel=(2, 2), stride=(1, 1), pool_type="avg")
+    pooling2 = mx.sym.Pooling(data=conv2, kernel=(2, 2), stride=(1, 1), pool_type="max")
+    flatten1 = mx.sym.flatten(data=pooling1)
+    flatten2 = mx.sym.flatten(data=pooling2)
+    sum = mx.sym.sum(data=flatten1, axis=1) + mx.sym.sum(data=flatten2, axis=1)
+    fc = mx.sym.FullyConnected(data=sum, num_hidden=num_class)
+    sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+
+    dshape1 = (10, 3, 64, 64)
+    dshape2 = (10, 3, 32, 32)
+    lshape = (10,)
+
+    mod = mx.mod.Module(symbol=sym, data_names=['data1', 'data2'],
+                        label_names=['softmax_label'])
+    mod.bind(data_shapes=[('data1', dshape1), ('data2', dshape2)],
+             label_shapes=[('softmax_label', lshape)])
+    mod.init_params()
+    mod.init_optimizer(optimizer_params={'learning_rate': 0.01})
+
+    # Train with original data shapes
+    data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+                                       mx.nd.random_uniform(5, 15, dshape2)],
+                                 label=[mx.nd.ones(lshape)])
+    mod.forward(data_batch)
+    assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+    mod.backward()
+    mod.update()
+
+    # Train with different batch size
+    dshape1 = (3, 3, 64, 64)
+    dshape2 = (3, 3, 32, 32)
+    lshape = (3,)
+    data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+                                       mx.nd.random_uniform(5, 15, dshape2)],
+                                 label=[mx.nd.ones(lshape)])
+    mod.forward(data_batch)
+    assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+    mod.backward()
+    mod.update()
+
+    dshape1 = (20, 3, 64, 64)
+    dshape2 = (20, 3, 32, 32)
+    lshape = (20,)
+    data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(3, 5, dshape1),
+                                       mx.nd.random_uniform(10, 25, dshape2)],
+                                 label=[mx.nd.ones(lshape)])
+    mod.forward(data_batch)
+    assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+    mod.backward()
+    mod.update()
+
+    #Train with both different batch size and data shapes
+    dshape1 = (20, 3, 120, 120)
+    dshape2 = (20, 3, 32, 64)
+    lshape = (20,)
+    data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+                                       mx.nd.random_uniform(5, 15, dshape2)],
+                                 label=[mx.nd.ones(lshape)])
+    mod.forward(data_batch)
+    assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+    mod.backward()
+    mod.update()
+
+    dshape1 = (5, 3, 28, 40)
+    dshape2 = (5, 3, 24, 16)
+    lshape = (5,)
+    data_batch = mx.io.DataBatch(data=[mx.nd.random_uniform(0, 9, dshape1),
+                                       mx.nd.random_uniform(15, 25, dshape2)],
+                                 label=[mx.nd.ones(lshape)])
+    mod.forward(data_batch)
+    assert mod.get_outputs()[0].shape == tuple([lshape[0], num_class])
+    mod.backward()
+    mod.update()
+
+    #Test score
+    dataset_shape1 = (30, 3, 30, 30)
+    dataset_shape2 = (30, 3, 20, 40)
+    labelset_shape = (30,)
+
+    eval_dataiter = mx.io.NDArrayIter(data=[mx.nd.random_uniform(0, 9, dataset_shape1),
+                                            mx.nd.random_uniform(15, 25, dataset_shape2)],
+                                      label=[mx.nd.ones(labelset_shape)],
+                                      batch_size=5)
+    assert len(mod.score(eval_data=eval_dataiter, eval_metric='acc')) == 1
+
+    #Test prediction
+    dshape1 = (1, 3, 30, 30)
+    dshape2 = (1, 3, 20, 40)
+    dataset_shape1 = (10, 3, 30, 30)
+    dataset_shape2 = (10, 3, 20, 40)
+
+    pred_dataiter = mx.io.NDArrayIter(data=[mx.nd.random_uniform(0, 9, dataset_shape1),
+                                            mx.nd.random_uniform(15, 25, dataset_shape2)])
+    mod.bind(data_shapes=[('data1', dshape1), ('data2', dshape2)],
+             for_training=False, force_rebind=True)
+    assert mod.predict(pred_dataiter).shape == tuple([10, num_class])
+
+    #Test forward with other data batch API
+    Batch = namedtuple('Batch', ['data'])
+    data = mx.sym.Variable('data')
+    out = data * 2
+    mod = mx.mod.Module(symbol=out, label_names=None)
+    mod.bind(data_shapes=[('data', (1, 10))])
+    mod.init_params()
+    data1 = [mx.nd.ones((1, 10))]
+    mod.forward(Batch(data1))
+    assert mod.get_outputs()[0].shape == (1, 10)
+    data2 = [mx.nd.ones((3, 5))]
+    mod.forward(Batch(data2))
+    assert mod.get_outputs()[0].shape == (3, 5)
+
+
 if __name__ == '__main__':
-    test_module_dtype()
-    test_module_input_grads()
-    test_module_states()
-    test_module_reshape()
-    test_module_set_params()
-    test_save_load()
-    test_module_layout()
-    test_module_switch_bucket()
-    test_monitor()
-    test_executor_group()
+    import nose
+    nose.runmodule()