You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/20 05:55:33 UTC
[incubator-mxnet] branch master updated: fix group2ctx with null
reqs (#8717)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7fc0396 fix group2ctx with null reqs (#8717)
7fc0396 is described below
commit 7fc039639b288f80fa7fe6482de1a25e04261e5e
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Sun Nov 19 21:55:31 2017 -0800
fix group2ctx with null reqs (#8717)
---
src/executor/graph_executor.cc | 18 +++++++++++++----
tests/python/unittest/test_multi_device_exec.py | 26 +++++++++++++++++--------
2 files changed, 32 insertions(+), 12 deletions(-)
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index ade8e83..01484da 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -321,6 +321,7 @@ Graph AssignContext(Graph g,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
+ const std::vector<OpReqType>& grad_req_types,
size_t num_forward_inputs,
size_t num_forward_outputs) {
const auto& idx = g.indexed_graph();
@@ -385,9 +386,15 @@ Graph AssignContext(Graph g,
// loop through backward input nodes and populate maps and lists
// the backward input nodes is the gradient of the loss wrt the output
- for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
+ size_t arg_grad_offset = 0;
+ // keep an offset into the arg_grad_ctxes vector,
+ // since g.outputs exclude arg_grad whose req == null
+ CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs)
+ << "insufficient number of grad_reqs";
+ for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+ while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
const uint32_t nid = idx.outputs()[i].node_id;
- Context ctx = arg_grad_ctxes[i - num_forward_outputs];
+ Context ctx = arg_grad_ctxes[arg_grad_offset];
if (ctx2id.count(ctx) == 0) {
ctx2id[ctx] = static_cast<int>(ctx_list.size());
ctx_list.push_back(ctx);
@@ -417,9 +424,11 @@ Graph AssignContext(Graph g,
// if the assigned device of gradient node
// corresponds to storage of grads
auto &new_idx = g.indexed_graph();
- for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
+ arg_grad_offset = 0;
+ for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) {
+ while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset;
const uint32_t nid = new_idx.outputs()[i].node_id;
- Context ctx = arg_grad_ctxes[i - num_forward_outputs];
+ Context ctx = arg_grad_ctxes[arg_grad_offset];
CHECK(ctx == vcontext[nid])
<< "Trying to save gradient to " << ctx
<< " while its source node \"" << new_idx[nid].source->attrs.name
@@ -1055,6 +1064,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
in_arg_ctxes,
arg_grad_ctxes,
aux_state_ctxes,
+ grad_req_types,
num_forward_inputs_,
num_forward_outputs_);
diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py
index 0a2739d..aa279b1 100644
--- a/tests/python/unittest/test_multi_device_exec.py
+++ b/tests/python/unittest/test_multi_device_exec.py
@@ -20,6 +20,17 @@ import numpy as np
import mxnet as mx
def test_ctx_group():
+ def check_ctx_group(group2ctx, grad_req, mlp, set_stage1):
+ texec = mlp.simple_bind(mx.cpu(0),
+ group2ctx=group2ctx,
+ data=(1,200), grad_req=grad_req)
+
+ for arr, name in zip(texec.arg_arrays, mlp.list_arguments()):
+ if name in set_stage1:
+ assert arr.context == group2ctx['stage1']
+ else:
+ assert arr.context == group2ctx['stage2']
+
with mx.AttrScope(ctx_group='stage1'):
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
@@ -40,15 +51,14 @@ def test_ctx_group():
'stage2' : mx.cpu(2)
}
- texec = mlp.simple_bind(mx.cpu(0),
- group2ctx=group2ctx,
- data=(1,200))
+ # generate reqs with null
+ grad_req_with_null = {}
+ for arg in mlp.list_arguments():
+ grad_req_with_null[arg] = 'null' if arg == 'data' else 'write'
- for arr, name in zip(texec.arg_arrays, mlp.list_arguments()):
- if name in set_stage1:
- assert arr.context == group2ctx['stage1']
- else:
- assert arr.context == group2ctx['stage2']
+ grad_reqs = ['write', grad_req_with_null]
+ for grad_req in grad_reqs:
+ check_ctx_group(group2ctx, grad_req, mlp, set_stage1)
def test_ctx_group_sparse():
with mx.AttrScope(ctx_group='stage1'):
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].