You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/08 22:59:10 UTC

[GitHub] piiswrong closed pull request #8522: gluon with multiple data type

piiswrong closed pull request #8522: gluon with multiple data type
URL: https://github.com/apache/incubator-mxnet/pull/8522
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index d26e86f409..88a9f4d597 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -120,6 +120,7 @@ class Imperative {
     nnvm::Graph fwd_graph_;
     nnvm::Graph grad_graph_;
     nnvm::Graph full_graph_;
+    std::vector<nnvm::NodeEntry> ograd_entries_;
     std::vector<bool> curr_grad_req_;
     std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
     std::vector<uint32_t> bwd_input_eid_;
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 73dbfc10fe..2546711246 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -21,6 +21,7 @@
 __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 
 import copy
+import warnings
 
 from .. import symbol, ndarray, initializer
 from ..symbol import Symbol
@@ -325,7 +326,7 @@ def __init__(self, prefix=None, params=None):
         self._reg_params = {}
         self._cached_graph = ()
         self._cached_op = None
-        self._cached_params = None
+        self._cached_op_args = None
         self._out_format = None
         self._in_format = None
         self._active = False
@@ -363,34 +364,47 @@ def _get_graph(self, *args):
 
     def _build_cache(self, *args):
         inputs, out = self._get_graph(*args)
+        input_idx = {var.name: i for i, var in enumerate(inputs)}
         self._cached_op = ndarray.CachedOp(out)
-
         params = dict(self.collect_params().items())
-        self._cached_params = [params.get(name, None) for name in out.list_inputs()]
-        assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
-            "Wrong number of inputs."
 
-        name2pos = {var.name: i for i, var in enumerate(inputs)}
-        self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
-                        if name not in params]
+        # verify graph inputs
+        expected_inputs = set(out.list_inputs())
+        for name in expected_inputs:
+            assert name in params or name in input_idx, \
+                "Unknown input to HybridBlock: %s"%name
+        for name, i in input_idx.items():
+            if name not in expected_inputs:
+                warnings.warn("The %d-th input to HybridBlock is not used by any "
+                              "computation. Is this intended?"%i)
+        for name in params:
+            if name not in expected_inputs:
+                warnings.warn("Parameter %s is not used by any computation. "
+                              "Is this intended?"%name)
+
+        self._cached_op_args = [(False, params[name]) if name in params
+                                else (True, input_idx[name])
+                                for name in out.list_inputs()]
+
+    def _finish_deferred_init(self, hybrid, *args):
+        self.infer_shape(*args)
+        self.infer_type(*args)
+        if hybrid:
+            for is_arg, i in self._cached_op_args:
+                if not is_arg:
+                    i._finish_deferred_init()
+        else:
+            for _, i in self.params.items():
+                i._finish_deferred_init()
 
     def _call_cached_op(self, *args):
         if self._cached_op is None:
             self._build_cache(*args)
 
-        try:
-            cargs = [i.data() if i else None for i in self._cached_params]
-        except DeferredInitializationError:
-            self.infer_shape(*args)
-            for i in self._cached_params:
-                if i is not None:
-                    i._finish_deferred_init()
-            cargs = [i.data() if i else None for i in self._cached_params]
-
         args, fmt = _flatten(args)
         assert fmt == self._in_format, "Invalid input format"
-        for i, j in self._in_idx:
-            cargs[i] = args[j]
+        cargs = [args[i] if is_arg else i.data()
+                 for is_arg, i in self._cached_op_args]
         out = self._cached_op(*cargs)
         if isinstance(out, NDArray):
             out = [out]
@@ -399,6 +413,7 @@ def _call_cached_op(self, *args):
     def _clear_cached_op(self):
         self._cached_graph = ()
         self._cached_op = None
+        self._cached_op_args = None
 
     def register_child(self, block):
         if not isinstance(block, HybridBlock):
@@ -414,17 +429,25 @@ def hybridize(self, active=True):
         self._active = active
         super(HybridBlock, self).hybridize(active)
 
