You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cj...@apache.org on 2017/11/05 20:58:13 UTC

[incubator-mxnet] 04/05: fix place device (#8450)

This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch v0.12.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit ca3d56fa94f5569bd4af04ef27bea72099b64895
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Sat Oct 28 10:50:43 2017 -0700

    fix place device (#8450)
---
 src/imperative/imperative_utils.h | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 3758b47..85e01b1 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -571,8 +571,15 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {
       auto fwd_nid = idx.node_id(idx[i].source->control_deps[0].get());
       CHECK_EQ(idx[fwd_nid].source->op(), _copyto);
       vctx[i] = vctx[idx[fwd_nid].inputs[0].node_id];
-    } else if (idx[i].inputs.size() && vctx[i].dev_type == -1) {
-      vctx[i] = vctx[idx[i].inputs[0].node_id];
+    } else if (idx[i].control_deps.size() &&
+               vctx[idx[i].control_deps[0]].dev_type != -1) {
+      vctx[i] = vctx[idx[i].control_deps[0]];
+    } else {
+      for (const auto& in : idx[i].inputs) {
+        if (vctx[in.node_id].dev_type == -1) continue;
+        vctx[i] = vctx[in.node_id];
+        break;
+      }
     }
   }
   // backward pass

-- 
To stop receiving notification emails like this one, please contact
"commits@mxnet.apache.org" <co...@mxnet.apache.org>.