You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/08 22:59:08 UTC

[incubator-mxnet] branch master updated: gluon with multiple data type (#8522)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8dc09ec  gluon with multiple data type (#8522)
8dc09ec is described below

commit 8dc09ecc1f39e24e50cba445fcad0a6980ea5695
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Wed Nov 8 14:59:05 2017 -0800

    gluon with multiple data type (#8522)
    
    * gluon with multiple data type
    
    * fix
    
    * fix
---
 include/mxnet/imperative.h            |  1 +
 python/mxnet/gluon/block.py           | 95 +++++++++++++++++++++++------------
 python/mxnet/gluon/nn/basic_layers.py | 18 +++----
 python/mxnet/gluon/nn/conv_layers.py  |  4 +-
 python/mxnet/gluon/parameter.py       |  2 +-
 python/mxnet/gluon/rnn/rnn_cell.py    | 24 ++++-----
 src/imperative/cached_op.cc           | 38 ++++++++++----
 tests/python/unittest/test_gluon.py   | 11 ++++
 8 files changed, 125 insertions(+), 68 deletions(-)

diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index d26e86f..88a9f4d 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -120,6 +120,7 @@ class Imperative {
     nnvm::Graph fwd_graph_;
     nnvm::Graph grad_graph_;
     nnvm::Graph full_graph_;
+    std::vector<nnvm::NodeEntry> ograd_entries_;
     std::vector<bool> curr_grad_req_;
     std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
     std::vector<uint32_t> bwd_input_eid_;
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 73dbfc1..2546711 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -21,6 +21,7 @@
 __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 
 import copy
+import warnings
 
 from .. import symbol, ndarray, initializer
 from ..symbol import Symbol
@@ -325,7 +326,7 @@ class HybridBlock(Block):
         self._reg_params = {}
         self._cached_graph = ()
         self._cached_op = None
-        self._cached_params = None
+        self._cached_op_args = None
         self._out_format = None
         self._in_format = None
         self._active = False
@@ -363,34 +364,47 @@ class HybridBlock(Block):
 
     def _build_cache(self, *args):
         inputs, out = self._get_graph(*args)
+        input_idx = {var.name: i for i, var in enumerate(inputs)}
         self._cached_op = ndarray.CachedOp(out)
-
         params = dict(self.collect_params().items())
-        self._cached_params = [params.get(name, None) for name in out.list_inputs()]
-        assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
-            "Wrong number of inputs."
 
-        name2pos = {var.name: i for i, var in enumerate(inputs)}
-        self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
-                        if name not in params]
+        # verify graph inputs
+        expected_inputs = set(out.list_inputs())
+        for name in expected_inputs:
+            assert name in params or name in input_idx, \
+                "Unknown input to HybridBlock: %s"%name
+        for name, i in input_idx.items():
+            if name not in expected_inputs:
+                warnings.warn("The %d-th input to HybridBlock is not used by any "
+                              "computation. Is this intended?"%i)
+        for name in params:
+            if name not in expected_inputs:
+                warnings.warn("Parameter %s is not used by any computation. "
+                              "Is this intended?"%name)
+
+        self._cached_op_args = [(False, params[name]) if name in params
+                                else (True, input_idx[name])
+                                for name in out.list_inputs()]
+
+    def _finish_deferred_init(self, hybrid, *args):
+        self.infer_shape(*args)
+        self.infer_type(*args)
+        if hybrid:
+            for is_arg, i in self._cached_op_args:
+                if not is_arg:
+                    i._finish_deferred_init()
+        else:
+            for _, i in self.params.items():
+                i._finish_deferred_init()
 
     def _call_cached_op(self, *args):
         if self._cached_op is None:
             self._build_cache(*args)
 
-        try:
-            cargs = [i.data() if i else None for i in self._cached_params]
-        except DeferredInitializationError:
-            self.infer_shape(*args)
-            for i in self._cached_params:
-                if i is not None:
-                    i._finish_deferred_init()
-            cargs = [i.data() if i else None for i in self._cached_params]
-
         args, fmt = _flatten(args)
         assert fmt == self._in_format, "Invalid input format"
-        for i, j in self._in_idx:
-            cargs[i] = args[j]
+        cargs = [args[i] if is_arg else i.data()
+                 for is_arg, i in self._cached_op_args]
         out = self._cached_op(*cargs)
         if isinstance(out, NDArray):
             out = [out]
@@ -399,6 +413,7 @@ class HybridBlock(Block):
     def _clear_cached_op(self):
         self._cached_graph = ()
         self._cached_op = None
+        self._cached_op_args = None
 
     def register_child(self, block):
         if not isinstance(block, HybridBlock):
@@ -414,17 +429,25 @@ class HybridBlock(Block):
         self._active = active
         super(HybridBlock, self).hybridize(active)
 
-    def infer_shape(self, *args):
-        """Infers shape of Parameters from inputs."""
+    def _infer_attrs(self, infer_fn, attr, *args):
+        """Generic infer attributes."""
         inputs, out = self._get_graph(*args)
         args, _ = _flatten(args)
-        arg_shapes, _, aux_shapes = out.infer_shape(
-            **{i.name: j.shape for i, j in zip(inputs, args)})
-        sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
-        sdict.update({name : shape for name, shape in \
-                      zip(out.list_auxiliary_states(), aux_shapes)})
+        arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
+            **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
+        sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
+        sdict.update({name : attr for name, attr in \
+                      zip(out.list_auxiliary_states(), aux_attrs)})
         for i in self.collect_params().values():
-            i.shape = sdict[i.name]
+            setattr(i, attr, sdict[i.name])
+
+    def infer_shape(self, *args):
+        """Infers shape of Parameters from inputs."""
+        self._infer_attrs('infer_shape', 'shape', *args)
+
+    def infer_type(self, *args):
+        """Infers data type of Parameters from inputs."""
+        self._infer_attrs('infer_type', 'dtype', *args)
 
     def export(self, path):
         """Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
@@ -462,15 +485,16 @@ class HybridBlock(Block):
         :py:class:`NDArray` or :py:class:`Symbol`."""
         if isinstance(x, NDArray):
             with x.context as ctx:
-                if self._active:
-                    return self._call_cached_op(x, *args)
                 try:
+                    if self._active:
+                        return self._call_cached_op(x, *args)
                     params = {i: j.data(ctx) for i, j in self._reg_params.items()}
                 except DeferredInitializationError:
-                    self.infer_shape(x, *args)
-                    for i in self.collect_params().values():
-                        i._finish_deferred_init()
-                    params = {i: j.data(ctx) for i, j in self._reg_params.items()}
+                    self._finish_deferred_init(self._active, x, *args)
+
+                if self._active:
+                    return self._call_cached_op(x, *args)
+                params = {i: j.data(ctx) for i, j in self._reg_params.items()}
                 return self.hybrid_forward(ndarray, x, *args, **params)
 
         assert isinstance(x, Symbol), \
@@ -559,6 +583,11 @@ class SymbolBlock(HybridBlock):
     def forward(self, x, *args):
         if isinstance(x, NDArray):
             with x.context:
+                try:
+                    return self._call_cached_op(x, *args)
+                except DeferredInitializationError:
+                    self._finish_deferred_init(True, x, *args)
+
                 return self._call_cached_op(x, *args)
 
         assert isinstance(x, Symbol), \
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 906f03e..8034ab8 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -185,11 +185,11 @@ class Dense(HybridBlock):
             self._units = units
             self._in_units = in_units
             self.weight = self.params.get('weight', shape=(units, in_units),
-                                          init=weight_initializer,
+                                          dtype=None, init=weight_initializer,
                                           allow_deferred_init=True)
             if use_bias:
                 self.bias = self.params.get('bias', shape=(units,),
-                                            init=bias_initializer,
+                                            dtype=None, init=bias_initializer,
                                             allow_deferred_init=True)
             else:
                 self.bias = None
@@ -336,20 +336,20 @@ class BatchNorm(HybridBlock):
             self.in_channels = in_channels
 
         self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
-                                     shape=(in_channels,), init=gamma_initializer,
-                                     allow_deferred_init=True,
+                                     shape=(in_channels,), dtype=None,
+                                     init=gamma_initializer, allow_deferred_init=True,
                                      differentiable=scale)
         self.beta = self.params.get('beta', grad_req='write' if center else 'null',
-                                    shape=(in_channels,), init=beta_initializer,
-                                    allow_deferred_init=True,
+                                    shape=(in_channels,), dtype=None,
+                                    init=beta_initializer, allow_deferred_init=True,
                                     differentiable=center)
         self.running_mean = self.params.get('running_mean', grad_req='null',
-                                            shape=(in_channels,),
+                                            shape=(in_channels,), dtype=None,
                                             init=running_mean_initializer,
                                             allow_deferred_init=True,
                                             differentiable=False)
         self.running_var = self.params.get('running_var', grad_req='null',
-                                           shape=(in_channels,),
+                                           shape=(in_channels,), dtype=None,
                                            init=running_variance_initializer,
                                            allow_deferred_init=True,
                                            differentiable=False)
@@ -437,7 +437,7 @@ class Embedding(HybridBlock):
         self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
                         'dtype': dtype}
         self.weight = self.params.get('weight', shape=(input_dim, output_dim),
-                                      init=weight_initializer,
+                                      dtype=None, init=weight_initializer,
                                       allow_deferred_init=True)
 
     def hybrid_forward(self, F, x, weight):
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 645de98..0dd7069 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -113,11 +113,11 @@ class _Conv(HybridBlock):
             dshape[layout.find('C')] = in_channels
             wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
             self.weight = self.params.get('weight', shape=wshapes[1],
-                                          init=weight_initializer,
+                                          dtype=None, init=weight_initializer,
                                           allow_deferred_init=True)
             if use_bias:
                 self.bias = self.params.get('bias', shape=wshapes[2],
-                                            init=bias_initializer,
+                                            dtype=None, init=bias_initializer,
                                             allow_deferred_init=True)
             else:
                 self.bias = None
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c42fbaa..27297b5 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -306,7 +306,7 @@ class Parameter(object):
             ctx = [ctx]
         if init is None:
             init = default_init if self.init is None else self.init
