You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/08 22:59:08 UTC
[incubator-mxnet] branch master updated: gluon with multiple data
type (#8522)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8dc09ec gluon with multiple data type (#8522)
8dc09ec is described below
commit 8dc09ecc1f39e24e50cba445fcad0a6980ea5695
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Wed Nov 8 14:59:05 2017 -0800
gluon with multiple data type (#8522)
* gluon with multiple data type
* fix
* fix
---
include/mxnet/imperative.h | 1 +
python/mxnet/gluon/block.py | 95 +++++++++++++++++++++++------------
python/mxnet/gluon/nn/basic_layers.py | 18 +++----
python/mxnet/gluon/nn/conv_layers.py | 4 +-
python/mxnet/gluon/parameter.py | 2 +-
python/mxnet/gluon/rnn/rnn_cell.py | 24 ++++-----
src/imperative/cached_op.cc | 38 ++++++++++----
tests/python/unittest/test_gluon.py | 11 ++++
8 files changed, 125 insertions(+), 68 deletions(-)
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index d26e86f..88a9f4d 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -120,6 +120,7 @@ class Imperative {
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
+ std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<uint32_t> bwd_input_eid_;
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 73dbfc1..2546711 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -21,6 +21,7 @@
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
import copy
+import warnings
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
@@ -325,7 +326,7 @@ class HybridBlock(Block):
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
- self._cached_params = None
+ self._cached_op_args = None
self._out_format = None
self._in_format = None
self._active = False
@@ -363,34 +364,47 @@ class HybridBlock(Block):
def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
+ input_idx = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out)
-
params = dict(self.collect_params().items())
- self._cached_params = [params.get(name, None) for name in out.list_inputs()]
- assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
- "Wrong number of inputs."
- name2pos = {var.name: i for i, var in enumerate(inputs)}
- self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
- if name not in params]
+ # verify graph inputs
+ expected_inputs = set(out.list_inputs())
+ for name in expected_inputs:
+ assert name in params or name in input_idx, \
+ "Unknown input to HybridBlock: %s"%name
+ for name, i in input_idx.items():
+ if name not in expected_inputs:
+ warnings.warn("The %d-th input to HybridBlock is not used by any "
+ "computation. Is this intended?"%i)
+ for name in params:
+ if name not in expected_inputs:
+ warnings.warn("Parameter %s is not used by any computation. "
+ "Is this intended?"%name)
+
+ self._cached_op_args = [(False, params[name]) if name in params
+ else (True, input_idx[name])
+ for name in out.list_inputs()]
+
+ def _finish_deferred_init(self, hybrid, *args):
+ self.infer_shape(*args)
+ self.infer_type(*args)
+ if hybrid:
+ for is_arg, i in self._cached_op_args:
+ if not is_arg:
+ i._finish_deferred_init()
+ else:
+ for _, i in self.params.items():
+ i._finish_deferred_init()
def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)
- try:
- cargs = [i.data() if i else None for i in self._cached_params]
- except DeferredInitializationError:
- self.infer_shape(*args)
- for i in self._cached_params:
- if i is not None:
- i._finish_deferred_init()
- cargs = [i.data() if i else None for i in self._cached_params]
-
args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"
- for i, j in self._in_idx:
- cargs[i] = args[j]
+ cargs = [args[i] if is_arg else i.data()
+ for is_arg, i in self._cached_op_args]
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
@@ -399,6 +413,7 @@ class HybridBlock(Block):
def _clear_cached_op(self):
self._cached_graph = ()
self._cached_op = None
+ self._cached_op_args = None
def register_child(self, block):
if not isinstance(block, HybridBlock):
@@ -414,17 +429,25 @@ class HybridBlock(Block):
self._active = active
super(HybridBlock, self).hybridize(active)
- def infer_shape(self, *args):
- """Infers shape of Parameters from inputs."""
+ def _infer_attrs(self, infer_fn, attr, *args):
+ """Generic infer attributes."""
inputs, out = self._get_graph(*args)
args, _ = _flatten(args)
- arg_shapes, _, aux_shapes = out.infer_shape(
- **{i.name: j.shape for i, j in zip(inputs, args)})
- sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
- sdict.update({name : shape for name, shape in \
- zip(out.list_auxiliary_states(), aux_shapes)})
+ arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
+ **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
+ sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
+ sdict.update({name : attr for name, attr in \
+ zip(out.list_auxiliary_states(), aux_attrs)})
for i in self.collect_params().values():
- i.shape = sdict[i.name]
+ setattr(i, attr, sdict[i.name])
+
+ def infer_shape(self, *args):
+ """Infers shape of Parameters from inputs."""
+ self._infer_attrs('infer_shape', 'shape', *args)
+
+ def infer_type(self, *args):
+ """Infers data type of Parameters from inputs."""
+ self._infer_attrs('infer_type', 'dtype', *args)
def export(self, path):
"""Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
@@ -462,15 +485,16 @@ class HybridBlock(Block):
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
- if self._active:
- return self._call_cached_op(x, *args)
try:
+ if self._active:
+ return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
- self.infer_shape(x, *args)
- for i in self.collect_params().values():
- i._finish_deferred_init()
- params = {i: j.data(ctx) for i, j in self._reg_params.items()}
+ self._finish_deferred_init(self._active, x, *args)
+
+ if self._active:
+ return self._call_cached_op(x, *args)
+ params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)
assert isinstance(x, Symbol), \
@@ -559,6 +583,11 @@ class SymbolBlock(HybridBlock):
def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
+ try:
+ return self._call_cached_op(x, *args)
+ except DeferredInitializationError:
+ self._finish_deferred_init(True, x, *args)
+
return self._call_cached_op(x, *args)
assert isinstance(x, Symbol), \
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 906f03e..8034ab8 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -185,11 +185,11 @@ class Dense(HybridBlock):
self._units = units
self._in_units = in_units
self.weight = self.params.get('weight', shape=(units, in_units),
- init=weight_initializer,
+ dtype=None, init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=(units,),
- init=bias_initializer,
+ dtype=None, init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
@@ -336,20 +336,20 @@ class BatchNorm(HybridBlock):
self.in_channels = in_channels
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
- shape=(in_channels,), init=gamma_initializer,
- allow_deferred_init=True,
+ shape=(in_channels,), dtype=None,
+ init=gamma_initializer, allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
- shape=(in_channels,), init=beta_initializer,
- allow_deferred_init=True,
+ shape=(in_channels,), dtype=None,
+ init=beta_initializer, allow_deferred_init=True,
differentiable=center)
self.running_mean = self.params.get('running_mean', grad_req='null',
- shape=(in_channels,),
+ shape=(in_channels,), dtype=None,
init=running_mean_initializer,
allow_deferred_init=True,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
- shape=(in_channels,),
+ shape=(in_channels,), dtype=None,
init=running_variance_initializer,
allow_deferred_init=True,
differentiable=False)
@@ -437,7 +437,7 @@ class Embedding(HybridBlock):
self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
'dtype': dtype}
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
- init=weight_initializer,
+ dtype=None, init=weight_initializer,
allow_deferred_init=True)
def hybrid_forward(self, F, x, weight):
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 645de98..0dd7069 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -113,11 +113,11 @@ class _Conv(HybridBlock):
dshape[layout.find('C')] = in_channels
wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
self.weight = self.params.get('weight', shape=wshapes[1],
- init=weight_initializer,
+ dtype=None, init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=wshapes[2],
- init=bias_initializer,
+ dtype=None, init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c42fbaa..27297b5 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -306,7 +306,7 @@ class Parameter(object):
ctx = [ctx]
if init is None:
init = default_init if self.init is None else self.init
- if not self.shape or np.prod(self.shape) <= 0:
+ if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
if self._allow_deferred_init:
self._deferred_init = (init, ctx, default_init)
return
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index ea0e32f..80bb8e3 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -326,16 +326,16 @@ class RNNCell(HybridRecurrentCell):
self._activation = activation
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
- init=i2h_weight_initializer,
+ dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
- init=h2h_weight_initializer,
+ dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
- init=i2h_bias_initializer,
+ dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
- init=h2h_bias_initializer,
+ dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
@@ -434,16 +434,16 @@ class LSTMCell(HybridRecurrentCell):
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
- init=i2h_weight_initializer,
+ dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
- init=h2h_weight_initializer,
+ dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
- init=i2h_bias_initializer,
+ dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
- init=h2h_bias_initializer,
+ dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
@@ -541,16 +541,16 @@ class GRUCell(HybridRecurrentCell):
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
- init=i2h_weight_initializer,
+ dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
- init=h2h_weight_initializer,
+ dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
- init=i2h_bias_initializer,
+ dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
- init=h2h_bias_initializer,
+ dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index eb99aab..60d66db 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -62,11 +62,10 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
}
// construct backward graph
- std::vector<NodeEntry> ograd_entries;
{
- ograd_entries.reserve(fwd_graph_.outputs.size());
+ ograd_entries_.reserve(fwd_graph_.outputs.size());
for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) {
- ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
+ ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0});
}
std::vector<NodeEntry> xs;
@@ -77,7 +76,7 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
<< "There are no inputs in computation graph that require gradients.";
grad_graph_ = pass::Gradient(
- fwd_graph_, fwd_graph_.outputs, xs, ograd_entries,
+ fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_,
exec::AggregateGradient, nullptr, nullptr,
zero_ops, "_copy");
}
@@ -105,12 +104,12 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
std::make_shared<dmlc::any>(std::move(full_ref_count));
size_t num_forward_inputs = num_inputs();
- for (uint32_t i = 0; i < ograd_entries.size(); ++i) {
- if (!idx.exist(ograd_entries[i].node.get())) continue;
- auto eid = idx.entry_id(ograd_entries[i]);
+ size_t num_forward_outputs = num_outputs();
+ for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
+ if (!idx.exist(ograd_entries_[i].node.get())) continue;
+ auto eid = idx.entry_id(ograd_entries_[i]);
if (ref_count[eid] > 0) {
bwd_ograd_dep_.push_back(i);
- bwd_input_eid_.push_back(eid);
}
}
save_inputs_.resize(num_forward_inputs, false);
@@ -119,16 +118,14 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
if (ref_count[eid] > 0) {
save_inputs_[i] = true;
bwd_in_dep_.push_back(i);
- bwd_input_eid_.push_back(eid);
}
}
save_outputs_.resize(idx.outputs().size(), false);
- for (uint32_t i = 0; i < idx.outputs().size(); ++i) {
+ for (uint32_t i = 0; i < num_forward_outputs; ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (ref_count[eid] > 0) {
save_outputs_[i] = true;
bwd_out_dep_.push_back(i);
- bwd_input_eid_.push_back(eid);
}
}
}
@@ -242,9 +239,28 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
}
+ bwd_input_eid_.clear();
}
const auto& idx = g.indexed_graph();
+
+ if (bwd_input_eid_.size() != inputs.size()) {
+ bwd_input_eid_.clear();
+ for (const auto& i : bwd_ograd_dep_) {
+ auto eid = idx.entry_id(ograd_entries_[i]);
+ bwd_input_eid_.push_back(eid);
+ }
+ for (const auto& i : bwd_in_dep_) {
+ auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+ bwd_input_eid_.push_back(eid);
+ }
+ for (const auto& i : bwd_out_dep_) {
+ auto eid = idx.entry_id(idx.outputs()[i]);
+ bwd_input_eid_.push_back(eid);
+ }
+ CHECK_EQ(inputs.size(), bwd_input_eid_.size());
+ }
+
size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index df0af34..751f1fb 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -567,6 +567,17 @@ def test_fill_shape_deferred():
assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]
+def test_dtype():
+ net = mx.gluon.model_zoo.vision.resnet18_v1()
+ net.initialize()
+ net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+ net = mx.gluon.model_zoo.vision.resnet18_v1()
+ net.initialize()
+ net.hybridize()
+ net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+
def test_fill_shape_load():
ctx = mx.context.current_context()
net1 = nn.HybridSequential()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].