You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/15 23:25:07 UTC

[GitHub] piiswrong closed pull request #11472: Add mirror to Gluon

piiswrong closed pull request #11472: Add mirror to Gluon
URL: https://github.com/apache/incubator-mxnet/pull/11472
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/_memmonger.py b/python/mxnet/gluon/_memmonger.py
new file mode 100644
index 00000000000..8d48a4cf09b
--- /dev/null
+++ b/python/mxnet/gluon/_memmonger.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=
+
+import math
+
+from .. import cpu
+
+def prod(shape):
+    """Get product of the shape.
+    """
+    ret = 1
+    for s in shape:
+        ret *= s
+    return ret
+
+def make_mirror_plan(sym, inputs, params, threshold, plan_info=None):
+    """Memory allocation planner with a given threshold.
+    The user can pass in a network configuration,
+    a threshold that limits memory per block.
+    And input shape configurations.
+    Parameters
+    ----------
+    sym : symbol
+        Input configuration of symbols.
+        The user need to pre-mark the attribute "mirror_stage" on the nodes
+        that can be book-kept as stage
+        The algorithm will decide whether to disbale mirror on the stage nodes.
+    threshold: integer
+        A tuning parameter to tune the approximate size of each stage blocks
+    plan_info: dict, optional
+        Used to hold plan information.
+    **kwargs:
+        The arguments to infer shape.
+    Returns
+    -------
+    alloc_sym: symbol
+        A symbol with force mirror tagged on the nodes for better allocation.
+    """
+    threshold = threshold << 20
+    sym = sym.__copy__()
+    internals = sym.get_internals()
+    input_shapes = {key: val.shape for key, val in inputs.items()}
+    input_shapes.update({key: val.shape for key, val in params.items()})
+    _, out_shapes, _ = internals.infer_shape(**input_shapes)
+    shape_dict = list(zip(internals.list_outputs(), out_shapes))
+    total_size = 0
+    local_size = 0
+    save_size = 0
+    max_size = 0
+    last_sb = None
+    last_local = 0
+    period = 1
+    last_stage = ''
+    stage_decision = ''
+
+    for idx, item in enumerate(shape_dict):
+        sb = internals[idx]
+        name, shape = item
+        if name in input_shapes:
+            continue
+        total_size += prod(shape) * 4
+        local_size += prod(shape) * 4
+        sb._set_attr(force_mirroring='True')
+
+        if sb.attr('mirror_stage') is not None:
+            stage = sb.attr('mirror_stage')
+            if stage == 'True' or stage != last_stage:
+                if local_size > threshold:
+                    save_size += prod(shape) * 4
+                    max_size = max(max_size, local_size)
+                    local_size = 0
+                    stage_decision = 'False'
+                    sb._set_attr(force_mirroring=stage_decision)
+                else:
+                    stage_decision = 'True'
+                    pass
+                last_stage = stage
+            elif stage == last_stage and stage_decision == 'False':
+                save_size += prod(shape) * 4
+                sb._set_attr(force_mirroring=stage_decision)
+
+    if plan_info is not None:
+        plan_info['max_size'] = max_size
+        plan_info['save_size'] = save_size
+    return sym
+
+
+def get_cost(sym, inputs, params):
+    """Get the cost of the current symbolic plan by running bind on CPU.
+    sym : Symbolic Variable
+    """
+    grad_reqs = {}
+    type_dict = {}
+    shape_dict = {}
+    for key, val in inputs.items():
+        type_dict[key] = val.dtype
+        shape_dict[key] = val.shape
+        if val.grad is None:
+            grad_reqs[key] = 'null'
+        else:
+            grad_reqs[key] = 'write'
+
+    for key, val in params.items():
+        type_dict[key] = val.dtype
+        shape_dict[key] = val.shape
+        grad_reqs[key] = val.grad_req
+
+    texec = sym.simple_bind(ctx=cpu(),
+                            grad_req=grad_reqs,
+                            type_dict=type_dict,
+                            **shape_dict)
+    return int(texec.debug_str().split('\n')[-3].split()[1])
+
+
+def search_plan(sym, inputs, params, ntrials=6):
+    """Quickly heurestic search over possible plans to find good memory plan.
+
+    Parameters
+    ----------
+    sym : symbolic
+       Symbolic configurations
+    ntrials: integer
+       Additional grid search steps
+    """
+    history = []
+    threshold = 0
+    min_threshold = None
+    min_cost = None
+    nbegin = 3
+
+    for k in range(nbegin):
+        info = {}
+        sym = make_mirror_plan(sym, inputs, params, threshold, info)
+        cost = get_cost(sym, inputs, params)
+        save_size = info['save_size'] >> 20
+        local_size = info['max_size'] >> 20
+        guess = int(math.sqrt(save_size * local_size / 2))
+        if min_cost is None or min_cost > cost:
+            min_cost = cost
+        if min_threshold is None or local_size < min_threshold:
+            min_threshold = local_size
+        print ("Search threshold=%d MB, cost=%d MB" % (threshold, cost))
+        history.append((cost, threshold, sym))
+        threshold = guess
+
+    max_threshold = threshold * math.sqrt(2)
+    step = int((max_threshold - min_threshold) / ntrials)
+    threshold = min_threshold + step
+    if step > 0:
+        for k in range(ntrials):
+            sym = make_mirror_plan(sym, inputs, params, threshold)
+            cost = get_cost(sym, inputs, params)
+            print ("Search threshold=%d MB, cost=%d MB" % (threshold, cost))
+            history.append((cost, threshold, sym))
+            threshold += step
+
+    history.sort(key = lambda x: x[0])
+    cost, threshold, sym = history[0]
+    print('Find best plan with threshold=%d, cost=%d MB' % (threshold, cost))
+    return sym
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 0ef28496c20..1d4bccd3e67 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -32,7 +32,7 @@
 from .. import name as _name
 from .parameter import Parameter, ParameterDict, DeferredInitializationError
 from .utils import _indent, _brief_print_list, HookHandle