-        if not self.shape or np.prod(self.shape) <= 0:
+        if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
             if self._allow_deferred_init:
                 self._deferred_init = (init, ctx, default_init)
                 return
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index ea0e32f..80bb8e3 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -326,16 +326,16 @@ class RNNCell(HybridRecurrentCell):
         self._activation = activation
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
-                                          init=i2h_weight_initializer,
+                                          dtype=None, init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
-                                          init=h2h_weight_initializer,
+                                          dtype=None, init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
-                                        init=i2h_bias_initializer,
+                                        dtype=None, init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
-                                        init=h2h_bias_initializer,
+                                        dtype=None, init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
@@ -434,16 +434,16 @@ class LSTMCell(HybridRecurrentCell):
         self._hidden_size = hidden_size
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
-                                          init=i2h_weight_initializer,
+                                          dtype=None, init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
-                                          init=h2h_weight_initializer,
+                                          dtype=None, init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
-                                        init=i2h_bias_initializer,
+                                        dtype=None, init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
-                                        init=h2h_bias_initializer,
+                                        dtype=None, init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
@@ -541,16 +541,16 @@ class GRUCell(HybridRecurrentCell):
         self._hidden_size = hidden_size
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
-                                          init=i2h_weight_initializer,
+                                          dtype=None, init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
-                                          init=h2h_weight_initializer,
+                                          dtype=None, init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
-                                        init=i2h_bias_initializer,
+                                        dtype=None, init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
-                                        init=h2h_bias_initializer,
+                                        dtype=None, init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index eb99aab..60d66db 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -62,11 +62,10 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
   }
 
   // construct backward graph