-    def infer_shape(self, *args):
-        """Infers shape of Parameters from inputs."""
+    def _infer_attrs(self, infer_fn, attr, *args):
+        """Generic infer attributes."""
         inputs, out = self._get_graph(*args)
         args, _ = _flatten(args)
-        arg_shapes, _, aux_shapes = out.infer_shape(
-            **{i.name: j.shape for i, j in zip(inputs, args)})
-        sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
-        sdict.update({name : shape for name, shape in \
-                      zip(out.list_auxiliary_states(), aux_shapes)})
+        arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
+            **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
+        sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
+        sdict.update({name : attr for name, attr in \
+                      zip(out.list_auxiliary_states(), aux_attrs)})
         for i in self.collect_params().values():
-            i.shape = sdict[i.name]
+            setattr(i, attr, sdict[i.name])
+
+    def infer_shape(self, *args):
+        """Infers shape of Parameters from inputs."""
+        self._infer_attrs('infer_shape', 'shape', *args)
+
+    def infer_type(self, *args):
+        """Infers data type of Parameters from inputs."""
+        self._infer_attrs('infer_type', 'dtype', *args)
 
     def export(self, path):
         """Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
@@ -462,15 +485,16 @@ def forward(self, x, *args):
         :py:class:`NDArray` or :py:class:`Symbol`."""
         if isinstance(x, NDArray):
             with x.context as ctx:
-                if self._active:
-                    return self._call_cached_op(x, *args)
                 try:
+                    if self._active:
+                        return self._call_cached_op(x, *args)
                     params = {i: j.data(ctx) for i, j in self._reg_params.items()}
                 except DeferredInitializationError:
-                    self.infer_shape(x, *args)
-                    for i in self.collect_params().values():
-                        i._finish_deferred_init()
-                    params = {i: j.data(ctx) for i, j in self._reg_params.items()}
+                    self._finish_deferred_init(self._active, x, *args)
+
+                if self._active:
+                    return self._call_cached_op(x, *args)
+                params = {i: j.data(ctx) for i, j in self._reg_params.items()}
                 return self.hybrid_forward(ndarray, x, *args, **params)
 
         assert isinstance(x, Symbol), \
@@ -559,6 +583,11 @@ def __init__(self, outputs, inputs, params=None):
     def forward(self, x, *args):
         if isinstance(x, NDArray):
             with x.context:
+                try:
+                    return self._call_cached_op(x, *args)
+                except DeferredInitializationError:
+                    self._finish_deferred_init(True, x, *args)
+
                 return self._call_cached_op(x, *args)
 
         assert isinstance(x, Symbol), \
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 906f03ec92..8034ab8415 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -185,11 +185,11 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True,
             self._units = units
             self._in_units = in_units
             self.weight = self.params.get('weight', shape=(units, in_units),
-                                          init=weight_initializer,
+                                          dtype=None, init=weight_initializer,
                                           allow_deferred_init=True)
             if use_bias:
                 self.bias = self.params.get('bias', shape=(units,),
-                                            init=bias_initializer,
+                                            dtype=None, init=bias_initializer,
                                             allow_deferred_init=True)
             else:
                 self.bias = None
@@ -336,20 +336,20 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
             self.in_channels = in_channels
 
         self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
-                                     shape=(in_channels,), init=gamma_initializer,
-                                     allow_deferred_init=True,
+                                     shape=(in_channels,), dtype=None,
+                                     init=gamma_initializer, allow_deferred_init=True,
                                      differentiable=scale)
         self.beta = self.params.get('beta', grad_req='write' if center else 'null',
-                                    shape=(in_channels,), init=beta_initializer,
-                                    allow_deferred_init=True,
+                                    shape=(in_channels,), dtype=None,
+                                    init=beta_initializer, allow_deferred_init=True,
                                     differentiable=center)
         self.running_mean = self.params.get('running_mean', grad_req='null',
-                                            shape=(in_channels,),
+                                            shape=(in_channels,), dtype=None,
                                             init=running_mean_initializer,
                                             allow_deferred_init=True,
                                             differentiable=False)
         self.running_var = self.params.get('running_var', grad_req='null',
-                                           shape=(in_channels,),
+                                           shape=(in_channels,), dtype=None,
                                            init=running_variance_initializer,
                                            allow_deferred_init=True,
                                            differentiable=False)
