You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/11/07 18:10:26 UTC
[incubator-mxnet] branch v1.3.x updated: Infer dtype in SymbolBlock
import from input symbol (v1.3.x) (#13117)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.3.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.3.x by this push:
new dc5c877 Infer dtype in SymbolBlock import from input symbol (v1.3.x) (#13117)
dc5c877 is described below
commit dc5c8771943b1711da9ea181acbcee19d4d67491
Author: Anton Chernov <me...@gmail.com>
AuthorDate: Wed Nov 7 19:10:08 2018 +0100
Infer dtype in SymbolBlock import from input symbol (v1.3.x) (#13117)
* Infer dtype in SymbolBlock import from input symbol
* Fix lint issues and make existing tests pass
* Add tests for importing a fp64 model into symbol block
* Fixing failing test for test symbol block
* Set context in unit tests
* Add tests for fp16, add default dtype in infer_param_types
* Use tmp directory as root for loading from model zoo to avoid race condition
* Fixing naming and parameter selection in test case
* Fixing failing GPU tests
* Make unit test more deterministic to get param name
* Override cast in symbol block, handle grouped symbol
* Handle multiple symbolic input usecase
* Add tests to verify behavior of SymbolBlock.cast
---
python/mxnet/gluon/block.py | 86 ++++++++++++++++++++++++++++++++++---
python/mxnet/gluon/parameter.py | 2 +
tests/python/gpu/test_gluon_gpu.py | 31 +++++++++++++
tests/python/unittest/test_gluon.py | 38 ++++++++++++++++
4 files changed, 151 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index d0830dc..6cb9fc6 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -26,6 +26,7 @@ import warnings
import re
from collections import OrderedDict
+from ..base import mx_real_t
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
from ..ndarray import NDArray
@@ -1053,13 +1054,20 @@ class SymbolBlock(HybridBlock):
"SymbolBlock doesn't support Parameter '%s' because its storage " \
"type is 'row_sparse'." % j.name
- for i in out.list_arguments():
- if i not in input_names:
- self.params.get(i, allow_deferred_init=True)
+ # Infer type of parameters. Without this, every parameter will be created with
+ # default type i.e., fp32
+ arg_params = out.list_arguments()
+ aux_params = out.list_auxiliary_states()
- for i in out.list_auxiliary_states():
- if i not in input_names:
- self.params.get(i, grad_req='null', allow_deferred_init=True)
+ arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params)
+
+ for i, arg in enumerate(arg_params):
+ if arg not in input_names:
+ self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i])
+
+ for i, aux in enumerate(aux_params):
+ if aux not in input_names:
+ self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i])
self._cached_graph = syms, out
len_prefix = len(_common_prefix(list(self._params.keys())))
@@ -1084,5 +1092,71 @@ class SymbolBlock(HybridBlock):
super(SymbolBlock, self)._clear_cached_op()
self._cached_graph = tmp
+ def cast(self, dtype):
+ self._clear_cached_op()
+ super(SymbolBlock, self).cast(dtype)
+
def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
+
+def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t):
+ """Utility function that helps in inferring DType of args and auxs params
+ from given input param.
+
+ Parameters
+ ----------
+ in_params: List of Symbol
+ List of input symbol variables.
+ out_params: Symbol
+ Output symbol variable.
+ arg_params: List of Str
+ List of names of argument parametrs.
+ aux_params: List of Str
+ List of names of auxiliary parameters.
+ default_dtype: numpy.dtype or str, default 'float32'
+ Default data type for arg_params and aux_params, if unable to infer the type.
+
+ Returns
+ -------
+ arg_types: List of numpy.dtype
+ List of arg_params type. Order is same as arg_params.
+ Defaults to 'float32', if unable to infer type.
+ aux_types: List of numpy.dtype
+ List of aux_params type. Order is same as aux_params.
+ Defaults to 'float32', if unable to infer type.
+ """
+ arg_types = None
+ aux_types = None
+
+ # Get Input symbol details. This will be used to infer types of
+ # other parameters.
+ input_sym_names = [in_param.name for in_param in in_params]
+
+ # Try to infer input types. If not successful, we will set default dtype.
+ # If successful, we will try to infer other params in the graph.
+ input_sym_arg_types = []
+ can_infer_input_type = True
+ for in_param in in_params:
+ input_sym_arg_type = in_param.infer_type()[0]
+ if not input_sym_arg_type or len(input_sym_arg_type) < 1:
+ can_infer_input_type = False
+ break
+ else:
+ input_sym_arg_types.append(in_param.infer_type()[0][0])
+
+ # Try to infer types of other parameters.
+ if can_infer_input_type:
+ params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)}
+ arg_types, _, aux_types = out_params.infer_type(**params)
+
+ if arg_types is None or len(arg_types) != len(arg_params):
+ arg_types = []
+ for _ in arg_params:
+ arg_types.append(default_dtype)
+
+ if aux_types is None or len(aux_types) != len(aux_params):
+ aux_types = []
+ for _ in aux_params:
+ aux_types.append(default_dtype)
+
+ return (arg_types, aux_types)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 24c86f4..f53eeb0 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -727,6 +727,8 @@ class ParameterDict(object):
if matched:
param._shape = tuple(inferred_shape)
continue
+ elif k == 'dtype' and np.dtype(v) == np.dtype(existing):
+ continue
assert v is None or v == existing, \
"Cannot retrieve Parameter '%s' because desired attribute " \
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 42d65da..ac7df62 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import sys
import os
+import tempfile
import time
import multiprocessing as mp
import unittest
@@ -198,6 +199,36 @@ def test_sync_batchnorm():
_check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
num_devices=ndev, cuda=True)
+@with_seed()
+def test_symbol_block_fp16():
+ # Test case to verify if initializing the SymbolBlock from a model with params
+ # other than fp32 param dtype.
+
+ # 1. Load a resnet model, cast it to fp16 and export
+ tmp = tempfile.mkdtemp()
+ tmpfile = os.path.join(tmp, 'resnet34_fp16')
+ ctx = mx.gpu(0)
+
+ net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
+ net_fp32.cast('float16')
+ net_fp32.hybridize()
+ data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx)
+ net_fp32.forward(data)
+ net_fp32.export(tmpfile, 0)
+
+ # 2. Load the saved model and verify if all the params are loaded correctly.
+ # and choose one of the param to verify the type if fp16.
+ sm = mx.sym.load(tmpfile + '-symbol.json')
+ inputs = mx.sym.var('data', dtype='float16')
+ net_fp16 = mx.gluon.SymbolBlock(sm, inputs)
+ net_fp16.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
+ # 3. Get a conv layer's weight parameter name. Conv layer's weight param is
+ # expected to be of dtype casted, fp16.
+ for param_name in net_fp16.params.keys():
+ if 'conv' in param_name and 'weight' in param_name:
+ break
+ assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16)
+
if __name__ == '__main__':
import nose
nose.runmodule()
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 61b441a..4e13fc3 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+import os
+import tempfile
+
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
@@ -336,6 +339,41 @@ def test_symbol_block():
net.hybridize()
assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
+ # Test case to verify if initializing the SymbolBlock from a model with params
+ # other than fp32 param dtype.
+
+ # 1. Load a resnet model, cast it to fp64 and export
+ tmp = tempfile.mkdtemp()
+ tmpfile = os.path.join(tmp, 'resnet34_fp64')
+ ctx = mx.cpu(0)
+
+ net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp)
+ net_fp32.cast('float64')
+ net_fp32.hybridize()
+ data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)
+ net_fp32.forward(data)
+ net_fp32.export(tmpfile, 0)
+
+ # 2. Load the saved model and verify if all the params are loaded correctly.
+ # and choose one of the param to verify the type if fp64.
+ sm = mx.sym.load(tmpfile + '-symbol.json')
+ inputs = mx.sym.var('data', dtype='float64')
+ net_fp64 = mx.gluon.SymbolBlock(sm, inputs)
+ net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
+ # 3. Get a conv layer's weight parameter name. Conv layer's weight param is
+ # expected to be of dtype casted, fp64.
+ for param_name in net_fp64.params.keys():
+ if 'conv' in param_name and 'weight' in param_name:
+ break
+ assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)
+
+ # Cast the symbol block to FP32 and try to forward a FP32 data.
+ # This will verify SymbolBlock.cast() functionality.
+ net_fp64.cast('float32')
+ fp32_data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx)
+ prediction = net_fp64.forward(fp32_data)
+ assert np.dtype(prediction.dtype) == np.dtype(np.float32)
+
@with_seed()
@raises(AssertionError)
def test_sparse_symbol_block():