You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/21 22:41:20 UTC
[incubator-mxnet] branch master updated: Add cast to Block and
Parameter. Implicit dtype casting is removed. (#8735)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 1852e2f Add cast to Block and Parameter. Implicit dtype casting is removed. (#8735)
1852e2f is described below
commit 1852e2f47d68bb4c2373a359a2a8671b59cd14e5
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Tue Nov 21 14:41:17 2017 -0800
Add cast to Block and Parameter. Implicit dtype casting is removed. (#8735)
* fix
* fix
* fix
* fix
* Update parameter.py
---
python/mxnet/gluon/block.py | 18 +++++++-
python/mxnet/gluon/nn/basic_layers.py | 24 +++++++----
python/mxnet/gluon/nn/conv_layers.py | 4 +-
python/mxnet/gluon/parameter.py | 81 +++++++++++++++++++++++++++--------
python/mxnet/gluon/rnn/rnn_cell.py | 24 +++++------
tests/python/unittest/test_gluon.py | 12 +++++-
6 files changed, 118 insertions(+), 45 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 2546711..466f87f 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -286,6 +286,19 @@ class Block(object):
for cld in self._children:
cld.hybridize(active)
+ def cast(self, dtype):
+ """Cast this Block to use another data type.
+
+ Parameters
+ ----------
+ dtype : str or numpy.dtype
+ The new data type.
+ """
+ for child in self._children:
+ child.cast(dtype)
+ for _, param in self.params.items():
+ param.cast(dtype)
+
def __call__(self, *args):
"""Calls forward. Only accepts positional arguments."""
return self.forward(*args)
@@ -388,7 +401,6 @@ class HybridBlock(Block):
def _finish_deferred_init(self, hybrid, *args):
self.infer_shape(*args)
- self.infer_type(*args)
if hybrid:
for is_arg, i in self._cached_op_args:
if not is_arg:
@@ -429,6 +441,10 @@ class HybridBlock(Block):
self._active = active
super(HybridBlock, self).hybridize(active)
+ def cast(self, dtype):
+ self._clear_cached_op()
+ super(HybridBlock, self).cast(dtype)
+
def _infer_attrs(self, infer_fn, attr, *args):
"""Generic infer attributes."""
inputs, out = self._get_graph(*args)
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 15c8285..c0b4b52 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -22,6 +22,7 @@ __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Activation',
'Dropout', 'BatchNorm', 'LeakyReLU', 'Embedding', 'Flatten',
'Lambda', 'HybridLambda']
import warnings
+import numpy as np
from ..block import Block, HybridBlock
from ..utils import _indent
@@ -185,11 +186,11 @@ class Dense(HybridBlock):
self._units = units
self._in_units = in_units
self.weight = self.params.get('weight', shape=(units, in_units),
- dtype=None, init=weight_initializer,
+ init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=(units,),
- dtype=None, init=bias_initializer,
+ init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
@@ -336,24 +337,29 @@ class BatchNorm(HybridBlock):
self.in_channels = in_channels
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
- shape=(in_channels,), dtype=None,
- init=gamma_initializer, allow_deferred_init=True,
+ shape=(in_channels,), init=gamma_initializer,
+ allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
- shape=(in_channels,), dtype=None,
- init=beta_initializer, allow_deferred_init=True,
+ shape=(in_channels,), init=beta_initializer,
+ allow_deferred_init=True,
differentiable=center)
self.running_mean = self.params.get('running_mean', grad_req='null',
- shape=(in_channels,), dtype=None,
+ shape=(in_channels,),
init=running_mean_initializer,
allow_deferred_init=True,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
- shape=(in_channels,), dtype=None,
+ shape=(in_channels,),
init=running_variance_initializer,
allow_deferred_init=True,
differentiable=False)
+ def cast(self, dtype):
+ if np.dtype(dtype).name == 'float16':
+ dtype = 'float32'
+ super(BatchNorm, self).cast(dtype)
+
def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
return F.BatchNorm(x, gamma, beta, running_mean, running_var,
name='fwd', **self._kwargs)
@@ -437,7 +443,7 @@ class Embedding(HybridBlock):
self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
'dtype': dtype}
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
- dtype=None, init=weight_initializer,
+ init=weight_initializer,
allow_deferred_init=True)
def hybrid_forward(self, F, x, weight):
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 0dd7069..645de98 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -113,11 +113,11 @@ class _Conv(HybridBlock):
dshape[layout.find('C')] = in_channels
wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
self.weight = self.params.get('weight', shape=wshapes[1],
- dtype=None, init=weight_initializer,
+ init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=wshapes[2],
- dtype=None, init=bias_initializer,
+ init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 27297b5..537d636 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -107,8 +107,8 @@ class Parameter(object):
self._differentiable = differentiable
self._allow_deferred_init = allow_deferred_init
self._grad_req = None
+ self._shape = shape
self.name = name
- self.shape = shape
self.dtype = dtype
self.lr_mult = lr_mult
self.wd_mult = wd_mult
@@ -138,6 +138,23 @@ class Parameter(object):
elif self._data is not None:
self._init_grad()
+ @property
+ def shape(self):
+ return self._shape
+
+ @shape.setter
+ def shape(self, new_shape):
+ if self._shape is None:
+ self._shape = new_shape
+ return
+
+ assert len(self._shape) == len(new_shape) and \
+ all(j == 0 or i == j for i, j in zip(new_shape, self._shape)), \
+ "Expected shape %s is incompatible with given shape %s."%(
+ str(new_shape), str(self._shape))
+
+ self._shape = new_shape
+
def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
if ctx is list:
@@ -147,9 +164,12 @@ class Parameter(object):
return arr_list[0]
else:
ctx = context.current_context()
- idx = self._ctx_map[ctx.device_typeid][ctx.device_id]
- if idx is not None:
- return arr_list[idx]
+ if ctx.device_typeid < len(self._ctx_map):
+ ctx_list = self._ctx_map[ctx.device_typeid]
+ if ctx.device_id < len(ctx_list):
+ idx = ctx_list[ctx.device_id]
+ if idx is not None:
+ return arr_list[idx]
raise RuntimeError(
"Parameter %s was not initialized on context %s. "
"It was only initialized on %s."%(
@@ -203,7 +223,7 @@ class Parameter(object):
"""Finishes deferred initialization."""
if not self._deferred_init:
return
- init, ctx, default_init = self._deferred_init
+ init, ctx, default_init, data = self._deferred_init
self._deferred_init = ()
assert self.shape is not None and np.prod(self.shape) > 0, \
"Cannot initialize Parameter %s because it has " \
@@ -212,10 +232,11 @@ class Parameter(object):
self.name, str(self.shape))
with autograd.pause():
- data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
- ctx=context.cpu())
- initializer.create(default_init)(
- initializer.InitDesc(self.name, {'__init__': init}), data)
+ if data is None:
+ data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
+ ctx=context.cpu())
+ initializer.create(default_init)(
+ initializer.InitDesc(self.name, {'__init__': init}), data)
self._init_impl(data, ctx)
@@ -306,14 +327,14 @@ class Parameter(object):
ctx = [ctx]
if init is None:
init = default_init if self.init is None else self.init
- if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
+ if not self.shape or np.prod(self.shape) <= 0:
if self._allow_deferred_init:
- self._deferred_init = (init, ctx, default_init)
+ self._deferred_init = (init, ctx, default_init, None)
return
raise ValueError("Cannot initialize Parameter %s because it has " \
"invalid shape: %s."%(self.name, str(self.shape)))
- self._deferred_init = (init, ctx, default_init)
+ self._deferred_init = (init, ctx, default_init, None)
self._finish_deferred_init()
def reset_ctx(self, ctx):
@@ -332,21 +353,25 @@ class Parameter(object):
with autograd.pause():
self._init_impl(data, ctx)
elif self._deferred_init:
- init, _, default_init = self._deferred_init
- self._deferred_init = (init, ctx, default_init)
+ init, _, default_init, data = self._deferred_init
+ self._deferred_init = (init, ctx, default_init, data)
else:
raise ValueError("Cannot reset context for Parameter %s because it "
"has not been initialized."%self.name)
def set_data(self, data):
- """Sets this parameter's value on all contexts to data."""
- assert self._data is not None, \
- "Parameter %s has not been initialized"%self.name
+ """Sets this parameter's value on all contexts."""
+ self.shape = data.shape
+
+ if self._data is None:
+ assert self._deferred_init is not None, \
+ "Parameter %s has not been initialized"%self.name
+ self._deferred_init = self._deferred_init[:3] + (data,)
+ return
+
for arr in self.list_data():
arr[:] = data
- if not self.shape or np.prod(self.shape) <= 0:
- self.shape = data.shape
def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
@@ -415,6 +440,24 @@ class Parameter(object):
init=self.init)
return self._var
+ def cast(self, dtype):
+ """Cast data and gradient of this Parameter to a new data type.
+
+ Parameters
+ ----------
+ dtype : str or numpy.dtype
+ The new data type.
+ """
+ self.dtype = dtype
+ if self._data is None:
+ return
+ with autograd.pause():
+ self._data = [i.astype(dtype) for i in self._data]
+ if self._grad is None:
+ return
+ self._grad = [i.astype(dtype) for i in self._grad]
+ autograd.mark_variables(self._data, self._grad, self.grad_req)
+
class ParameterDict(object):
"""A dictionary managing a set of parameters.
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 80bb8e3..ea0e32f 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -326,16 +326,16 @@ class RNNCell(HybridRecurrentCell):
self._activation = activation
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
- dtype=None, init=i2h_weight_initializer,
+ init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
- dtype=None, init=h2h_weight_initializer,
+ init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
- dtype=None, init=i2h_bias_initializer,
+ init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
- dtype=None, init=h2h_bias_initializer,
+ init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
@@ -434,16 +434,16 @@ class LSTMCell(HybridRecurrentCell):
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
- dtype=None, init=i2h_weight_initializer,
+ init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
- dtype=None, init=h2h_weight_initializer,
+ init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
- dtype=None, init=i2h_bias_initializer,
+ init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
- dtype=None, init=h2h_bias_initializer,
+ init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
@@ -541,16 +541,16 @@ class GRUCell(HybridRecurrentCell):
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
- dtype=None, init=i2h_weight_initializer,
+ init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
- dtype=None, init=h2h_weight_initializer,
+ init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
- dtype=None, init=i2h_bias_initializer,
+ init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
- dtype=None, init=h2h_bias_initializer,
+ init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 751f1fb..df9f78e 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -570,12 +570,20 @@ def test_fill_shape_deferred():
def test_dtype():
net = mx.gluon.model_zoo.vision.resnet18_v1()
net.initialize()
- net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+ net.cast('float64')
+ with mx.autograd.record():
+ y = net(mx.nd.ones((16, 3, 32, 32), dtype='float64'))
+ y.backward()
net = mx.gluon.model_zoo.vision.resnet18_v1()
net.initialize()
net.hybridize()
- net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+ net(mx.nd.ones((16, 3, 32, 32), dtype='float32'))
+
+ net.cast('float64')
+ net(mx.nd.ones((16, 3, 32, 32), dtype='float64'))
+
+ mx.nd.waitall()
def test_fill_shape_load():
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].