You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/21 22:41:20 UTC

[incubator-mxnet] branch master updated: Add cast to Block and Parameter. Implicit dtype casting is removed. (#8735)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1852e2f  Add cast to Block and Parameter. Implicit dtype casting is removed. (#8735)
1852e2f is described below

commit 1852e2f47d68bb4c2373a359a2a8671b59cd14e5
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Tue Nov 21 14:41:17 2017 -0800

    Add cast to Block and Parameter. Implicit dtype casting is removed. (#8735)
    
    * fix
    
    * fix
    
    * fix
    
    * fix
    
    * Update parameter.py
---
 python/mxnet/gluon/block.py           | 18 +++++++-
 python/mxnet/gluon/nn/basic_layers.py | 24 +++++++----
 python/mxnet/gluon/nn/conv_layers.py  |  4 +-
 python/mxnet/gluon/parameter.py       | 81 +++++++++++++++++++++++++++--------
 python/mxnet/gluon/rnn/rnn_cell.py    | 24 +++++------
 tests/python/unittest/test_gluon.py   | 12 +++++-
 6 files changed, 118 insertions(+), 45 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 2546711..466f87f 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -286,6 +286,19 @@ class Block(object):
         for cld in self._children:
             cld.hybridize(active)
 
+    def cast(self, dtype):
+        """Cast this Block to use another data type.
+
+        Parameters
+        ----------
+        dtype : str or numpy.dtype
+            The new data type.
+        """
+        for child in self._children:
+            child.cast(dtype)
+        for _, param in self.params.items():
+            param.cast(dtype)
+
     def __call__(self, *args):
         """Calls forward. Only accepts positional arguments."""
         return self.forward(*args)
@@ -388,7 +401,6 @@ class HybridBlock(Block):
 
     def _finish_deferred_init(self, hybrid, *args):
         self.infer_shape(*args)
-        self.infer_type(*args)
         if hybrid:
             for is_arg, i in self._cached_op_args:
                 if not is_arg:
@@ -429,6 +441,10 @@ class HybridBlock(Block):
         self._active = active
         super(HybridBlock, self).hybridize(active)
 
+    def cast(self, dtype):
+        self._clear_cached_op()
+        super(HybridBlock, self).cast(dtype)
+
     def _infer_attrs(self, infer_fn, attr, *args):
         """Generic infer attributes."""
         inputs, out = self._get_graph(*args)
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 15c8285..c0b4b52 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -22,6 +22,7 @@ __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Activation',
            'Dropout', 'BatchNorm', 'LeakyReLU', 'Embedding', 'Flatten',
            'Lambda', 'HybridLambda']
 import warnings
+import numpy as np
 
 from ..block import Block, HybridBlock
 from ..utils import _indent
@@ -185,11 +186,11 @@ class Dense(HybridBlock):
             self._units = units
             self._in_units = in_units
             self.weight = self.params.get('weight', shape=(units, in_units),
-                                          dtype=None, init=weight_initializer,
+                                          init=weight_initializer,
                                           allow_deferred_init=True)
             if use_bias:
                 self.bias = self.params.get('bias', shape=(units,),
-                                            dtype=None, init=bias_initializer,
+                                            init=bias_initializer,
                                             allow_deferred_init=True)
             else:
                 self.bias = None
@@ -336,24 +337,29 @@ class BatchNorm(HybridBlock):
             self.in_channels = in_channels
 
         self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
-                                     shape=(in_channels,), dtype=None,
-                                     init=gamma_initializer, allow_deferred_init=True,
+                                     shape=(in_channels,), init=gamma_initializer,
+                                     allow_deferred_init=True,
                                      differentiable=scale)
         self.beta = self.params.get('beta', grad_req='write' if center else 'null',
-                                    shape=(in_channels,), dtype=None,
-                                    init=beta_initializer, allow_deferred_init=True,
+                                    shape=(in_channels,), init=beta_initializer,
+                                    allow_deferred_init=True,
                                     differentiable=center)
         self.running_mean = self.params.get('running_mean', grad_req='null',
-                                            shape=(in_channels,), dtype=None,
+                                            shape=(in_channels,),
                                             init=running_mean_initializer,
                                             allow_deferred_init=True,
                                             differentiable=False)
         self.running_var = self.params.get('running_var', grad_req='null',
-                                           shape=(in_channels,), dtype=None,
+                                           shape=(in_channels,),
                                            init=running_variance_initializer,
                                            allow_deferred_init=True,
                                            differentiable=False)
 
