You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/08/07 05:37:35 UTC

[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #18779: [WIP] Support extra inputs for subgraph ops

samskalicky commented on a change in pull request #18779:
URL: https://github.com/apache/incubator-mxnet/pull/18779#discussion_r462563043



##########
File path: example/extensions/lib_subgraph/subgraph_lib.cc
##########
@@ -331,6 +332,46 @@ REGISTER_PARTITIONER(mySelect)
 .setCreateSelector("strategy1", createSelector)
 .setReviewSubgraph("strategy1", myReviewSubgraph);
 
+/* \brief a basic pass that adds a new input for subgraph ops */
+MXReturnValue addInputPass(const std::string& in_graph, const std::string** out_graph,
+			   const std::unordered_map<std::string, std::string>& options,
+			   const std::unordered_map<std::string, MXTensor>& args,
+			   const std::unordered_map<std::string, MXTensor>& aux,
+			   const PassResource& res) {
+  // convert graph from JSON string to Graph/Node data structure
+  Graph *g = Graph::fromString(in_graph);
+  //find node with '_custom_subgraph_op' op type
+  for(Node* n : g->nodes) {
+    if(n->op.compare("_custom_subgraph_op") == 0) {
+      //set extra input
+      n->attrs[MX_STR_EXTRA_INPUTS] = std::to_string(1);
+      
+      //create a new input Node
+      Node* input = new Node();
+      std::string input_name = n->name + "_input";
+      input->name = input_name;
+      input->op = "null";
+      //add a new node in graph
+      g->nodes.push_back(input);
+      g->inputs.push_back(input);
+      //connect new input to node
+      input->outputs.push_back({n,(int)(n->inputs.size())});
+      //connect node to new input
+      n->inputs.push_back({input,0});
+      // add a corresponding tensor for this input
+      MXTensor* arg_ = res.alloc_arg(input_name,{1},MXContext::CPU(0),kFloat32);
+    }
+  }
+
+  //convert back to JSON string from Graph/Node
+  *out_graph = new std::string(g->toString());
+  return MX_SUCCESS;
+}
+
+REGISTER_PASS(addInputPass)

Review comment:
       Graph passes usage through optimize_for is already merged in #17885. We're not changing how graph passes are called in this PR




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org