You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/20 05:55:34 UTC

[GitHub] piiswrong closed pull request #8717: fix group2ctx with null reqs

piiswrong closed pull request #8717: fix group2ctx with null reqs
URL: https://github.com/apache/incubator-mxnet/pull/8717
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index ade8e838fd..01484dac29 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -321,6 +321,7 @@ Graph AssignContext(Graph g,
                     const std::vector<Context>& in_arg_ctxes,
                     const std::vector<Context>& arg_grad_ctxes,
                     const std::vector<Context>& aux_state_ctxes,
+                    const std::vector<OpReqType>& grad_req_types,
                     size_t num_forward_inputs,
                     size_t num_forward_outputs) {
   const auto& idx = g.indexed_graph();
@@ -385,9 +386,15 @@ Graph AssignContext(Graph g,
 
   // loop through backward input nodes and populate maps and lists
   // the backward input nodes is the gradient of the loss wrt the output
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
+  size_t arg_grad_offset = 0;
+  // keep an offset into the arg_grad_ctxes vector,
+  // since g.outputs exclude arg_grad whose req == null
+  CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs)
+           << "insufficient number of grad_reqs";
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
     const uint32_t nid = idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[i - num_forward_outputs];
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
     if (ctx2id.count(ctx) == 0) {
       ctx2id[ctx] = static_cast<int>(ctx_list.size());
       ctx_list.push_back(ctx);
@@ -417,9 +424,11 @@ Graph AssignContext(Graph g,
   // if the assigned device of gradient node
   // corresponds to storage of grads
   auto &new_idx = g.indexed_graph();
-  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
+  arg_grad_offset = 0;
+  for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+    while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
     const uint32_t nid = new_idx.outputs()[i].node_id;
-    Context ctx = arg_grad_ctxes[i - num_forward_outputs];
+    Context ctx = arg_grad_ctxes[arg_grad_offset];
     CHECK(ctx == vcontext[nid])
       << "Trying to save gradient to " << ctx
       << " while its source node \"" << new_idx[nid].source->attrs.name
@@ -1055,6 +1064,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
                     in_arg_ctxes,
                     arg_grad_ctxes,
                     aux_state_ctxes,
+                    grad_req_types,
                     num_forward_inputs_,
                     num_forward_outputs_);
 
diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py
index 0a2739d9bb..aa279b1837 100644
--- a/tests/python/unittest/test_multi_device_exec.py
+++ b/tests/python/unittest/test_multi_device_exec.py
@@ -20,6 +20,17 @@
 import mxnet as mx
 
 def test_ctx_group():
+    def check_ctx_group(group2ctx, grad_req, mlp, set_stage1):
+        texec = mlp.simple_bind(mx.cpu(0),
+                                group2ctx=group2ctx,
+                                data=(1,200), grad_req=grad_req)
+
+        for arr, name in zip(texec.arg_arrays, mlp.list_arguments()):
+            if name in set_stage1:
+                assert arr.context == group2ctx['stage1']
+            else:
+                assert arr.context == group2ctx['stage2']
+
     with mx.AttrScope(ctx_group='stage1'):
         data = mx.symbol.Variable('data')
         fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
@@ -40,15 +51,14 @@ def test_ctx_group():
         'stage2' : mx.cpu(2)
     }
 
-    texec = mlp.simple_bind(mx.cpu(0),
-                            group2ctx=group2ctx,
-                            data=(1,200))
+    # generate reqs with null
+    grad_req_with_null = {}
+    for arg in mlp.list_arguments():
+        grad_req_with_null[arg] = 'null' if arg == 'data' else 'write'
 
-    for arr, name in zip(texec.arg_arrays, mlp.list_arguments()):
-        if name in set_stage1:
-            assert arr.context == group2ctx['stage1']
-        else:
-            assert arr.context == group2ctx['stage2']
+    grad_reqs = ['write', grad_req_with_null]
+    for grad_req in grad_reqs:
+        check_ctx_group(group2ctx, grad_req, mlp, set_stage1)
 
 def test_ctx_group_sparse():
     with mx.AttrScope(ctx_group='stage1'):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services