@@ -437,7 +437,7 @@ def __init__(self, input_dim, output_dim, dtype='float32',
         self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
                         'dtype': dtype}
         self.weight = self.params.get('weight', shape=(input_dim, output_dim),
-                                      init=weight_initializer,
+                                      dtype=None, init=weight_initializer,
                                       allow_deferred_init=True)
 
     def hybrid_forward(self, F, x, weight):
diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py
index 645de98ec0..0dd70697ab 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -113,11 +113,11 @@ def __init__(self, channels, kernel_size, strides, padding, dilation,
             dshape[layout.find('C')] = in_channels
             wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
             self.weight = self.params.get('weight', shape=wshapes[1],
-                                          init=weight_initializer,
+                                          dtype=None, init=weight_initializer,
                                           allow_deferred_init=True)
             if use_bias:
                 self.bias = self.params.get('bias', shape=wshapes[2],
-                                            init=bias_initializer,
+                                            dtype=None, init=bias_initializer,
                                             allow_deferred_init=True)
             else:
                 self.bias = None
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c42fbaa1fa..27297b5329 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -306,7 +306,7 @@ def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
             ctx = [ctx]
         if init is None:
             init = default_init if self.init is None else self.init
-        if not self.shape or np.prod(self.shape) <= 0:
+        if self.dtype is None or not self.shape or np.prod(self.shape) <= 0:
             if self._allow_deferred_init:
                 self._deferred_init = (init, ctx, default_init)
                 return
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index ea0e32faeb..80bb8e3fb8 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -326,16 +326,16 @@ def __init__(self, hidden_size, activation='tanh',
         self._activation = activation
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(hidden_size, input_size),
-                                          init=i2h_weight_initializer,
+                                          dtype=None, init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(hidden_size, hidden_size),
-                                          init=h2h_weight_initializer,
+                                          dtype=None, init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(hidden_size,),
-                                        init=i2h_bias_initializer,
+                                        dtype=None, init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(hidden_size,),
-                                        init=h2h_bias_initializer,
+                                        dtype=None, init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
@@ -434,16 +434,16 @@ def __init__(self, hidden_size,
         self._hidden_size = hidden_size
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
-                                          init=i2h_weight_initializer,
+                                          dtype=None, init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
-                                          init=h2h_weight_initializer,
+                                          dtype=None, init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
-                                        init=i2h_bias_initializer,
+                                        dtype=None, init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
-                                        init=h2h_bias_initializer,
+                                        dtype=None, init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
@@ -541,16 +541,16 @@ def __init__(self, hidden_size,
         self._hidden_size = hidden_size
         self._input_size = input_size
         self.i2h_weight = self.params.get('i2h_weight', shape=(3*hidden_size, input_size),
-                                          init=i2h_weight_initializer,
+                                          dtype=None, init=i2h_weight_initializer,
                                           allow_deferred_init=True)
         self.h2h_weight = self.params.get('h2h_weight', shape=(3*hidden_size, hidden_size),
-                                          init=h2h_weight_initializer,
+                                          dtype=None, init=h2h_weight_initializer,
                                           allow_deferred_init=True)
         self.i2h_bias = self.params.get('i2h_bias', shape=(3*hidden_size,),
-                                        init=i2h_bias_initializer,
+                                        dtype=None, init=i2h_bias_initializer,
                                         allow_deferred_init=True)
         self.h2h_bias = self.params.get('h2h_bias', shape=(3*hidden_size,),
-                                        init=h2h_bias_initializer,
+                                        dtype=None, init=h2h_bias_initializer,
                                         allow_deferred_init=True)
 
     def state_info(self, batch_size=0):
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index eb99aabf11..60d66db485 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -62,11 +62,10 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
   }
 
   // construct backward graph
-  std::vector<NodeEntry> ograd_entries;
   {
-    ograd_entries.reserve(fwd_graph_.outputs.size());
+    ograd_entries_.reserve(fwd_graph_.outputs.size());
     for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) {
-      ograd_entries.emplace_back(NodeEntry{Node::Create(), 0, 0});
+      ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0});
     }
 
     std::vector<NodeEntry> xs;
@@ -77,7 +76,7 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
         << "There are no inputs in computation graph that require gradients.";
 
     grad_graph_ = pass::Gradient(
-        fwd_graph_, fwd_graph_.outputs, xs, ograd_entries,
+        fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_,
         exec::AggregateGradient, nullptr, nullptr,
         zero_ops, "_copy");
   }
