You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/18 22:46:55 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12412: Infer dtype in SymbolBlock import from input symbol

sandeep-krishnamurthy closed pull request #12412: Infer dtype in SymbolBlock import from input symbol
URL: https://github.com/apache/incubator-mxnet/pull/12412
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index d0830dcc8ca..6cb9fc690b5 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -26,6 +26,7 @@
 import re
 from collections import OrderedDict
 
+from ..base import mx_real_t
 from .. import symbol, ndarray, initializer
 from ..symbol import Symbol
 from ..ndarray import NDArray
@@ -1053,13 +1054,20 @@ def __init__(self, outputs, inputs, params=None):
                     "SymbolBlock doesn't support Parameter '%s' because its storage " \
                     "type is 'row_sparse'." % j.name
 
-        for i in out.list_arguments():
-            if i not in input_names:
-                self.params.get(i, allow_deferred_init=True)
+        # Infer type of parameters. Without this, every parameter will be created with
+        # default type i.e., fp32
+        arg_params = out.list_arguments()
+        aux_params = out.list_auxiliary_states()
 
-        for i in out.list_auxiliary_states():
-            if i not in input_names:
-                self.params.get(i, grad_req='null', allow_deferred_init=True)
+        arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params)
+
+        for i, arg in enumerate(arg_params):
+            if arg not in input_names:
+                self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i])
+
+        for i, aux in enumerate(aux_params):
+            if aux not in input_names:
+                self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i])
 
         self._cached_graph = syms, out
         len_prefix = len(_common_prefix(list(self._params.keys())))
@@ -1084,5 +1092,71 @@ def _clear_cached_op(self):
         super(SymbolBlock, self)._clear_cached_op()
         self._cached_graph = tmp
 
+    def cast(self, dtype):
+        self._clear_cached_op()
+        super(SymbolBlock, self).cast(dtype)
+
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
+
+def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t):
+    """Utility function that helps in inferring DType of args and auxs params
+    from given input param.
+
+    Parameters
+    ----------
+    in_params: List of Symbol
+        List of input symbol variables.
+    out_params: Symbol
+        Output symbol variable.
+    arg_params: List of Str
+        List of names of argument parametrs.
+    aux_params: List of Str
+        List of names of auxiliary parameters.
+    default_dtype: numpy.dtype or str, default 'float32'
+        Default data type for arg_params and aux_params, if unable to infer the type.
+
+    Returns
+    -------
+    arg_types: List of numpy.dtype
+        List of arg_params type. Order is same as arg_params.
+        Defaults to 'float32', if unable to infer type.
+    aux_types: List of numpy.dtype
+        List of aux_params type. Order is same as aux_params.
+        Defaults to 'float32', if unable to infer type.
+    """
+    arg_types = None
+    aux_types = None
+
+    # Get Input symbol details. This will be used to infer types of
+    # other parameters.
+    input_sym_names = [in_param.name for in_param in in_params]
+
+    # Try to infer input types. If not successful, we will set default dtype.
+    # If successful, we will try to infer other params in the graph.
+    input_sym_arg_types = []
+    can_infer_input_type = True
+    for in_param in in_params:
+        input_sym_arg_type = in_param.infer_type()[0]
+        if not input_sym_arg_type or len(input_sym_arg_type) < 1:
+            can_infer_input_type = False
+            break
+        else:
+            input_sym_arg_types.append(in_param.infer_type()[0][0])
+
+    # Try to infer types of other parameters.
+    if can_infer_input_type:
+        params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)}
+        arg_types, _, aux_types = out_params.infer_type(**params)
+
+    if arg_types is None or len(arg_types) != len(arg_params):
+        arg_types = []
+        for _ in arg_params:
+            arg_types.append(default_dtype)
+
+    if aux_types is None or len(aux_types) != len(aux_params):
+        aux_types = []
+        for _ in aux_params:
+            aux_types.append(default_dtype)
+
+    return (arg_types, aux_types)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 24c86f4e0fa..f53eeb00694 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -727,6 +727,8 @@ def get(self, name, **kwargs):
                         if matched:
                             param._shape = tuple(inferred_shape)
                             continue
+                    elif k == 'dtype' and np.dtype(v) == np.dtype(existing):
+                        continue
 
                     assert v is None or v == existing, \
                         "Cannot retrieve Parameter '%s' because desired attribute " \
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 69375afdfe0..8394276c8ef 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -18,6 +18,7 @@
 from __future__ import print_function
 import sys
 import os
+import tempfile
 import time
 import multiprocessing as mp
 import unittest
@@ -202,6 +203,36 @@ def get_num_devices():
         _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
                                 num_devices=ndev, cuda=True)
 
+@with_seed()
+def test_symbol_block_fp16():
+    # Test case to verify if initializing the SymbolBlock from a model with params
+    # other than fp32 param dtype.
+
+    # 1. Load a resnet model, cast it to fp16 and export
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'resnet34_fp16')
+    ctx = mx.gpu(0)
+
+    net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
+    net_fp32.cast('float16')
+    net_fp32.hybridize()
+    data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx)
+    net_fp32.forward(data)
+    net_fp32.export(tmpfile, 0)
+
+    # 2. Load the saved model and verify if all the params are loaded correctly.
+    # and choose one of the param to verify the type if fp16.
+    sm = mx.sym.load(tmpfile + '-symbol.json')
+    inputs = mx.sym.var('data', dtype='float16')
+    net_fp16 = mx.gluon.SymbolBlock(sm, inputs)
+    net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
+    # 3. Get a conv layer's weight parameter name. Conv layer's weight param is
+    # expected to be of dtype casted, fp16.
+    for param_name in net_fp16.params.keys():
+        if 'conv' in param_name and 'weight' in param_name:
+            break
+    assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index bf9f5a77c84..796182e2b73 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -15,6 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import os
+import tempfile
+
 import mxnet as mx
 from mxnet import gluon
 from mxnet.gluon import nn
@@ -336,6 +339,41 @@ def hybrid_forward(self, F, x):
     net.hybridize()
     assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
 
+    # Test case to verify if initializing the SymbolBlock from a model with params 
+    # other than fp32 param dtype.
+
+    # 1. Load a resnet model, cast it to fp64 and export
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'resnet34_fp64')
+    ctx = mx.cpu(0)
+
+    net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
+    net_fp32.cast('float64')
+    net_fp32.hybridize()
+    data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)
+    net_fp32.forward(data)
+    net_fp32.export(tmpfile, 0)
+
+    # 2. Load the saved model and verify if all the params are loaded correctly.
+    # and choose one of the param to verify the type if fp64.
+    sm = mx.sym.load(tmpfile + '-symbol.json')
+    inputs = mx.sym.var('data', dtype='float64')
+    net_fp64 = mx.gluon.SymbolBlock(sm, inputs)
+    net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
+    # 3. Get a conv layer's weight parameter name. Conv layer's weight param is
+    # expected to be of dtype casted, fp64.
+    for param_name in net_fp64.params.keys():
+        if 'conv' in param_name and 'weight' in param_name:
+            break
+    assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)
+
+    # Cast the symbol block to FP32 and try to forward a FP32 data.
+    # This will verify SymbolBlock.cast() functionality.
+    net_fp64.cast('float32')
+    fp32_data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx)
+    prediction = net_fp64.forward(fp32_data)
+    assert np.dtype(prediction.dtype) == np.dtype(np.float32)
+
 @with_seed()
 @raises(AssertionError)
 def test_sparse_symbol_block():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services