You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/20 05:55:33 UTC

[incubator-mxnet] branch master updated: fix group2ctx with null reqs (#8717)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7fc0396  fix group2ctx with null reqs (#8717)
7fc0396 is described below

commit 7fc039639b288f80fa7fe6482de1a25e04261e5e
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Sun Nov 19 21:55:31 2017 -0800

    fix group2ctx with null reqs (#8717)
---
 src/executor/graph_executor.cc                  | 18 +++++++++++++----
 tests/python/unittest/test_multi_device_exec.py | 26 +++++++++++++++++--------
 2 files changed, 32 insertions(+), 12 deletions(-)

diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index ade8e83..01484da 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -321,6 +321,7 @@ Graph AssignContext(Graph g,
                     const std::vector<Context>& in_arg_ctxes,
                     const std::vector<Context>& arg_grad_ctxes,
                     const std::vector<Context>& aux_state_ctxes,
+                    const std::vector<OpReqType>& grad_req_types,
                     size_t num_forward_inputs,
                     size_t num_forward_outputs) {
   const auto& idx = g.indexed_graph();
@@ -385,9 +386,15 @@ Graph AssignContext(Graph g,
 
   // loop through backward input nodes and populate maps and lists
   // the backward input nodes is the gradient of the loss wrt the output
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
+  size_t arg_grad_offset = 0;
+  // keep an offset into the arg_grad_ctxes vector,
+  // since g.outputs exclude arg_grad whose req == null
+  CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs)
+           << "insufficient number of grad_reqs";
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
     const uint32_t nid = idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[i - num_forward_outputs];
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
     if (ctx2id.count(ctx) == 0) {
       ctx2id[ctx] = static_cast<int>(ctx_list.size());
       ctx_list.push_back(ctx);
@@ -417,9 +424,11 @@ Graph AssignContext(Graph g,
   // if the assigned device of gradient node
   // corresponds to storage of grads
   auto &new_idx = g.indexed_graph();
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
+  arg_grad_offset = 0;
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
     const uint32_t nid = new_idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[i - num_forward_outputs];
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
     CHECK(ctx == vcontext[nid])
       << "Trying to save gradient to " << ctx
       << " while its source node \"" << new_idx[nid].source->attrs.name
@@ -1055,6 +1064,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
                     in_arg_ctxes,
                     arg_grad_ctxes,
                     aux_state_ctxes,
+                    grad_req_types,
                     num_forward_inputs_,
                     num_forward_outputs_);
 
diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py
index 0a2739d..aa279b1 100644
--- a/tests/python/unittest/test_multi_device_exec.py
+++ b/tests/python/unittest/test_multi_device_exec.py
@@ -20,6 +20,17 @@ import numpy as np
 import mxnet as mx
 
 def test_ctx_group():
+    def check_ctx_group(group2ctx, grad_req, mlp, set_stage1):
+        texec = mlp.simple_bind(mx.cpu(0),
+                                group2ctx=group2ctx,
+                                data=(1,200), grad_req=grad_req)
+
+        for arr, name in zip(texec.arg_arrays, mlp.list_arguments()):
+            if name in set_stage1:
+                assert arr.context == group2ctx['stage1']
+            else:
+                assert arr.context == group2ctx['stage2']
+
     with mx.AttrScope(ctx_group='stage1'):
         data = mx.symbol.Variable('data')
         fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
@@ -40,15 +51,14 @@ def test_ctx_group():
         'stage2' : mx.cpu(2)
     }
 
-    texec = mlp.simple_bind(mx.cpu(0),
-                            group2ctx=group2ctx,
-                            data=(1,200))
+    # generate reqs with null
+    grad_req_with_null = {}
+    for arg in mlp.list_arguments():
+        grad_req_with_null[arg] = 'null' if arg == 'data' else 'write'
 
-    for arr, name in zip(texec.arg_arrays, mlp.list_arguments()):
-        if name in set_stage1:
-            assert arr.context == group2ctx['stage1']
-        else:
-            assert arr.context == group2ctx['stage2']
+    grad_reqs = ['write', grad_req_with_null]
+    for grad_req in grad_reqs:
+        check_ctx_group(group2ctx, grad_req, mlp, set_stage1)
 
 def test_ctx_group_sparse():
     with mx.AttrScope(ctx_group='stage1'):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].