@@ -105,12 +104,12 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
         std::make_shared<dmlc::any>(std::move(full_ref_count));
 
     size_t num_forward_inputs = num_inputs();
-    for (uint32_t i = 0; i < ograd_entries.size(); ++i) {
-      if (!idx.exist(ograd_entries[i].node.get())) continue;
-      auto eid = idx.entry_id(ograd_entries[i]);
+    size_t num_forward_outputs = num_outputs();
+    for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
+      if (!idx.exist(ograd_entries_[i].node.get())) continue;
+      auto eid = idx.entry_id(ograd_entries_[i]);
       if (ref_count[eid] > 0) {
         bwd_ograd_dep_.push_back(i);
-        bwd_input_eid_.push_back(eid);
       }
     }
     save_inputs_.resize(num_forward_inputs, false);
@@ -119,16 +118,14 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
       if (ref_count[eid] > 0) {
         save_inputs_[i] = true;
         bwd_in_dep_.push_back(i);
-        bwd_input_eid_.push_back(eid);
       }
     }
     save_outputs_.resize(idx.outputs().size(), false);
-    for (uint32_t i = 0; i < idx.outputs().size(); ++i) {
+    for (uint32_t i = 0; i < num_forward_outputs; ++i) {
       auto eid = idx.entry_id(idx.outputs()[i]);
       if (ref_count[eid] > 0) {
         save_outputs_[i] = true;
         bwd_out_dep_.push_back(i);
-        bwd_input_eid_.push_back(eid);
       }
     }
   }
@@ -242,9 +239,28 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
       if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
     }
+    bwd_input_eid_.clear();
   }
 
   const auto& idx = g.indexed_graph();
+
+  if (bwd_input_eid_.size() != inputs.size()) {
+    bwd_input_eid_.clear();
+    for (const auto& i : bwd_ograd_dep_) {
+      auto eid = idx.entry_id(ograd_entries_[i]);
+      bwd_input_eid_.push_back(eid);
+    }
+    for (const auto& i : bwd_in_dep_) {
+      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+      bwd_input_eid_.push_back(eid);
+    }
+    for (const auto& i : bwd_out_dep_) {
+      auto eid = idx.entry_id(idx.outputs()[i]);
+      bwd_input_eid_.push_back(eid);
+    }
+    CHECK_EQ(inputs.size(), bwd_input_eid_.size());
+  }
+
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
   size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
 
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index df0af34dfe..751f1fbd96 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -567,6 +567,17 @@ def test_fill_shape_deferred():
     assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]
 
 
+def test_dtype():
+    net = mx.gluon.model_zoo.vision.resnet18_v1()
+    net.initialize()
+    net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+    net = mx.gluon.model_zoo.vision.resnet18_v1()
+    net.initialize()
+    net.hybridize()
+    net(mx.nd.ones((16, 3, 32, 32), dtype='float64')).wait_to_read()
+
+
 def test_fill_shape_load():
     ctx = mx.context.current_context()
     net1 = nn.HybridSequential()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services