You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/08/18 21:43:29 UTC

[GitHub] [incubator-mxnet] HahTK commented on a change in pull request #18894: [1.x] Backporting #18779 to v1.x

HahTK commented on a change in pull request #18894:
URL: https://github.com/apache/incubator-mxnet/pull/18894#discussion_r472436593



##########
File path: example/extensions/lib_subgraph/subgraph_lib.cc
##########
@@ -331,12 +290,41 @@ REGISTER_PARTITIONER(mySelect)
 .setCreateSelector("strategy1", createSelector)
 .setReviewSubgraph("strategy1", myReviewSubgraph);
 
+/* \brief a basic pass that adds a new input for subgraph ops */
+MXReturnValue addInputPass(mxnet::ext::Graph *graph,

Review comment:
       This seems to add an input to all custom subgraph ops.
   Will this not cause problems if there is more than one type of custom subgraph op defined ?

##########
File path: example/extensions/lib_subgraph/subgraph_lib.cc
##########
@@ -176,70 +174,42 @@ REGISTER_OP(_custom_subgraph_op)
 
 const std::vector<std::string> op_names({"exp","log"});
 
-MXReturnValue mySupportedOps(const std::string& json,
+MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
                              std::vector<int>* ids,
                              const std::unordered_map<std::string, std::string>& options) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
   }
-  //convert json string to json object
-  JsonParser parser;
-  JsonVal json_val = parser.parse_to_json(json);
-  //get nodes list
-  JsonVal nodes = json_val.map[JsonVal("nodes")];
 
   //loop over nodes
-  for(int i=0; i<nodes.list.size(); i++) {
-    JsonVal node = nodes.list[i];
-    JsonVal op = node.map[JsonVal("op")];
+  for(int i=0; i<graph->size(); i++) {
+    const mxnet::ext::Node *node = graph->getNode(i);
 
     //get shape/type if available
     std::string shape;
     int dtype = -1;
-    if(node.map.find(JsonVal("attrs")) != node.map.end()) {
-      JsonVal attrs = node.map[JsonVal("attrs")];
-      if(attrs.map.find(JsonVal("shape")) != attrs.map.end()) 
-        shape = attrs.map[JsonVal("shape")].str;
-      if(attrs.map.find(JsonVal("dtype")) != attrs.map.end())
-        dtype = std::stoi(attrs.map[JsonVal("dtype")].str);
-    }
+    if(node->attrs.count("shape") > 0)
+      shape = node->attrs.at("shape");
+    if(node->attrs.count("dtype") > 0)
+      dtype = std::stoi(node->attrs.at("dtype"));
 
     //check if op dtype is float, and if option was specified to require float types
     if((dtype == kFloat32 && options.count("reqFloat") > 0) || options.count("reqFloat") == 0) {
-      //check if op is in whitelist
-      if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) != op_names.end()) {
-        // found op in whitelist, set value to -1 to include op in any subgraph
+      //check if op is in allowlist
+      if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) != op_names.end()) {
+        // found op in allowlist, set value to -1 to include op in any subgraph
         ids->at(i) = -1;
       }
     }
   }
   return MX_SUCCESS;
 }
 
-MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool* accept,
-                               const std::unordered_map<std::string, std::string>& options,
-                               std::unordered_map<std::string, std::string>* attrs,
-                               const std::unordered_map<std::string, MXTensor>& args,
-                               const std::unordered_map<std::string, MXTensor>& aux) {
+MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept,
+                               const std::unordered_map<std::string, std::string>& options) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
   }
-  for (auto kv : args) {
-    std::cout << "arg: " << kv.first << " ==> (";
-    for (auto s : kv.second.shape)
-      std::cout << s << ",";
-    std::cout << ") [";
-    for (int i=0; i<kv.second.size(); i++)
-      std::cout << kv.second.data<float>()[i] << ", ";
-    std::cout << "]" << std::endl;
-  }
-
-  // check if option `reqArgs` was specified, and if so check if args were provided

Review comment:
       Does the concept of required args go away for some reason ?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org