-
+from . import _memmonger
 
 class _BlockScope(object):
     """Scope for collecting child `Block` s."""
@@ -702,6 +702,7 @@ def __init__(self, prefix=None, params=None):
         self._out_format = None
         self._in_format = None
         self._active = False
+        self._use_memmonger = False
         self._flags = []
 
     def __setattr__(self, name, value):
@@ -730,7 +731,7 @@ def _get_graph(self, *args):
 
     def _build_cache(self, *args):
         data, out = self._get_graph(*args)
-        data_names = {data.name : i for i, data in enumerate(data)}
+        data_names = {j.name : i for i, j in enumerate(data)}
         params = self.collect_params()
         input_names = out.list_inputs()
 
@@ -753,6 +754,10 @@ def _build_cache(self, *args):
             warnings.warn("Parameter %s is not used by any computation. "
                           "Is this intended?"%unused, stacklevel=4)
 
+        if self._use_memmonger:
+            inputs = {i.name: j for i, j in zip(data, _flatten(args, "input")[0])}
+            out = _memmonger.search_plan(out, inputs, params)
+
         data_indices = []
         param_indices = []
         self._cached_op_args = []
@@ -814,6 +819,7 @@ def register_child(self, block, name=None):
 
     def hybridize(self, active=True, **kwargs):
         self._active = active
+        self._use_memmonger = kwargs.pop('use_memmonger', False)
         self._flags = list(kwargs.items())
         self._clear_cached_op()
         if active and self._forward_hooks or self._forward_pre_hooks:
diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py
index da279b89583..f238854c285 100644
--- a/python/mxnet/gluon/model_zoo/vision/resnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/resnet.py
@@ -80,7 +80,7 @@ def hybrid_forward(self, F, x):
         if self.downsample:
             residual = self.downsample(residual)
 
-        x = F.Activation(residual+x, act_type='relu')
+        x = F.Activation(residual+x, act_type='relu', __mirror_stage__='True')
 
         return x
 
@@ -128,7 +128,7 @@ def hybrid_forward(self, F, x):
         if self.downsample:
             residual = self.downsample(residual)
 
