You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/08 22:59:10 UTC
[GitHub] piiswrong closed pull request #8522: gluon with multiple data type
piiswrong closed pull request #8522: gluon with multiple data type
URL: https://github.com/apache/incubator-mxnet/pull/8522
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index d26e86f409..88a9f4d597 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -120,6 +120,7 @@ class Imperative {
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
+ std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<uint32_t> bwd_input_eid_;
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 73dbfc10fe..2546711246 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -21,6 +21,7 @@
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
import copy
+import warnings
from .. import symbol, ndarray, initializer
from ..symbol import Symbol
@@ -325,7 +326,7 @@ def __init__(self, prefix=None, params=None):
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
- self._cached_params = None
+ self._cached_op_args = None
self._out_format = None
self._in_format = None
self._active = False
@@ -363,34 +364,47 @@ def _get_graph(self, *args):
def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
+ input_idx = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out)
-
params = dict(self.collect_params().items())
- self._cached_params = [params.get(name, None) for name in out.list_inputs()]
- assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
- "Wrong number of inputs."
- name2pos = {var.name: i for i, var in enumerate(inputs)}
- self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
- if name not in params]
+ # verify graph inputs
+ expected_inputs = set(out.list_inputs())
+ for name in expected_inputs:
+ assert name in params or name in input_idx, \
+ "Unknown input to HybridBlock: %s"%name
+ for name, i in input_idx.items():
+ if name not in expected_inputs:
+ warnings.warn("The %d-th input to HybridBlock is not used by any "
+ "computation. Is this intended?"%i)
+ for name in params:
+ if name not in expected_inputs:
+ warnings.warn("Parameter %s is not used by any computation. "
+ "Is this intended?"%name)
+
+ self._cached_op_args = [(False, params[name]) if name in params
+ else (True, input_idx[name])
+ for name in out.list_inputs()]
+
+ def _finish_deferred_init(self, hybrid, *args):
+ self.infer_shape(*args)
+ self.infer_type(*args)
+ if hybrid:
+ for is_arg, i in self._cached_op_args:
+ if not is_arg:
+ i._finish_deferred_init()
+ else:
+ for _, i in self.params.items():
+ i._finish_deferred_init()
def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)
- try:
- cargs = [i.data() if i else None for i in self._cached_params]
- except DeferredInitializationError:
- self.infer_shape(*args)
- for i in self._cached_params:
- if i is not None:
- i._finish_deferred_init()
- cargs = [i.data() if i else None for i in self._cached_params]
-
args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"
- for i, j in self._in_idx:
- cargs[i] = args[j]
+ cargs = [args[i] if is_arg else i.data()
+ for is_arg, i in self._cached_op_args]
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
@@ -399,6 +413,7 @@ def _call_cached_op(self, *args):
def _clear_cached_op(self):
self._cached_graph = ()
self._cached_op = None
+ self._cached_op_args = None
def register_child(self, block):
if not isinstance(block, HybridBlock):
@@ -414,17 +429,25 @@ def hybridize(self, active=True):
self._active = active
super(HybridBlock, self).hybridize(active)
- def infer_shape(self, *args):
- """Infers shape of Parameters from inputs."""
+ def _infer_attrs(self, infer_fn, attr, *args):
+ """Generic infer attributes."""
inputs, out = self._get_graph(*args)
args, _ = _flatten(args)
- arg_shapes, _, aux_shapes = out.infer_shape(
- **{i.name: j.shape for i, j in zip(inputs, args)})
- sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
- sdict.update({name : shape for name, shape in \
- zip(out.list_auxiliary_states(), aux_shapes)})
+ arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
+ **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
+ sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
+ sdict.update({name : attr for name, attr in \
+ zip(out.list_auxiliary_states(), aux_attrs)})
for i in self.collect_params().values():
- i.shape = sdict[i.name]
+ setattr(i, attr, sdict[i.name])
+
+ def infer_shape(self, *args):
+ """Infers shape of Parameters from inputs."""
+ self._infer_attrs('infer_shape', 'shape', *args)
+
+ def infer_type(self, *args):
+ """Infers data type of Parameters from inputs."""
+ self._infer_attrs('infer_type', 'dtype', *args)
def export(self, path):
"""Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
@@ -462,15 +485,16 @@ def forward(self, x, *args):
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
- if self._active:
- return self._call_cached_op(x, *args)
try:
+ if self._active:
+ return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
- self.infer_shape(x, *args)
- for i in self.collect_params().values():
- i._finish_deferred_init()
- params = {i: j.data(ctx) for i, j in self._reg_params.items()}
+ self._finish_deferred_init(self._active, x, *args)
+
+ if self._active:
+ return self._call_cached_op(x, *args)
+ params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)
assert isinstance(x, Symbol), \
@@ -559,6 +583,11 @@ def __init__(self, outputs, inputs, params=None):
def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
+ try:
+ return self._call_cached_op(x, *args)
+ except DeferredInitializationError:
+ self._finish_deferred_init(True, x, *args)
+
return self._call_cached_op(x, *args)
assert isinstance(x, Symbol), \
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 906f03ec92..8034ab8415 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -185,11 +185,11 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True,
self._units = units
self._in_units = in_units
self.weight = self.params.get('weight', shape=(units, in_units),
- init=weight_initializer,
+ dtype=None, init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=(units,),
- init=bias_initializer,
+ dtype=None, init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
@@ -336,20 +336,20 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
self.in_channels = in_channels
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
- shape=(in_channels,), init=gamma_initializer,
- allow_deferred_init=True,
+ shape=(in_channels,), dtype=None,
+ init=gamma_initializer, allow_deferred_init=True,
differentiable=scale)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
- shape=(in_channels,), init=beta_initializer,
- allow_deferred_init=True,
+ shape=(in_channels,), dtype=None,
+ init=beta_initializer, allow_deferred_init=True,
differentiable=center)
self.running_mean = self.params.get('running_mean', grad_req='null',
- shape=(in_channels,),
+ shape=(in_channels,), dtype=None,
init=running_mean_initializer,
allow_deferred_init=True,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
- shape=(in_channels,),
+ shape=(in_channels,), dtype=None,
init=running_variance_initializer,
allow_deferred_init=True,
differentiable=False)
@@ -437,7 +437,7 @@ def __init__(self, input_dim, output_dim, dtype='float32',
self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
'dtype': dtype}
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
- init=weight_initializer,
+ dtype=None, init=weight_initializer,
allow_deferred_init=True)
def hybrid_forward(self, F, x, weight):
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 645de98ec0..0dd70697ab 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -113,11 +113,11 @@ def __init__(self, channels, kernel_size, strides, padding, dilation,
dshape[layout.find('C')] = in_channels
wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
self.weight = self.params.get('weight', shape=wshapes[1],
- init=weight_initializer,
+ dtype=None, init=weight_initializer,
allow_deferred_init=True)
if use_bias:
self.bias = self.params.get('bias', shape=wshapes[2],
- init=bias_initializer,
+ dtype=None, init=bias_initializer,
allow_deferred_init=True)
else:
self.bias = None
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c42fbaa1fa..27297b5329 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -306,7 +306,7 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
ctx = [ctx]
if init is None:
init = default_init if self.init is None else self.init
- if not self.shape or np.prod(self.shape) <= 0:
+ if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
if self._allow_deferred_init:
self._deferred_init = (init, ctx, default_init)
return
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index ea0e32faeb..80bb8e3fb8 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -326,16 +326,16 @@ def __init__(self, hidden_size, activation='tanh',
self._activation = activation
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
- init=i2h_weight_initializer,
+ dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
- init=h2h_weight_initializer,
+ dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
- init=i2h_bias_initializer,
+ dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
- init=h2h_bias_initializer,
+ dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
@@ -434,16 +434,16 @@ def __init__(self, hidden_size,
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
- init=i2h_weight_initializer,
+ dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
- init=h2h_weight_initializer,
+ dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
- init=i2h_bias_initializer,
+ dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
- init=h2h_bias_initializer,
+ dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
@@ -541,16 +541,16 @@ def __init__(self, hidden_size,
self._hidden_size = hidden_size
self._input_size = input_size
self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
- init=i2h_weight_initializer,
+ dtype=None, init=i2h_weight_initializer,
allow_deferred_init=True)
self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
- init=h2h_weight_initializer,
+ dtype=None, init=h2h_weight_initializer,
allow_deferred_init=True)
self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
- init=i2h_bias_initializer,
+ dtype=None, init=i2h_bias_initializer,
allow_deferred_init=True)
self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
- init=h2h_bias_initializer,
+ dtype=None, init=h2h_bias_initializer,
allow_deferred_init=True)
def state_info(self, batch_size=0):
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index eb99aabf11..60d66db485 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -62,11 +62,10 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
}
// construct backward graph
- std::vector<NodeEntry> ograd_entries;
{
- ograd_entries.reserve(fwd_graph_.outputs.size());
+ ograd_entries_.reserve(fwd_graph_.outputs.size());
for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) {
- ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
+ ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0});
}
std::vector<NodeEntry> xs;
@@ -77,7 +76,7 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
<< "There are no inputs in computation graph that require gradients.";
grad_graph_ = pass::Gradient(
- fwd_graph_, fwd_graph_.outputs, xs, ograd_entries,
+ fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_,
exec::AggregateGradient, nullptr, nullptr,
zero_ops, "_copy");
}
@@ -105,12 +104,12 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
std::make_shared<dmlc::any>(std::move(full_ref_count));
size_t num_forward_inputs = num_inputs();
- for (uint32_t i = 0; i < ograd_entries.size(); ++i) {
- if (!idx.exist(ograd_entries[i].node.get())) continue;
- auto eid = idx.entry_id(ograd_entries[i]);
+ size_t num_forward_outputs = num_outputs();
+ for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
+ if (!idx.exist(ograd_entries_[i].node.get())) continue;
+ auto eid = idx.entry_id(ograd_entries_[i]);
if (ref_count[eid] > 0) {
bwd_ograd_dep_.push_back(i);
- bwd_input_eid_.push_back(eid);
}
}
save_inputs_.resize(num_forward_inputs, false);
@@ -119,16 +118,14 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
if (ref_count[eid] > 0) {
save_inputs_[i] = true;
bwd_in_dep_.push_back(i);
- bwd_input_eid_.push_back(eid);
}
}
save_outputs_.resize(idx.outputs().size(), false);
- for (uint32_t i = 0; i < idx.outputs().size(); ++i) {
+ for (uint32_t i = 0; i < num_forward_outputs; ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (ref_count[eid] > 0) {
save_outputs_[i] = true;
bwd_out_dep_.push_back(i);
- bwd_input_eid_.push_back(eid);
}
}
}
@@ -242,9 +239,28 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
}
+ bwd_input_eid_.clear();
}
const auto& idx = g.indexed_graph();
+
+ if (bwd_input_eid_.size() != inputs.size()) {
+ bwd_input_eid_.clear();
+ for (const auto& i : bwd_ograd_dep_) {
+ auto eid = idx.entry_id(ograd_entries_[i]);
+ bwd_input_eid_.push_back(eid);
+ }
+ for (const auto& i : bwd_in_dep_) {
+ auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+ bwd_input_eid_.push_back(eid);
+ }
+ for (const auto& i : bwd_out_dep_) {
+ auto eid = idx.entry_id(idx.outputs()[i]);
+ bwd_input_eid_.push_back(eid);
+ }
+ CHECK_EQ(inputs.size(), bwd_input_eid_.size());
+ }
+
size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index df0af34dfe..751f1fbd96 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -567,6 +567,17 @@ def test_fill_shape_deferred():
assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]
+def test_dtype():
+ net = mx.gluon.model_zoo.vision.resnet18_v1()
+ net.initialize()
+ net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+ net = mx.gluon.model_zoo.vision.resnet18_v1()
+ net.initialize()
+ net.hybridize()
+ net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+
def test_fill_shape_load():
ctx = mx.context.current_context()
net1 = nn.HybridSequential()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services