+    def cast(self, dtype):
+        if np.dtype(dtype).name == 'float16':
+            dtype = 'float32'
+        super(BatchNorm, self).cast(dtype)
+
     def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
         return F.BatchNorm(x, gamma, beta, running_mean, running_var,
                            name='fwd', **self._kwargs)
@@ -437,7 +443,7 @@ class Embedding(HybridBlock):
         self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
                         'dtype': dtype}
         self.weight = self.params.get('weight', shape=(input_dim, output_dim),
-                                      dtype=None, init=weight_initializer,
+                                      init=weight_initializer,
                                       allow_deferred_init=True)
 
     def hybrid_forward(self, F, x, weight):
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 0dd7069..645de98 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -113,11 +113,11 @@ class _Conv(HybridBlock):
             dshape[layout.find('C')] = in_channels
             wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
             self.weight = self.params.get('weight', shape=wshapes[1],
-                                          dtype=None, init=weight_initializer,
+                                          init=weight_initializer,
                                           allow_deferred_init=True)
             if use_bias:
                 self.bias = self.params.get('bias', shape=wshapes[2],
-                                            dtype=None, init=bias_initializer,
+                                            init=bias_initializer,
                                             allow_deferred_init=True)
             else:
                 self.bias = None
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 27297b5..537d636 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -107,8 +107,8 @@ class Parameter(object):
         self._differentiable = differentiable
         self._allow_deferred_init = allow_deferred_init
         self._grad_req = None
+        self._shape = shape
         self.name = name
-        self.shape = shape
         self.dtype = dtype
         self.lr_mult = lr_mult
         self.wd_mult = wd_mult
@@ -138,6 +138,23 @@ class Parameter(object):
         elif self._data is not None:
             self._init_grad()
 
+    @property
+    def shape(self):
+        return self._shape
+
+    @shape.setter
+    def shape(self, new_shape):
+        if self._shape is None:
+            self._shape = new_shape
+            return
+
+        assert len(self._shape) == len(new_shape) and \
+            all(j == 0 or i == j for i, j in zip(new_shape, self._shape)), \
+            "Expected shape %s is incompatible with given shape %s."%(
+                str(new_shape), str(self._shape))
+
+        self._shape = new_shape
+
     def _check_and_get(self, arr_list, ctx):
         if arr_list is not None:
             if ctx is list:
@@ -147,9 +164,12 @@ class Parameter(object):
                     return arr_list[0]
                 else:
                     ctx = context.current_context()
-            idx = self._ctx_map[ctx.device_typeid][ctx.device_id]
-            if idx is not None:
-                return arr_list[idx]
+            if ctx.device_typeid < len(self._ctx_map):
+                ctx_list = self._ctx_map[ctx.device_typeid]
+                if ctx.device_id < len(ctx_list):
+                    idx = ctx_list[ctx.device_id]
+                    if idx is not None:
+                        return arr_list[idx]
             raise RuntimeError(
                 "Parameter %s was not initialized on context %s. "
                 "It was only initialized on %s."%(
@@ -203,7 +223,7 @@ class Parameter(object):
         """Finishes deferred initialization."""
         if not self._deferred_init:
             return
-        init, ctx, default_init = self._deferred_init
+        init, ctx, default_init, data = self._deferred_init
         self._deferred_init = ()
         assert self.shape is not None and np.prod(self.shape) > 0, \
             "Cannot initialize Parameter %s because it has " \
@@ -212,10 +232,11 @@ class Parameter(object):
                 self.name, str(self.shape))
 
         with autograd.pause():
-            data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
-                                 ctx=context.cpu())
-            initializer.create(default_init)(
-                initializer.InitDesc(self.name, {'__init__': init}), data)
+            if data is None:
+                data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
+                                     ctx=context.cpu())
+                initializer.create(default_init)(
+                    initializer.InitDesc(self.name, {'__init__': init}), data)
 
             self._init_impl(data, ctx)
 
@@ -306,14 +327,14 @@ class Parameter(object):
             ctx = [ctx]
         if init is None:
             init = default_init if self.init is None else self.init
-        if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
+        if not self.shape or np.prod(self.shape) <= 0:
             if self._allow_deferred_init:
-                self._deferred_init = (init, ctx, default_init)
+                self._deferred_init = (init, ctx, default_init, None)
                 return
             raise ValueError("Cannot initialize Parameter %s because it has " \
                              "invalid shape: %s."%(self.name, str(self.shape)))
 
-        self._deferred_init = (init, ctx, default_init)
+        self._deferred_init = (init, ctx, default_init, None)
         self._finish_deferred_init()
 
     def reset_ctx(self, ctx):
@@ -332,21 +353,25 @@ class Parameter(object):
             with autograd.pause():
                 self._init_impl(data, ctx)
         elif self._deferred_init:
-            init, _, default_init = self._deferred_init
-            self._deferred_init = (init, ctx, default_init)
+            init, _, default_init, data = self._deferred_init
+            self._deferred_init = (init, ctx, default_init, data)
         else:
             raise ValueError("Cannot reset context for Parameter %s because it "
                              "has not been initialized."%self.name)
 
 
     def set_data(self, data):
-        """Sets this parameter's value on all contexts to data."""
-        assert self._data is not None, \
-            "Parameter %s has not been initialized"%self.name
+        """Sets this parameter's value on all contexts."""
+        self.shape = data.shape
+
+        if self._data is None:
+            assert self._deferred_init is not None, \
+                "Parameter %s has not been initialized"%self.name
+            self._deferred_init = self._deferred_init[:3] + (data,)
+            return
+
         for arr in self.list_data():
             arr[:] = data
-        if not self.shape or np.prod(self.shape) <= 0:
-            self.shape = data.shape
 
     def data(self, ctx=None):
         """Returns a copy of this parameter on one context. Must have been
@@ -415,6 +440,24 @@ class Parameter(object):
                                    init=self.init)
         return self._var
 
+    def cast(self, dtype):
+        """Cast data and gradient of this Parameter to a new data type.
+
+        Parameters
+        ----------
+        dtype : str or numpy.dtype
+            The new data type.
+        """
+        self.dtype = dtype
+        if self._data is None:
+            return
+        with autograd.pause():
+            self._data = [i.astype(dtype) for i in self._data]
+            if self._grad is None:
+                return
+            self._grad = [i.astype(dtype) for i in self._grad]
+            autograd.mark_variables(self._data, self._grad, self.grad_req)
+
 
 class ParameterDict(object):
     """A dictionary managing a set of parameters.
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 80bb8e3..ea0e32f 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -326,16 +326,16 @@ class RNNCell(HybridRecurrentCell):
         self._activation = activation
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
-                                          dtype=None, init=i2h_weight_initializer,
+                                          init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
-                                          dtype=None, init=h2h_weight_initializer,
+                                          init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
-                                        dtype=None, init=i2h_bias_initializer,
+                                        init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
-                                        dtype=None, init=h2h_bias_initializer,
+                                        init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
@@ -434,16 +434,16 @@ class LSTMCell(HybridRecurrentCell):
         self._hidden_size = hidden_size
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
-                                          dtype=None, init=i2h_weight_initializer,
+                                          init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
-                                          dtype=None, init=h2h_weight_initializer,
+                                          init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
-                                        dtype=None, init=i2h_bias_initializer,
+                                        init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
-                                        dtype=None, init=h2h_bias_initializer,
+                                        init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
@@ -541,16 +541,16 @@ class GRUCell(HybridRecurrentCell):
         self._hidden_size = hidden_size
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
-                                          dtype=None, init=i2h_weight_initializer,
+                                          init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
-                                          dtype=None, init=h2h_weight_initializer,
+                                          init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
-                                        dtype=None, init=i2h_bias_initializer,
+                                        init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
-                                        dtype=None, init=h2h_bias_initializer,
+                                        init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 751f1fb..df9f78e 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -570,12 +570,20 @@ def test_fill_shape_deferred():
 def test_dtype():
     net = mx.gluon.model_zoo.vision.resnet18_v1()
     net.initialize()
-    net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+    net.cast('float64')
+    with mx.autograd.record():
+        y = net(mx.nd.ones((16, 3, 32, 32), dtype='float64'))
+        y.backward()
 
     net = mx.gluon.model_zoo.vision.resnet18_v1()
     net.initialize()
     net.hybridize()
-    net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+    net(mx.nd.ones((16, 3, 32, 32), dtype='float32'))
+
+    net.cast('float64')
+    net(mx.nd.ones((16, 3, 32, 32), dtype='float64'))
+
+    mx.nd.waitall()
 
 
 def test_fill_shape_load():

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].