You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/12 15:20:58 UTC

[incubator-mxnet] branch v1.x updated: Fix for optimize_for multiple subgraph properties issue (#19263) (#20142)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new dd3a6f8  Fix for optimize_for multiple subgraph properties issue (#19263) (#20142)
dd3a6f8 is described below

commit dd3a6f82eed46ec0964387575b8c493f57ba0d31
Author: bgawrych <ba...@intel.com>
AuthorDate: Mon Apr 12 17:18:17 2021 +0200

    Fix for optimize_for multiple subgraph properties issue (#19263) (#20142)
    
    * initial commit
    
    * fixed whitespace
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
    
    Co-authored-by: Sam Skalicky <sa...@gmail.com>
    Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
---
 src/c_api/c_api_symbolic.cc | 211 +++++++++++++++++++++++---------------------
 1 file changed, 110 insertions(+), 101 deletions(-)

diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index ccc0547..877c8de 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1372,122 +1372,127 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   API_BEGIN();
   nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
   *s = sym->Copy();
-  nnvm::Graph g = Symbol2Graph(*s);
-  const auto& indexed_graph = g.indexed_graph();
-  const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
-  std::vector<std::string> input_names = sym->ListInputNames(nnvm::Symbol::kAll);
-  size_t num_forward_inputs = input_names.size();
+
+  // create a data structure from pointer array
+  std::unordered_map<std::string, std::string> options_map;
+  for (mx_uint i = 0; i < num_options; ++i)
+    options_map.emplace(keys[i], vals[i]);
 
   NDArray ***new_args_ptr = reinterpret_cast<NDArray***>(new_args_handle);
   NDArray ***new_aux_ptr = reinterpret_cast<NDArray***>(new_aux_handle);
+  NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
+  NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
 
-  if (args_len || aux_len) {
-    NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
-    NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
-    if (!skip_infer) {
-      Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-      mxnet::ShapeVector arg_shapes(args_len + aux_len);
-      nnvm::DTypeVector arg_dtypes(args_len + aux_len);
-      StorageTypeVector arg_stypes(args_len + aux_len);
-
-      // create the input shape, dtype and stype maps
-      std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
-      for (uint32_t i = 0; i < num_input_shapes; ++i) {
-        input_shape_map.emplace(input_shape_names[i],
+  auto init_graph = [&](nnvm::Symbol* s) {
+    nnvm::Graph g = Symbol2Graph(*s);
+    const auto& indexed_graph = g.indexed_graph();
+    const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
+    std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
+    size_t num_forward_inputs = input_names.size();
+
+    if (args_len || aux_len) {
+      if (!skip_infer) {
+        Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+        mxnet::ShapeVector arg_shapes(args_len + aux_len);
+        nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+        StorageTypeVector arg_stypes(args_len + aux_len);
+
+        // create the input shape, dtype and stype maps
+        std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
+        for (uint32_t i = 0; i < num_input_shapes; ++i) {
+          input_shape_map.emplace(input_shape_names[i],
                     mxnet::TShape(input_shape_data + input_shape_idx[i],
                     input_shape_data + input_shape_idx[i+1]));
-      }
-      std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
-      for (uint32_t i = 0; i < num_input_dtypes; ++i) {
-        input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
-      }
-      std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
-      for (uint32_t i = 0; i < num_input_stypes; ++i) {
-        input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
-      }
+        }
+        std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+        for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+          input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+        }
+        std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+        for (uint32_t i = 0; i < num_input_stypes; ++i) {
+          input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
+        }
 
-      size_t args_top = 0, aux_top = 0;
-      // loop over inputs to symbol in order and add to args/aux if mutable
-      for (size_t i = 0; i < num_forward_inputs; ++i) {
-        const uint32_t nid = indexed_graph.input_nodes().at(i);
-        if (mutable_nodes.count(nid)) {
-          CHECK_LT(aux_top, aux_len)
-            << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
-          if (in_aux_ptr[aux_top] != nullptr) {
-            const auto &in_arg = *(in_aux_ptr[aux_top]);
-            arg_shapes[i] = in_arg.shape();
-            arg_dtypes[i] = in_arg.dtype();
-            arg_stypes[i] = in_arg.storage_type();
-          }
-          aux_top++;
-        } else {
-          auto name = input_names[i];
-          CHECK_LT(args_top, args_len)
-            << "Cannot find arg '" << name << "' in provided args to optimize_for";
-          if (in_args_ptr[args_top] != nullptr) {
-            const auto &in_arg = *(in_args_ptr[args_top]);
-            arg_shapes[i] = in_arg.shape();
-            arg_dtypes[i] = in_arg.dtype();
-            arg_stypes[i] = in_arg.storage_type();
-          } else {
-            // input_names[i] is not in args but can be in the optional
-            // shape/type/stype attribute dicts.
-            auto it_shape = input_shape_map.find(name);
-            if (it_shape != input_shape_map.end()) {
-              arg_shapes[i] = it_shape->second;
-            }
-            auto it_type = input_dtype_map.find(name);
-            if (it_type != input_dtype_map.end()) {
-              arg_dtypes[i] = it_type->second;
+        size_t args_top = 0, aux_top = 0;
+        // loop over inputs to symbol in order and add to args/aux if mutable
+        for (size_t i = 0; i < num_forward_inputs; ++i) {
+          const uint32_t nid = indexed_graph.input_nodes().at(i);
+          if (mutable_nodes.count(nid)) {
+            CHECK_LT(aux_top, aux_len)
+              << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
+            if (in_aux_ptr[aux_top] != nullptr) {
+              const auto &in_arg = *(in_aux_ptr[aux_top]);
+              arg_shapes[i] = in_arg.shape();
+              arg_dtypes[i] = in_arg.dtype();
+              arg_stypes[i] = in_arg.storage_type();
             }
-            it_type = input_stype_map.find(name);
-            if (it_type != input_stype_map.end()) {
-              arg_stypes[i] = it_type->second;
+            aux_top++;
+          } else {
+            auto name = input_names[i];
+            CHECK_LT(args_top, args_len)
+              << "Cannot find arg '" << name << "' in provided args to optimize_for";
+            if (in_args_ptr[args_top] != nullptr) {
+              const auto &in_arg = *(in_args_ptr[args_top]);
+              arg_shapes[i] = in_arg.shape();
+              arg_dtypes[i] = in_arg.dtype();
+              arg_stypes[i] = in_arg.storage_type();
+            } else {
+              // input_names[i] is not in args but can be in the optional
+              // shape/type/stype attribute dicts.
+              auto it_shape = input_shape_map.find(name);
+              if (it_shape != input_shape_map.end()) {
+                arg_shapes[i] = it_shape->second;
+              }
+              auto it_type = input_dtype_map.find(name);
+              if (it_type != input_dtype_map.end()) {
+                arg_dtypes[i] = it_type->second;
+              }
+              it_type = input_stype_map.find(name);
+              if (it_type != input_stype_map.end()) {
+                arg_stypes[i] = it_type->second;
+              }
             }
+            args_top++;
           }
-          args_top++;
         }
-      }
 
-      g.attrs["context"] = std::make_shared<nnvm::any>(
+        g.attrs["context"] = std::make_shared<nnvm::any>(
           exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
 
-      // infer shapes
-      g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
-      // infer dtypes
-      g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
-      // infer stypes
-      g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+        // infer shapes
+        g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+        // infer dtypes
+        g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+        // infer stypes
+        g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+      }
+      // set args/aux as attributes on graph so that subgraph property can use them
+      std::vector<std::string> arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+      g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
+      g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
+
+      std::vector<std::string> aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+      g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
+      g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
+    } else {
+      // args/aux were not specified, so set nullptr/empty-lists
+      NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
+      std::vector<std::string> arg_names;
+      g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
+      g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
+
+      NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
+      std::vector<std::string> aux_names;
+      g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
+      g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
     }
-    // set args/aux as attributes on graph so that subgraph property can use them
-    std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
-    g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
-    g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
-
-    std::vector<std::string> aux_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
-    g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
-    g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
-  } else {
-    // args/aux were not specified, so set nullptr/empty-lists
-    NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
-    std::vector<std::string> arg_names;
-    g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
-    g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
-
-    NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
-    std::vector<std::string> aux_names;
-    g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
-    g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
-  }
-  // create a data structure from pointer array
-  std::unordered_map<std::string, std::string> options_map;
-  for (mx_uint i = 0; i < num_options; ++i)
-     options_map.emplace(keys[i], vals[i]);
 
-  // set dedup option as attribute on graph to enable dedup during partitioning
-  if (options_map.count("dedup_subgraph") > 0 &&
-      options_map.at("dedup_subgraph").compare("True") == 0)
-    g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
+    // set dedup option as attribute on graph to enable dedup during partitioning
+    if (options_map.count("dedup_subgraph") > 0 &&
+        options_map.at("dedup_subgraph").compare("True") == 0)
+      g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
+    return g;
+  };
 
   if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) {
     // use subgraph backend
@@ -1495,14 +1500,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
                                       ::Get()->GetSubgraphBackend(backend_name);
     const auto& subgraph_prop_list = backend->GetSubgraphProperties();
     for (auto property : subgraph_prop_list) {
+      nnvm::Graph g = init_graph(s);
       property->PrePartition(g, options_map);
       g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
       g = ApplyPass(std::move(g), "BuildSubgraph");
       g.attrs.erase("subgraph_property");
       property->PostPartition(g);
+      s->outputs = g.outputs;
     }
   } else if (dmlc::Registry<nnvm::PassFunctionReg>::Find(backend_name) != nullptr) {
     // use graph pass
+    nnvm::Graph g = init_graph(s);
     g.attrs["options_map"] = std::make_shared<nnvm::any>(options_map);
     g.attrs["pass_name"] = std::make_shared<nnvm::any>(backend_name);
     g = ApplyPass(std::move(g), backend_name);
@@ -1515,6 +1523,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
     g.attrs.erase("new_aux");
     g.attrs.erase("new_arg_names");
     g.attrs.erase("new_aux_names");
+    s->outputs = g.outputs;
 
     NDArray** new_arg_arr = new NDArray*[new_arg_names.size()];
     NDArray** new_aux_arr = new NDArray*[new_aux_names.size()];
@@ -1546,7 +1555,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
     // cannot find graph pass or subgraph backend registered in this name
     LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found";
   }
-  s->outputs = g.outputs;
+
   *ret_sym_handle = s;
   API_END_HANDLE_ERROR(delete s);
 }