-  std::vector<NodeEntry> ograd_entries;
   {
-    ograd_entries.reserve(fwd_graph_.outputs.size());
+    ograd_entries_.reserve(fwd_graph_.outputs.size());
     for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) {
-      ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
+      ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0});
     }
 
     std::vector<NodeEntry> xs;
@@ -77,7 +76,7 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
         << "There are no inputs in computation graph that require gradients.";
 
     grad_graph_ = pass::Gradient(
-        fwd_graph_, fwd_graph_.outputs, xs, ograd_entries,
+        fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_,
         exec::AggregateGradient, nullptr, nullptr,
         zero_ops, "_copy");
   }
@@ -105,12 +104,12 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
         std::make_shared<dmlc::any>(std::move(full_ref_count));
 
     size_t num_forward_inputs = num_inputs();
-    for (uint32_t i = 0; i < ograd_entries.size(); ++i) {
-      if (!idx.exist(ograd_entries[i].node.get())) continue;
-      auto eid = idx.entry_id(ograd_entries[i]);
+    size_t num_forward_outputs = num_outputs();
+    for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
+      if (!idx.exist(ograd_entries_[i].node.get())) continue;
+      auto eid = idx.entry_id(ograd_entries_[i]);
       if (ref_count[eid] > 0) {
         bwd_ograd_dep_.push_back(i);
-        bwd_input_eid_.push_back(eid);
       }
     }
     save_inputs_.resize(num_forward_inputs, false);
