You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/11/07 18:10:26 UTC

[incubator-mxnet] branch v1.3.x updated: Infer dtype in SymbolBlock import from input symbol (v1.3.x) (#13117)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.3.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.3.x by this push:
     new dc5c877  Infer dtype in SymbolBlock import from input symbol (v1.3.x) (#13117)
dc5c877 is described below

commit dc5c8771943b1711da9ea181acbcee19d4d67491
Author: Anton Chernov <me...@gmail.com>
AuthorDate: Wed Nov 7 19:10:08 2018 +0100

    Infer dtype in SymbolBlock import from input symbol (v1.3.x) (#13117)
    
    * Infer dtype in SymbolBlock import from input symbol
    
    * Fix lint issues and make existing tests pass
    
    * Add tests for importing a fp64 model into symbol block
    
    * Fixing failing test for test symbol block
    
    * Set context in unit tests
    
    * Add tests for fp16, add default dtype in infer_param_types
    
    * Use tmp directory as root for loading from model zoo to avoid race condition
    
    * Fixing naming and parameter selection in test case
    
    * Fixing failing GPU tests
    
    * Make unit test more deterministic to get param name
    
    * Override cast in symbol block, handle grouped symbol
    
    * Handle multiple symbolic input usecase
    
    * Add tests to verify behavior of SymbolBlock.cast
---
 python/mxnet/gluon/block.py         | 86 ++++++++++++++++++++++++++++++++++---
 python/mxnet/gluon/parameter.py     |  2 +
 tests/python/gpu/test_gluon_gpu.py  | 31 +++++++++++++
 tests/python/unittest/test_gluon.py | 38 ++++++++++++++++
 4 files changed, 151 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index d0830dc..6cb9fc6 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -26,6 +26,7 @@ import warnings
 import re
 from collections import OrderedDict
 
+from ..base import mx_real_t
 from .. import symbol, ndarray, initializer
 from ..symbol import Symbol
 from ..ndarray import NDArray
@@ -1053,13 +1054,20 @@ class SymbolBlock(HybridBlock):
                     "SymbolBlock doesn't support Parameter '%s' because its storage " \
                     "type is 'row_sparse'." % j.name
 
-        for i in out.list_arguments():
-            if i not in input_names:
-                self.params.get(i, allow_deferred_init=True)
+        # Infer type of parameters. Without this, every parameter will be created with
+        # default type i.e., fp32
+        arg_params = out.list_arguments()
+        aux_params = out.list_auxiliary_states()
 
-        for i in out.list_auxiliary_states():
-            if i not in input_names:
-                self.params.get(i, grad_req='null', allow_deferred_init=True)
+        arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params)
+
+        for i, arg in enumerate(arg_params):
+            if arg not in input_names:
+                self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i])
+
+        for i, aux in enumerate(aux_params):
+            if aux not in input_names:
+                self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i])
 
         self._cached_graph = syms, out
         len_prefix = len(_common_prefix(list(self._params.keys())))
@@ -1084,5 +1092,71 @@ class SymbolBlock(HybridBlock):
         super(SymbolBlock, self)._clear_cached_op()
         self._cached_graph = tmp
 
+    def cast(self, dtype):
+        self._clear_cached_op()
+        super(SymbolBlock, self).cast(dtype)
+
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
+
+def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t):
+    """Utility function that helps in inferring DType of args and auxs params
+    from given input param.
+
+    Parameters
+    ----------
+    in_params: List of Symbol
+        List of input symbol variables.
+    out_params: Symbol
+        Output symbol variable.
+    arg_params: List of Str
+        List of names of argument parametrs.
+    aux_params: List of Str
+        List of names of auxiliary parameters.
+    default_dtype: numpy.dtype or str, default 'float32'
+        Default data type for arg_params and aux_params, if unable to infer the type.
+
+    Returns
+    -------
+    arg_types: List of numpy.dtype
+        List of arg_params type. Order is same as arg_params.
+        Defaults to 'float32', if unable to infer type.
+    aux_types: List of numpy.dtype
+        List of aux_params type. Order is same as aux_params.
+        Defaults to 'float32', if unable to infer type.
+    """
+    arg_types = None
+    aux_types = None
+
+    # Get Input symbol details. This will be used to infer types of
+    # other parameters.
+    input_sym_names = [in_param.name for in_param in in_params]
+
+    # Try to infer input types. If not successful, we will set default dtype.
+    # If successful, we will try to infer other params in the graph.
+    input_sym_arg_types = []
+    can_infer_input_type = True
+    for in_param in in_params:
+        input_sym_arg_type = in_param.infer_type()[0]
+        if not input_sym_arg_type or len(input_sym_arg_type) < 1:
+            can_infer_input_type = False
+            break
+        else:
+            input_sym_arg_types.append(in_param.infer_type()[0][0])
+
+    # Try to infer types of other parameters.
+    if can_infer_input_type:
+        params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)}
+        arg_types, _, aux_types = out_params.infer_type(**params)
+
+    if arg_types is None or len(arg_types) != len(arg_params):
+        arg_types = []
+        for _ in arg_params:
+            arg_types.append(default_dtype)
+
+    if aux_types is None or len(aux_types) != len(aux_params):
+        aux_types = []
+        for _ in aux_params:
+            aux_types.append(default_dtype)
+
+    return (arg_types, aux_types)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 24c86f4..f53eeb0 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -727,6 +727,8 @@ class ParameterDict(object):
                         if matched:
                             param._shape = tuple(inferred_shape)
                             continue