-        x = F.Activation(x + residual, act_type='relu')
+        x = F.Activation(x + residual, act_type='relu', __mirror_stage__='True')
         return x
 
 
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f90023..a83ad4c7bee 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -275,13 +275,23 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
     }
   }
 
-  int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0);
-  auto need_mirror = [do_mirror](const nnvm::Node& node) -> int {
-    if (node.is_variable()) return 0;
-    const std::string& type = node.attrs.op->name;
-    if (type == "Dropout") return false;
+  int use_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0);
+  auto mirror_fn = [use_mirror](const nnvm::Node& node) -> int {
+    static auto& fresource = nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
+    static auto& fresource_ex = nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
+    static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+
+    if (node.is_variable()) return false;
+    if (fresource_ex.count(node.op())) return false;
+    if (fresource.count(node.op())) {
+      auto reqs = fresource[node.op()](node.attrs);
+      for (const auto& req : reqs) {
+        if (req.type != ResourceRequest::kTempSpace) return false;
+      }
+    }
     if (get_node_attr(node, "__force_mirroring__", false)) return true;
-    if (do_mirror == 0) return false;
+    if (!use_mirror) return false;
+    const std::string& type = node.attrs.op->name;
     if (type == "Convolution") return false;
     if (type == "FullyConnected") return false;
     if (type == "Concat") return false;
@@ -298,7 +308,7 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
   // take gradient
   nnvm::Graph g_grad = nnvm::pass::Gradient(
       g, symbol.outputs, xs, head_grad_entry_,
-      AggregateGradient, need_mirror, nullptr,
+      AggregateGradient, mirror_fn, nullptr,
       zero_ops, "_copy");
   CHECK_EQ(g_grad.outputs.size(), xs.size());
   for (const auto &e : g_grad.outputs) {
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5a3d44c04ce..06c81ba04fa 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -89,6 +89,22 @@ struct CachedOp::CachedOpState {
   std::multimap<size_t, NDArray> bwd_reuse_pool;
 };
 
+template<typename ValueType>
+inline ValueType get_node_attr(
+    const nnvm::Node& node,
+    const std::string& key, ValueType default_value) {
+  auto it = node.attrs.dict.find(key);
+  if (it == node.attrs.dict.end()) {
+    return default_value;
+  } else {
+    ValueType ret;
+    dmlc::parameter::FieldEntry<ValueType> e;
+    e.Init(key, &ret, ret);
+    e.Set(&ret, it->second);
+    return ret;
+  }
+}
+
 CachedOp::CachedOp(
     const nnvm::Symbol& sym,
     const std::vector<std::pair<std::string, std::string> >& flags) {
@@ -172,9 +188,26 @@ CachedOp::CachedOp(
     CHECK_GT(xs.size(), 0)
         << "There are no inputs in computation graph that require gradients.";
 
+    auto mirror_fn = [](const nnvm::Node& node) -> int {
+      static auto& fresource = nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
+      static auto& fresource_ex = nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
+      static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+
+      if (node.is_variable()) return false;
+      if (fresource_ex.count(node.op())) return false;
+      if (fresource.count(node.op())) {
+        auto reqs = fresource[node.op()](node.attrs);
+        for (const auto& req : reqs) {
+          if (req.type != ResourceRequest::kTempSpace) return false;
+        }
+      }
+      if (get_node_attr(node, "__force_mirroring__", false)) return true;
+      return false;
+    };
+
     grad_graph_ = pass::Gradient(
         fwd_graph_, fwd_graph_.outputs, xs, ograd_entries_,
-        exec::AggregateGradient, nullptr, nullptr,
+        exec::AggregateGradient, mirror_fn, nullptr,
         zero_ops, "_copy");
   }
 
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 7fff6b8c1f5..d0357247413 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1175,7 +1175,7 @@ def test(net, x):
         assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
 
 def test_hybrid_static_memory():
-    check_hybrid_static_memory()
+    check_hybrid_static_memory(use_memmonger=True)
     check_hybrid_static_memory(static_alloc=True)
     check_hybrid_static_memory(static_alloc=True, static_shape=True)
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services