@@ -119,16 +118,14 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
       if (ref_count[eid] > 0) {
         save_inputs_[i] = true;
         bwd_in_dep_.push_back(i);
-        bwd_input_eid_.push_back(eid);
       }
     }
     save_outputs_.resize(idx.outputs().size(), false);
-    for (uint32_t i = 0; i < idx.outputs().size(); ++i) {
+    for (uint32_t i = 0; i < num_forward_outputs; ++i) {
       auto eid = idx.entry_id(idx.outputs()[i]);
       if (ref_count[eid] > 0) {
         save_outputs_[i] = true;
         bwd_out_dep_.push_back(i);
-        bwd_input_eid_.push_back(eid);
       }
     }
   }
@@ -242,9 +239,28 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
       if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
     }
+    bwd_input_eid_.clear();
   }
 
   const auto& idx = g.indexed_graph();
+
+  if (bwd_input_eid_.size() != inputs.size()) {
+    bwd_input_eid_.clear();
+    for (const auto& i : bwd_ograd_dep_) {
+      auto eid = idx.entry_id(ograd_entries_[i]);
+      bwd_input_eid_.push_back(eid);
+    }
+    for (const auto& i : bwd_in_dep_) {
+      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+      bwd_input_eid_.push_back(eid);
+    }
+    for (const auto& i : bwd_out_dep_) {
+      auto eid = idx.entry_id(idx.outputs()[i]);
+      bwd_input_eid_.push_back(eid);
+    }
+    CHECK_EQ(inputs.size(), bwd_input_eid_.size());
+  }
+
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
   size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
 
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index df0af34..751f1fb 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -567,6 +567,17 @@ def test_fill_shape_deferred():
     assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]
 
 
+def test_dtype():
+    net = mx.gluon.model_zoo.vision.resnet18_v1()
+    net.initialize()
+    net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+    net = mx.gluon.model_zoo.vision.resnet18_v1()
+    net.initialize()
+    net.hybridize()
+    net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+
 def test_fill_shape_load():
     ctx = mx.context.current_context()
     net1 = nn.HybridSequential()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].