You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/17 18:39:03 UTC

[GitHub] zheng-da commented on a change in pull request #11306: Add subgraph storage type inference to _backward_CachedOp

zheng-da commented on a change in pull request #11306: Add subgraph storage type inference to _backward_CachedOp 
URL: https://github.com/apache/incubator-mxnet/pull/11306#discussion_r195937517
 
 

 ##########
 File path: src/imperative/cached_op.cc
 ##########
 @@ -1015,6 +1017,49 @@ void CachedOp::Backward(
   Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
+bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
+                                   const int dev_mask,
+                                   DispatchMode* dispatch_mode,
+                                   std::vector<int> *in_attrs,
+                                   std::vector<int> *out_attrs) {
+  using namespace imperative;
+  nnvm::Graph g(full_graph_);
+  const auto& idx = g.indexed_graph();
+  const auto &outputs = idx.outputs();
+  const size_t num_forward_outputs = fwd_graph_.outputs.size();
+  CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size());
+
+  // Construct bwd_input_eid
+  std::vector<uint32_t> bwd_input_eid;
+  SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
+                      ograd_entries_, idx, &bwd_input_eid);
+  CHECK_EQ(in_attrs->size(), bwd_input_eid.size());
+
+  // Prepare stypes and contexts based on inputs
+  StorageTypeVector stypes(idx.num_node_entries(), -1);
+  for (size_t i = 0; i < in_attrs->size(); ++i) {
+    stypes[bwd_input_eid[i]] = in_attrs->at(i);
+  }
+  // Some out_attr is known ahead of time (e.g. the grad stype is given by users).
+  // Prepare these to before invoking infer storage on the subgraph
+  for (size_t i = 0; i < out_attrs->size(); i++) {
+    const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+    stypes[eid] = out_attrs->at(i);
+  }
+  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+
+  // Full graph storage type inference
+  CheckAndInferStorageType(&g, std::move(dev_masks), std::move(stypes), false);
 
 Review comment:
   should we use `exec::InferShape(std::move(g))`? Is it guaranteed that the inference works in one invocation?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services