+                    elif k == 'dtype' and np.dtype(v) == np.dtype(existing):
+                        continue
 
                     assert v is None or v == existing, \
                         "Cannot retrieve Parameter '%s' because desired attribute " \
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 42d65da..ac7df62 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -18,6 +18,7 @@
 from __future__ import print_function
 import sys
 import os
+import tempfile
 import time
 import multiprocessing as mp
 import unittest
@@ -198,6 +199,36 @@ def test_sync_batchnorm():
         _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
                                 num_devices=ndev, cuda=True)
 
+@with_seed()
+def test_symbol_block_fp16():
+    # Test case to verify if initializing the SymbolBlock from a model with params
+    # other than fp32 param dtype.
+
+    # 1. Load a resnet model, cast it to fp16 and export
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'resnet34_fp16')
+    ctx = mx.gpu(0)
+
+    net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
+    net_fp32.cast('float16')
+    net_fp32.hybridize()
+    data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx)
+    net_fp32.forward(data)
+    net_fp32.export(tmpfile, 0)
+
+    # 2. Load the saved model and verify if all the params are loaded correctly.
+    # and choose one of the param to verify the type if fp16.
+    sm = mx.sym.load(tmpfile + '-symbol.json')
+    inputs = mx.sym.var('data', dtype='float16')
+    net_fp16 = mx.gluon.SymbolBlock(sm, inputs)
+    net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
+    # 3. Get a conv layer's weight parameter name. Conv layer's weight param is
+    # expected to be of dtype casted, fp16.
+    for param_name in net_fp16.params.keys():
+        if 'conv' in param_name and 'weight' in param_name:
+            break
+    assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 61b441a..4e13fc3 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -15,6 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import os
+import tempfile
+
 import mxnet as mx
 from mxnet import gluon
 from mxnet.gluon import nn
@@ -336,6 +339,41 @@ def test_symbol_block():
     net.hybridize()
     assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
 
+    # Test case to verify if initializing the SymbolBlock from a model with params 
+    # other than fp32 param dtype.
+
+    # 1. Load a resnet model, cast it to fp64 and export
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'resnet34_fp64')
+    ctx = mx.cpu(0)
+
+    net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
+    net_fp32.cast('float64')
+    net_fp32.hybridize()
+    data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)
+    net_fp32.forward(data)
+    net_fp32.export(tmpfile, 0)
+
+    # 2. Load the saved model and verify if all the params are loaded correctly.
+    # and choose one of the param to verify the type if fp64.
+    sm = mx.sym.load(tmpfile + '-symbol.json')
+    inputs = mx.sym.var('data', dtype='float64')
+    net_fp64 = mx.gluon.SymbolBlock(sm, inputs)
+    net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
+    # 3. Get a conv layer's weight parameter name. Conv layer's weight param is
+    # expected to be of dtype casted, fp64.
+    for param_name in net_fp64.params.keys():
+        if 'conv' in param_name and 'weight' in param_name:
+            break
+    assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)
+
+    # Cast the symbol block to FP32 and try to forward a FP32 data.
+    # This will verify SymbolBlock.cast() functionality.
+    net_fp64.cast('float32')
+    fp32_data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx)
+    prediction = net_fp64.forward(fp32_data)
+    assert np.dtype(prediction.dtype) == np.dtype(np.float32)
+
 @with_seed()
 @raises(AssertionError)
 def test_sparse_symbol_block():