You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/12 15:20:58 UTC
[incubator-mxnet] branch v1.x updated: Fix for optimize_for
multiple subgraph properties issue (#19263) (#20142)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new dd3a6f8 Fix for optimize_for multiple subgraph properties issue (#19263) (#20142)
dd3a6f8 is described below
commit dd3a6f82eed46ec0964387575b8c493f57ba0d31
Author: bgawrych <ba...@intel.com>
AuthorDate: Mon Apr 12 17:18:17 2021 +0200
Fix for optimize_for multiple subgraph properties issue (#19263) (#20142)
* initial commit
* fixed whitespace
Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
Co-authored-by: Sam Skalicky <sa...@gmail.com>
Co-authored-by: Ubuntu <ub...@ip-172-31-6-220.us-west-2.compute.internal>
---
src/c_api/c_api_symbolic.cc | 211 +++++++++++++++++++++++---------------------
1 file changed, 110 insertions(+), 101 deletions(-)
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index ccc0547..877c8de 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1372,122 +1372,127 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
- nnvm::Graph g = Symbol2Graph(*s);
- const auto& indexed_graph = g.indexed_graph();
- const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
- std::vector<std::string> input_names = sym->ListInputNames(nnvm::Symbol::kAll);
- size_t num_forward_inputs = input_names.size();
+
+ // create a data structure from pointer array
+ std::unordered_map<std::string, std::string> options_map;
+ for (mx_uint i = 0; i < num_options; ++i)
+ options_map.emplace(keys[i], vals[i]);
NDArray ***new_args_ptr = reinterpret_cast<NDArray***>(new_args_handle);
NDArray ***new_aux_ptr = reinterpret_cast<NDArray***>(new_aux_handle);
+ NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
+ NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
- if (args_len || aux_len) {
- NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
- NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
- if (!skip_infer) {
- Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
- mxnet::ShapeVector arg_shapes(args_len + aux_len);
- nnvm::DTypeVector arg_dtypes(args_len + aux_len);
- StorageTypeVector arg_stypes(args_len + aux_len);
-
- // create the input shape, dtype and stype maps
- std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
- for (uint32_t i = 0; i < num_input_shapes; ++i) {
- input_shape_map.emplace(input_shape_names[i],
+ auto init_graph = [&](nnvm::Symbol* s) {
+ nnvm::Graph g = Symbol2Graph(*s);
+ const auto& indexed_graph = g.indexed_graph();
+ const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
+ std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
+ size_t num_forward_inputs = input_names.size();
+
+ if (args_len || aux_len) {
+ if (!skip_infer) {
+ Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+ mxnet::ShapeVector arg_shapes(args_len + aux_len);
+ nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+ StorageTypeVector arg_stypes(args_len + aux_len);
+
+ // create the input shape, dtype and stype maps
+ std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
+ for (uint32_t i = 0; i < num_input_shapes; ++i) {
+ input_shape_map.emplace(input_shape_names[i],
mxnet::TShape(input_shape_data + input_shape_idx[i],
input_shape_data + input_shape_idx[i+1]));
- }
- std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
- for (uint32_t i = 0; i < num_input_dtypes; ++i) {
- input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
- }
- std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
- for (uint32_t i = 0; i < num_input_stypes; ++i) {
- input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
- }
+ }
+ std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+ for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+ input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+ }
+ std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+ for (uint32_t i = 0; i < num_input_stypes; ++i) {
+ input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
+ }
- size_t args_top = 0, aux_top = 0;
- // loop over inputs to symbol in order and add to args/aux if mutable
- for (size_t i = 0; i < num_forward_inputs; ++i) {
- const uint32_t nid = indexed_graph.input_nodes().at(i);
- if (mutable_nodes.count(nid)) {
- CHECK_LT(aux_top, aux_len)
- << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
- if (in_aux_ptr[aux_top] != nullptr) {
- const auto &in_arg = *(in_aux_ptr[aux_top]);
- arg_shapes[i] = in_arg.shape();
- arg_dtypes[i] = in_arg.dtype();
- arg_stypes[i] = in_arg.storage_type();
- }
- aux_top++;
- } else {
- auto name = input_names[i];
- CHECK_LT(args_top, args_len)
- << "Cannot find arg '" << name << "' in provided args to optimize_for";
- if (in_args_ptr[args_top] != nullptr) {
- const auto &in_arg = *(in_args_ptr[args_top]);
- arg_shapes[i] = in_arg.shape();
- arg_dtypes[i] = in_arg.dtype();
- arg_stypes[i] = in_arg.storage_type();
- } else {
- // input_names[i] is not in args but can be in the optional
- // shape/type/stype attribute dicts.
- auto it_shape = input_shape_map.find(name);
- if (it_shape != input_shape_map.end()) {
- arg_shapes[i] = it_shape->second;
- }
- auto it_type = input_dtype_map.find(name);
- if (it_type != input_dtype_map.end()) {
- arg_dtypes[i] = it_type->second;
+ size_t args_top = 0, aux_top = 0;
+ // loop over inputs to symbol in order and add to args/aux if mutable
+ for (size_t i = 0; i < num_forward_inputs; ++i) {
+ const uint32_t nid = indexed_graph.input_nodes().at(i);
+ if (mutable_nodes.count(nid)) {
+ CHECK_LT(aux_top, aux_len)
+ << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
+ if (in_aux_ptr[aux_top] != nullptr) {
+ const auto &in_arg = *(in_aux_ptr[aux_top]);
+ arg_shapes[i] = in_arg.shape();
+ arg_dtypes[i] = in_arg.dtype();
+ arg_stypes[i] = in_arg.storage_type();
}
- it_type = input_stype_map.find(name);
- if (it_type != input_stype_map.end()) {
- arg_stypes[i] = it_type->second;
+ aux_top++;
+ } else {
+ auto name = input_names[i];
+ CHECK_LT(args_top, args_len)
+ << "Cannot find arg '" << name << "' in provided args to optimize_for";
+ if (in_args_ptr[args_top] != nullptr) {
+ const auto &in_arg = *(in_args_ptr[args_top]);
+ arg_shapes[i] = in_arg.shape();
+ arg_dtypes[i] = in_arg.dtype();
+ arg_stypes[i] = in_arg.storage_type();
+ } else {
+ // input_names[i] is not in args but can be in the optional
+ // shape/type/stype attribute dicts.
+ auto it_shape = input_shape_map.find(name);
+ if (it_shape != input_shape_map.end()) {
+ arg_shapes[i] = it_shape->second;
+ }
+ auto it_type = input_dtype_map.find(name);
+ if (it_type != input_dtype_map.end()) {
+ arg_dtypes[i] = it_type->second;
+ }
+ it_type = input_stype_map.find(name);
+ if (it_type != input_stype_map.end()) {
+ arg_stypes[i] = it_type->second;
+ }
}
+ args_top++;
}
- args_top++;
}
- }
- g.attrs["context"] = std::make_shared<nnvm::any>(
+ g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
- // infer shapes
- g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
- // infer dtypes
- g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
- // infer stypes
- g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+ // infer shapes
+ g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+ // infer dtypes
+ g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+ // infer stypes
+ g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+ }
+ // set args/aux as attributes on graph so that subgraph property can use them
+ std::vector<std::string> arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+ g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
+ g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
+
+ std::vector<std::string> aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+ g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
+ g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
+ } else {
+ // args/aux were not specified, so set nullptr/empty-lists
+ NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
+ std::vector<std::string> arg_names;
+ g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
+ g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
+
+ NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
+ std::vector<std::string> aux_names;
+ g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
+ g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
}
- // set args/aux as attributes on graph so that subgraph property can use them
- std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
- g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
- g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
-
- std::vector<std::string> aux_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
- g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
- g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
- } else {
- // args/aux were not specified, so set nullptr/empty-lists
- NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
- std::vector<std::string> arg_names;
- g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
- g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);
-
- NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
- std::vector<std::string> aux_names;
- g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
- g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
- }
- // create a data structure from pointer array
- std::unordered_map<std::string, std::string> options_map;
- for (mx_uint i = 0; i < num_options; ++i)
- options_map.emplace(keys[i], vals[i]);
- // set dedup option as attribute on graph to enable dedup during partitioning
- if (options_map.count("dedup_subgraph") > 0 &&
- options_map.at("dedup_subgraph").compare("True") == 0)
- g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
+ // set dedup option as attribute on graph to enable dedup during partitioning
+ if (options_map.count("dedup_subgraph") > 0 &&
+ options_map.at("dedup_subgraph").compare("True") == 0)
+ g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
+ return g;
+ };
if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) {
// use subgraph backend
@@ -1495,14 +1500,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
+ nnvm::Graph g = init_graph(s);
property->PrePartition(g, options_map);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = ApplyPass(std::move(g), "BuildSubgraph");
g.attrs.erase("subgraph_property");
property->PostPartition(g);
+ s->outputs = g.outputs;
}
} else if (dmlc::Registry<nnvm::PassFunctionReg>::Find(backend_name) != nullptr) {
// use graph pass
+ nnvm::Graph g = init_graph(s);
g.attrs["options_map"] = std::make_shared<nnvm::any>(options_map);
g.attrs["pass_name"] = std::make_shared<nnvm::any>(backend_name);
g = ApplyPass(std::move(g), backend_name);
@@ -1515,6 +1523,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
g.attrs.erase("new_aux");
g.attrs.erase("new_arg_names");
g.attrs.erase("new_aux_names");
+ s->outputs = g.outputs;
NDArray** new_arg_arr = new NDArray*[new_arg_names.size()];
NDArray** new_aux_arr = new NDArray*[new_aux_names.size()];
@@ -1546,7 +1555,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
// cannot find graph pass or subgraph backend registered in this name
LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found";
}
- s->outputs = g.outputs;
+
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}