You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/09 21:04:26 UTC

[GitHub] zheng-da commented on a change in pull request #12104: [DO NOT REVIEW] Subgraph API

zheng-da commented on a change in pull request #12104: [DO NOT REVIEW] Subgraph API
URL: https://github.com/apache/incubator-mxnet/pull/12104#discussion_r209075473
 
 

 ##########
 File path: src/executor/graph_executor.cc
 ##########
 @@ -1699,6 +1701,146 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start,
     iter->c_str());
   return ret;
 }
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+                                     nnvm::ShapeVector arg_shapes,
+                                     nnvm::DTypeVector arg_dtypes,
+                                     StorageTypeVector arg_stypes,
+                                     const Context& default_ctx,
+                                     const std::map<std::string, Context>& ctx_map,
+                                     const std::vector<Context>& in_arg_ctxes,
+                                     const std::vector<Context>& aux_state_ctxes) {
+  const auto& indexed_graph = g.indexed_graph();
+  const auto num_forward_inputs = indexed_graph.input_nodes().size();
+  g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+                   aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs, indexed_graph,
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs, indexed_graph,
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+  return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const nnvm::ShapeVector& arg_shapes,
+                                   const nnvm::DTypeVector& arg_dtypes,
+                                   const StorageTypeVector arg_stypes,
 
 Review comment:
   use reference for `arg_stypes`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services