You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/10/06 19:29:26 UTC

[incubator-mxnet] branch v1.8.x updated: [1.8.x] Backporting: Fixed setting attributes in reviewSubgraph (#19278)

This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch v1.8.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.8.x by this push:
     new cc4b8ec  [1.8.x] Backporting: Fixed setting attributes in reviewSubgraph (#19278)
cc4b8ec is described below

commit cc4b8ec68b6ec9dae73046e9c34ac97439efda83
Author: Sam Skalicky <sa...@gmail.com>
AuthorDate: Tue Oct 6 12:24:40 2020 -0700

    [1.8.x] Backporting: Fixed setting attributes in reviewSubgraph (#19278)
    
    * initial commit
    
    * fixed mapping from top level param names to subgraph input names
    
    * fixed sanity
    
    * support escape characters when parsing strings
    
    * changed string allocation from new to malloc to match free
    
    * add node to graph nodes array
    
    * fixed add nodes
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
---
 example/extensions/lib_subgraph/subgraph_lib.cc | 10 ++++++-
 include/mxnet/lib_api.h                         |  8 +++--
 src/lib_api.cc                                  | 40 +++++++++++++++----------
 3 files changed, 38 insertions(+), 20 deletions(-)

diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc
index f471093..98508c7 100644
--- a/example/extensions/lib_subgraph/subgraph_lib.cc
+++ b/example/extensions/lib_subgraph/subgraph_lib.cc
@@ -209,11 +209,16 @@ MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
 }
 
 MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept,
-                               const std::unordered_map<std::string, std::string>& options) {
+                               const std::unordered_map<std::string, std::string>& options,
+                               std::unordered_map<std::string, std::string>* attrs) {
   for (auto kv : options) {
     std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
   }
 
+  std::string sg = subgraph->toString();
+  std::cout << "subgraph " << subgraph_id << ": " << std::endl;
+  std::cout << sg << std::endl;
+
   // check if option `reject` was specified, and if so check if value is 'True'
   if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) {
     // if specified, reject the subgraph. this is only used for testing
@@ -223,6 +228,9 @@ MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_i
     *accept = true;
     std::cout << "accepting subgraph" << std::endl;
   }
+
+  attrs->emplace("myKey","myVal");
+
   return MX_SUCCESS;
 }
 
diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h
index 0213557..db93dbe 100644
--- a/include/mxnet/lib_api.h
+++ b/include/mxnet/lib_api.h
@@ -594,10 +594,10 @@ class Graph {
   static Graph* fromJson(JsonVal val);
 
   /* \brief convert graph object back to JSON object */
-  JsonVal toJson();
+  JsonVal toJson() const;
 
   /* \brief convert graph object to JSON string */
-  std::string toString();
+  std::string toString() const;
 
   /* \brief visits a node "n" */
   void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
@@ -819,7 +819,9 @@ typedef MXReturnValue (*createSelector_t)(const mxnet::ext::Graph *graph,
 typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id,
                                           bool* accept,
                                           const std::unordered_map<std::string,
-                                                                   std::string>& options);
+                                                                   std::string>& options,
+                                          std::unordered_map<std::string,
+                                                             std::string>* attrs);
 
 /*!
  * \brief An abstract class for subgraph property
diff --git a/src/lib_api.cc b/src/lib_api.cc
index 20ae280..c273678 100644
--- a/src/lib_api.cc
+++ b/src/lib_api.cc
@@ -348,7 +348,8 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) {
 mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) {
   JsonVal ret(STR);
   while (*idx < json.size()) {
-    if (json[*idx] == '"') {
+    if (json[*idx] == '"' && (ret.str.size() == 0 ||
+                              (ret.str.size() > 0 && ret.str.back() != '\\'))) {
       ++(*idx);
       return ret;
     } else {
@@ -561,7 +562,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) {
 }
 
 /* \brief convert graph object back to JSON object */
-mxnet::ext::JsonVal mxnet::ext::Graph::toJson() {
+mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const {
   // top level object is a map
   JsonVal val(MAP);
 
@@ -646,7 +647,7 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() {
 }
 
 /* \brief convert graph object to JSON string */
-std::string mxnet::ext::Graph::toString() {
+std::string mxnet::ext::Graph::toString() const {
   return toJson().dump();
 }
 
@@ -725,6 +726,7 @@ void mxnet::ext::Graph::print(int indent) const {
 /* \brief add a new node to this graph */
 mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) {
   Node* n = new Node();
+  nodes.push_back(n);
   n->name = name;
   n->op = op;
   if (res)
@@ -766,10 +768,14 @@ void mxnet::ext::Graph::_setParams(std::unordered_map<std::string, mxnet::ext::M
                                    std::unordered_map<std::string, mxnet::ext::MXTensor>* aux) {
   // set params for each input node
   for (Node* node : inputs) {
-    if (args->count(node->name) > 0)
-      node->tensor = &args->at(node->name);
-    else if (aux->count(node->name) > 0)
-      node->tensor = &aux->at(node->name);
+    std::string name = node->name;
+    if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0)
+      // mapping name back to original node name from subgraph input name
+      name = node->attrs["argName"];
+    if (args->count(name) > 0)
+      node->tensor = &args->at(name);
+    else if (aux->count(name) > 0)
+      node->tensor = &aux->at(name);
   }
 }
 
@@ -1494,26 +1500,27 @@ MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph,
   }
 
   subgraph->_setParams(&args, &aux);
+
+  std::unordered_map<std::string, std::string> attrs;
   mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool,
-                                                    opts);
+                                                    opts, &attrs);
   if (!retval) return retval;
 
   *accept = accept_bool;
 
-  if (subgraph->attrs.size() > 0) {
-    *num_attrs = subgraph->attrs.size();
+  if (attrs.size() > 0) {
+    *num_attrs = attrs.size();
     // allocate space for attributes
     *attr_keys = static_cast<char**>(malloc (*num_attrs * sizeof(char*)));  // NOLINT
     *attr_vals = static_cast<char**>(malloc (*num_attrs * sizeof(char*)));  // NOLINT
 
     // copy attributes
     int i = 0;
-    for (auto kv : subgraph->attrs) {
+    for (auto kv : attrs) {
       (*attr_keys)[i] = static_cast<char*>(malloc ((kv.first.size()+1) * sizeof(char)));  // NOLINT
-      std::string val = kv.second.dump();  // convert JsonVal back to string
-      (*attr_vals)[i] = static_cast<char*>(malloc ((val.size()+1) * sizeof(char)));  // NOLINT
+      (*attr_vals)[i] = static_cast<char*>(malloc ((kv.second.size()+1) * sizeof(char)));  // NOLINT
       snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str());
-      snprintf((*attr_vals)[i], val.size()+1, "%s", val.c_str());
+      snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str());
       i++;
     }
   }
@@ -1587,8 +1594,9 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso
   mxnet::ext::MXReturnValue retval = graphPass(graph, opts);
   if (!retval) return retval;
 
-  std::string *tmp = new std::string(graph->toString());
-  *out_graph = const_cast<char*>(tmp->c_str());
+  std::string tmp = graph->toString();
+  *out_graph = static_cast<char*>(malloc ((tmp.size()+1) * sizeof(char)));  // NOLINT
+  snprintf((*out_graph), tmp.size()+1, "%s", tmp.c_str());
   return retval;
 }