You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/14 01:51:33 UTC
[incubator-mxnet] branch subgraph updated: [DO NOT REVIEW] Subgraph
API (#12104)
This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/subgraph by this push:
new 4c1933e [DO NOT REVIEW] Subgraph API (#12104)
4c1933e is described below
commit 4c1933e8ca34e0a522a326e2366b9db730d915f0
Author: reminisce <wu...@gmail.com>
AuthorDate: Mon Aug 13 18:51:25 2018 -0700
[DO NOT REVIEW] Subgraph API (#12104)
* Initial commit
* Add unit tests
* Fix lint
* Fix lint
* Clean up
* Add graph partitiong to Bind
* Add property name to graph partitioning c api
* Fix unit test gpu context
* Address cr
* Move subgraph to attrs.subgraphs and fix the example
* Fix lint
* Add var version unit test
* Address cr
* Enable unit test that was flaky
---
example/subgraph_op/imagenet_inference.py | 31 ++++-
include/mxnet/c_api.h | 5 -
include/mxnet/c_api_test.h | 66 ++++++++++
include/mxnet/engine.h | 4 +-
src/c_api/c_api_symbolic.cc | 25 ----
src/c_api/c_api_test.cc | 73 +++++++++++
src/engine/naive_engine.cc | 4 -
src/engine/threaded_engine.cc | 2 +-
src/engine/threaded_engine.h | 2 +-
src/executor/graph_executor.cc | 151 ++++++++++++++++++++++
src/executor/graph_executor.h | 4 +
src/operator/subgraph/common.h | 18 +--
src/operator/subgraph/default_subgraph_op.cc | 8 +-
src/operator/subgraph/default_subgraph_op.cu | 7 +-
src/operator/subgraph/default_subgraph_op.h | 127 ------------------
src/operator/subgraph/default_subgraph_property.h | 81 ++++++++++++
src/operator/subgraph/partition_graph.cc | 4 +-
src/operator/subgraph/subgraph_property.h | 132 +++++++++++++++++++
tests/cpp/engine/threaded_engine_test.cc | 58 +++++++++
tests/python/unittest/test_gluon_trainer.py | 1 -
tests/python/unittest/test_subgraph_op.py | 114 +++++++++++++++-
21 files changed, 720 insertions(+), 197 deletions(-)
diff --git a/example/subgraph_op/imagenet_inference.py b/example/subgraph_op/imagenet_inference.py
index 8a38cff..a0f16f6 100644
--- a/example/subgraph_op/imagenet_inference.py
+++ b/example/subgraph_op/imagenet_inference.py
@@ -87,7 +87,8 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Score a model on a dataset')
- parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
+ parser.add_argument('--model', type=str, required=True,
+ choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--label-name', type=str, default='softmax_label')
@@ -107,6 +108,8 @@ if __name__ == '__main__':
help='shuffling seed, see'
' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
' for more details')
+ parser.add_argument('--subgraph-backend', type=str, default='default', help='subgraph backend name.')
+ parser.add_argument('--ctx', type=str, default='cpu')
args = parser.parse_args()
@@ -133,6 +136,15 @@ if __name__ == '__main__':
download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
logger.info('Dataset for inference: %s' % dataset)
+ subgraph_backend = args.subgraph_backend
+
+ if args.ctx == 'cpu':
+ ctx = mx.cpu()
+ elif args.ctx == 'gpu':
+ ctx = mx.gpu(0)
+ else:
+ raise ValueError('unknown ctx option, only cpu and gpu are supported')
+
# creating data iterator
data = mx.io.ImageRecordIter(path_imgrec=dataset,
label_width=1,
@@ -151,16 +163,21 @@ if __name__ == '__main__':
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
- out = SymbolHandle()
- check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names),
- ctypes.byref(out)))
- psym = Symbol(out)
-
+ if subgraph_backend is not None:
+ os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+ if subgraph_backend == 'default':
+ check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+ c_str_array(op_names)))
# make sure that fp32 inference works on the same images as calibrated quantized model
logger.info('Skipping the first %d batches' % args.num_skipped_batches)
data = advance_data_iter(data, args.num_skipped_batches)
num_inference_images = args.num_inference_batches * batch_size
logger.info('Running model %s for inference' % args.model)
- score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
+ score(sym, arg_params, aux_params, data, [ctx], label_name,
max_num_examples=num_inference_images, logger=logger)
+
+ if subgraph_backend is not None:
+ del os.environ['MXNET_SUBGRAPH_BACKEND']
+ if subgraph_backend == 'default':
+ check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 2987cd7..75147cf 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1479,11 +1479,6 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
const float* high_quantiles,
SymbolHandle* ret_sym_handle);
-MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
- const mx_uint num_ops,
- const char** op_names,
- SymbolHandle* ret_sym_handle);
-
//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h
new file mode 100644
index 0000000..fe6fc7f
--- /dev/null
+++ b/include/mxnet/c_api_test.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file c_api_test.h
+ * \brief C API of mxnet for ease of testing backend in Python
+ */
+#ifndef MXNET_C_API_TEST_H_
+#define MXNET_C_API_TEST_H_
+
+/*! \brief Inhibit C++ name-mangling for MXNet functions. */
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+#include <mxnet/c_api.h>
+
+/*!
+ * \brief This API partitions a graph only by the operator names
+ * provided by users. This will attach a DefaultSubgraphProperty
+ * to the input graph for partitioning. This function should be
+ * used only for the testing purpose.
+ */
+MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+ const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names,
+ SymbolHandle* ret_sym_handle);
+
+/*!
+ * \brief Given a subgraph property name, use the provided op names
+ * as the op_names attribute for that subgraph property, instead of
+ * the predefined one. This is only for the purpose of testing.
+ */
+MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names);
+
+/*!
+ * \brief Given a subgraph property name, delete the op name set
+ * in the SubgraphPropertyOpNameSet.
+ */
+MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
+
+#ifdef __cplusplus
+}
+#endif // __cplusplus
+
+#endif // MXNET_C_API_TEST_H_
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 2c33b6c..11e64ed 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -43,7 +43,7 @@ class Engine;
namespace engine {
/*! \brief base class of engine variables.*/
struct Var {
- virtual uint32_t version() {
+ virtual size_t version() {
return version_;
}
virtual ~Var() = default;
@@ -58,7 +58,7 @@ struct Var {
* \brief version number of the var. Every time the object it is associated with
* is modified, the version number is incremented by 1.
*/
- uint32_t version_{0};
+ size_t version_{0};
}; // struct Var
/*! \brief Internal representation of operator. */
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 7ed86ec..c27a59a 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -31,7 +31,6 @@
#include "./c_api_common.h"
#include "../operator/operator_common.h"
#include "../executor/exec_pass.h"
-#include "../operator/subgraph/default_subgraph_op.h"
namespace mxnet {
namespace op {
@@ -697,27 +696,3 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
*ret_qsym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
-
-int MXPartitionGraph(SymbolHandle sym_handle,
- const mx_uint num_ops,
- const char** op_names,
- SymbolHandle* ret_sym_handle) {
- nnvm::Symbol* s = new nnvm::Symbol();
- API_BEGIN();
- std::unordered_set<std::string> op_name_set;
- for (size_t i = 0; i < num_ops; ++i) {
- op_name_set.emplace(op_names[i]);
- }
- nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
- *s = sym->Copy();
- nnvm::Graph g = Symbol2Graph(*s);
- if (!op_name_set.empty()) {
- mxnet::op::SubgraphPropertyPtr property
- = std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
- g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
- }
- g = ApplyPass(std::move(g), "PartitionGraph");
- s->outputs = g.outputs;
- *ret_sym_handle = s;
- API_END_HANDLE_ERROR(delete s);
-}
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
new file mode 100644
index 0000000..2f5ad76
--- /dev/null
+++ b/src/c_api/c_api_test.cc
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file c_api_test.cc
+ * \brief C API of mxnet for the ease of testing backend in Python
+ */
+#include <mxnet/c_api_test.h>
+#include <nnvm/pass.h>
+#include "./c_api_common.h"
+#include "../operator/subgraph/default_subgraph_property.h"
+
+int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+ const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names,
+ SymbolHandle* ret_sym_handle) {
+ nnvm::Symbol* s = new nnvm::Symbol();
+ API_BEGIN();
+ std::unordered_set<std::string> op_name_set;
+ for (size_t i = 0; i < num_ops; ++i) {
+ op_name_set.emplace(op_names[i]);
+ }
+ nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+ *s = sym->Copy();
+ nnvm::Graph g;
+ g.outputs = s->outputs;
+ if (!op_name_set.empty()) {
+ mxnet::op::SubgraphPropertyPtr property
+ = mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+ property->SetAttr("op_names", op_name_set);
+ g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
+ }
+ g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
+ s->outputs = g.outputs;
+ *ret_sym_handle = s;
+ API_END_HANDLE_ERROR(delete s);
+}
+
+int MXSetSubgraphPropertyOpNames(const char* prop_name,
+ const mx_uint num_ops,
+ const char** op_names) {
+ API_BEGIN();
+ std::unordered_set<std::string> op_name_set;
+ for (size_t i = 0; i < num_ops; ++i) {
+ op_name_set.emplace(op_names[i]);
+ }
+ (*mxnet::op::SubgraphPropertyOpNameSet::Get())[prop_name] = op_name_set;
+ API_END();
+}
+
+int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
+ API_BEGIN();
+ mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
+ API_END();
+}
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index e0a47fa..8adac9e 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -86,10 +86,6 @@ class NaiveEngine final : public Engine {
// new variables
VarHandle NewVariable() override {
return NaiveVar::New();
-#if 0
- size_t v = ++counter_;
- return reinterpret_cast<VarHandle>(v);
-#endif
}
OprHandle NewOperator(AsyncFn fn,
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index bd11697..3a7587f 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -199,7 +199,7 @@ inline bool ThreadedVar::ready_to_read() {
return this->is_ready_to_read();
}
-inline uint32_t ThreadedVar::version() {
+inline size_t ThreadedVar::version() {
std::lock_guard<std::mutex> lock{mutex_};
return this->version_;
}
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 7730c06..a2c1a2b 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,7 +162,7 @@ class ThreadedVar final
inline void SetToDelete();
/*! \return whether this variable is ready to read. */
inline bool ready_to_read();
- inline uint32_t version() override;
+ inline size_t version() override;
/*!
* \brief Cast a Var pointer to ThreadedVar pointer
* \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 33c6f57..4fc36b9 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -33,6 +33,7 @@
#include "../profiler/profiler.h"
#include "../common/utils.h"
#include "../common/exec_utils.h"
+#include "../operator/subgraph/subgraph_property.h"
namespace mxnet {
namespace exec {
@@ -40,6 +41,7 @@ namespace exec {
GraphExecutor::GraphExecutor() {
log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
need_grad_ = false;
+ subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
}
GraphExecutor::~GraphExecutor() {
@@ -1699,6 +1701,146 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start,
iter->c_str());
return ret;
}
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+ nnvm::ShapeVector arg_shapes,
+ nnvm::DTypeVector arg_dtypes,
+ StorageTypeVector arg_stypes,
+ const Context& default_ctx,
+ const std::map<std::string, Context>& ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>& aux_state_ctxes) {
+ const auto& indexed_graph = g.indexed_graph();
+ const auto num_forward_inputs = indexed_graph.input_nodes().size();
+ g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+ aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+ g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+ if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+ HandleInferShapeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<nnvm::ShapeVector>("shape"));
+ }
+ g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+ if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+ HandleInferTypeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<nnvm::DTypeVector>("dtype"));
+ }
+ g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+ if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+ HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<StorageTypeVector>("storage_type"));
+ }
+ return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const nnvm::ShapeVector& arg_shapes,
+ const nnvm::DTypeVector& arg_dtypes,
+ const StorageTypeVector& arg_stypes,
+ const Context& default_ctx,
+ const std::map<std::string, Context>& ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>& aux_state_ctxes) {
+ auto subgraph_prop = op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+ nnvm::Symbol ret = src.Copy();
+ nnvm::Graph g;
+ g.outputs = ret.outputs;
+ g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
+ ctx_map, in_arg_ctxes, aux_state_ctxes);
+ subgraph_prop->SetAttr("graph", g);
+ auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name);
+ // assign a op name set to the subgraph property if it has been provided by users
+ if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
+ LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << prop_name
+ << " has been assigned a value. Please make sure it is initialized"
+ " only for the testing purpose.";
+ subgraph_prop->SetAttr("op_names", it->second);
+ }
+ g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(subgraph_prop));
+ g = ApplyPass(std::move(g), "PartitionGraph");
+ ret.outputs = g.outputs;
+ return ret;
+}
+
+// Given input attr dicts, partition the graph using the backend name equal to prop_name.
+// This is for simple_bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const std::unordered_map<std::string, TShape>& arg_shape_map,
+ const std::unordered_map<std::string, int>& arg_dtype_map,
+ const std::unordered_map<std::string, int>& arg_stype_map,
+ const Context& default_ctx,
+ const std::map<std::string, Context>& ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>& aux_state_ctxes) {
+ const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
+ nnvm::ShapeVector arg_shapes(input_names.size(), TShape());
+ nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
+ StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ auto it1 = arg_shape_map.find(input_names[i]);
+ if (arg_shape_map.end() != it1) {
+ arg_shapes[i] = it1->second;
+ }
+ auto it2 = arg_dtype_map.find(input_names[i]);
+ if (arg_dtype_map.end() != it2) {
+ arg_dtypes[i] = it2->second;
+ }
+ auto it3 = arg_stype_map.find(input_names[i]);
+ if (arg_stype_map.end() != it3) {
+ arg_stypes[i] = it3->second;
+ }
+ }
+ return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+ default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
+
+// Given input ndarrays, partition the graph using the backend name equal to prop_name.
+// This is for bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const std::vector<NDArray> &in_args,
+ const std::vector<NDArray> &aux_states,
+ const Context& default_ctx,
+ const std::map<std::string, Context>& ctx_map) {
+ const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
+ const std::vector<std::string> arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+ const std::vector<std::string> aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+ CHECK_EQ(arg_names.size(), in_args.size());
+ CHECK_EQ(aux_names.size(), aux_states.size());
+ nnvm::ShapeVector arg_shapes; // all input shapes
+ arg_shapes.reserve(input_names.size());
+ nnvm::DTypeVector arg_dtypes; // all input dtypes
+ arg_dtypes.reserve(input_names.size());
+ StorageTypeVector arg_stypes; // all input stypes
+ arg_stypes.reserve(input_names.size());
+ std::vector<Context> in_arg_ctxes(in_args.size());
+ std::vector<Context> aux_state_ctxes(aux_states.size());
+
+ size_t i1 = 0, i2 = 0;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ if (i2 < aux_names.size() && aux_names[i2] == input_names[i]) {
+ arg_shapes.push_back(aux_states[i2].shape());
+ arg_dtypes.push_back(aux_states[i2].dtype());
+ arg_stypes.push_back(aux_states[i2].storage_type());
+ aux_state_ctxes[i2] = aux_states[i2].ctx();
+ ++i2;
+ } else {
+ CHECK(i1 < arg_names.size());
+ CHECK_EQ(arg_names[i1], input_names[i]);
+ arg_shapes.push_back(in_args[i1].shape());
+ arg_dtypes.push_back(in_args[i1].dtype());
+ arg_stypes.push_back(in_args[i1].storage_type());
+ in_arg_ctxes[i1] = in_args[i1].ctx();
+ ++i1;
+ }
+ }
+ return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+ default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
} // namespace exec
Executor *Executor::SimpleBind(nnvm::Symbol symbol,
@@ -1718,6 +1860,11 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray>* shared_buffer,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
+ if (!exec->subgraph_property().empty()) {
+ symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
+ arg_stype_map, default_ctx, group2ctx, in_arg_ctxes,
+ aux_state_ctxes);
+ }
exec->Init(symbol, default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map, arg_stype_map,
@@ -1736,6 +1883,10 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
const std::vector<NDArray> &aux_states,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
+ if (!exec->subgraph_property().empty()) {
+ symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, aux_states,
+ default_ctx, group2ctx);
+ }
exec->Init(symbol, default_ctx, group2ctx,
in_args, arg_grad_store, grad_req_type, aux_states,
reinterpret_cast<Executor*>(shared_exec));
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bfc415b..b4d36b1 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -117,6 +117,8 @@ class GraphExecutor : public Executor {
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states) override;
+ const std::string& subgraph_property() const { return subgraph_property_; }
+
protected:
friend class mxnet::Imperative;
// Information about operational node
@@ -256,6 +258,8 @@ class GraphExecutor : public Executor {
std::unordered_set<std::string> cached_seg_opr_names_;
// verbose logging
bool log_verbose_ = false;
+ // subgraph property name
+ std::string subgraph_property_;
};
} // namespace exec
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
index 472312d..bf46104 100644
--- a/src/operator/subgraph/common.h
+++ b/src/operator/subgraph/common.h
@@ -57,22 +57,22 @@ struct SimpleNode {
} // namespace sg
inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListInputNames(nnvm::Symbol::kAll).size();
}
inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListOutputNames().size();
}
inline std::vector<std::string> DefaultSubgraphOpListInputs(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListInputNames(nnvm::Symbol::kAll);
}
inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& sym = *attrs.subgraphs[0];
return sym.ListOutputNames();
}
@@ -80,7 +80,7 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shapes,
std::vector<TShape> *out_shapes) {
using namespace exec;
- const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
@@ -124,7 +124,7 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_types,
std::vector<int> *out_types) {
- const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
@@ -169,7 +169,7 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
- const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
nnvm::Graph g;
g.outputs = subgraph_sym.outputs;
const auto& idx_g = g.indexed_graph();
@@ -222,7 +222,7 @@ inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
}
inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
const std::vector<std::string> immutable_input_names =
subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
@@ -245,7 +245,7 @@ inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr
}
inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
- const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+ const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
std::set<ResourceRequest::Type> resource_types;
DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
index 8372ae9..491d6ee 100644
--- a/src/operator/subgraph/default_subgraph_op.cc
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -18,7 +18,7 @@
*/
#include <mxnet/ndarray.h>
-#include "./default_subgraph_op.h"
+#include "./common.h"
#include "../../imperative/imperative_utils.h"
#include "../../imperative/cached_op.h"
@@ -30,7 +30,8 @@ namespace op {
class DefaultSubgraphOperator {
public:
explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
- subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}}));
+ subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
+ {"static_shape", "true"}}));
}
void Forward(const OpContext& ctx,
@@ -79,8 +80,7 @@ OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
Context ctx,
const std::vector<TShape>& in_shapes,
const std::vector<int>& in_types) {
- const Symbol& subgraph_sym = nnvm::get<Symbol>(attrs.parsed);
- return OpStatePtr::Create<DefaultSubgraphOperator>(subgraph_sym);
+ return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
}
void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
index 15a76e3..008826b 100644
--- a/src/operator/subgraph/default_subgraph_op.cu
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -19,11 +19,14 @@
/*!
* Copyright (c) 2018 by Contributors
- * \file subgraph_op.cu
+ * \file default_subgraph_op.cu
* \brief GPU Implementation of subgraph operations
*/
-#include "./default_subgraph_op.h"
+#include <mxnet/ndarray.h>
+#include "./common.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
namespace mxnet {
namespace op {
diff --git a/src/operator/subgraph/default_subgraph_op.h b/src/operator/subgraph/default_subgraph_op.h
deleted file mode 100644
index 7d6624e..0000000
--- a/src/operator/subgraph/default_subgraph_op.h
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
-#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
-
-#include <vector>
-#include <string>
-#include "./common.h"
-
-namespace mxnet {
-namespace op {
-
-/*
- * This provides criteria for selecting nodes in a subgraph.
- * When a node is passed to this object, the selection criteria may be changed.
- * We can also specify what links we should use when traversing the neighbor
- * nodes.
- */
-class SubgraphSelector {
- public:
- virtual ~SubgraphSelector() {
- }
- // Determine if the node should be selected for a subgraph.
- virtual bool Select(const nnvm::Node &n) = 0;
- // Determine if the input node should be selected for a subgraph.
- virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
- // Determine if the output node should be selected for a subgraph.
- virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
- // Post processes pre-selected subgraph nodes. Return a list of nodes that
- // users want to keep in subgraph(s).
- virtual std::vector<nnvm::Node*> Filter(nnvm::Graph* g,
- const std::vector<nnvm::Node*>& candidates) {
- return candidates;
- }
-};
-
-using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
-
-/*!
- * \brief This provides a set of properties for partitioning a graph into subgraphs,
- * reconstructing a new graph from the subgraphs and creating a subgraph
- * operator to execute the subgraph.
- */
-class SubgraphProperty {
- public:
- // the criteria of selecting the subgraph nodes.
- virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
- // create an nnvm node for a given subgraph. Here users can customize how to
- // execute the operators in the subgraph.
- virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
- const int subgraph_id = 0) const = 0;
-};
-
-using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
-
-void RegisterSubgraphProperty(SubgraphPropertyPtr property);
-
-/*
- * This selects nodes for a subgraph that only contains operators
- * in a given set and it visits nodes via both input and output links.
- */
-class ContainOpSelector: public SubgraphSelector {
- std::shared_ptr<const std::unordered_set<std::string>> op_names;
-
- public:
- explicit ContainOpSelector(std::shared_ptr<const std::unordered_set<std::string>> op_names) {
- this->op_names = op_names;
- }
-
- virtual bool Select(const nnvm::Node &n) {
- return !n.is_variable() && op_names->count(n.op()->name);
- }
-
- virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
- return !new_node.is_variable() && op_names->count(new_node.op()->name);
- }
-
- virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
- return !new_node.is_variable() && op_names->count(new_node.op()->name);
- }
-};
-
-/*
- * This subgraph property finds a subgraph whose nodes have only operators
- * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
- */
-class DefaultSubgraphProperty: public SubgraphProperty {
- public:
- explicit DefaultSubgraphProperty(const std::unordered_set<std::string> &op_names) :
- op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {}
- virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
- const int subgraph_id = 0) const {
- nnvm::NodePtr n = nnvm::Node::Create();
- n->attrs.op = Op::Get("_default_subgraph_op");
- n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
- n->attrs.parsed = sym;
- return n;
- }
- virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
- return std::make_shared<ContainOpSelector>(op_names_);
- }
-
- private:
- std::shared_ptr<const std::unordered_set<std::string>> op_names_;
-};
-
-} // namespace op
-} // namespace mxnet
-
-#endif // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
diff --git a/src/operator/subgraph/default_subgraph_property.h b/src/operator/subgraph/default_subgraph_property.h
new file mode 100644
index 0000000..3882247
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_property.h
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
+
+#include <vector>
+#include <string>
+#include "./common.h"
+#include "./subgraph_property.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+ public:
+ explicit ContainOpSelector(const std::unordered_set<std::string>& op_names)
+ : op_names_(op_names) {}
+
+ virtual bool Select(const nnvm::Node &n) {
+ return !n.is_variable() && op_names_.count(n.op()->name);
+ }
+
+ virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
+ return !new_node.is_variable() && op_names_.count(new_node.op()->name);
+ }
+
+ virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
+ return !new_node.is_variable() && op_names_.count(new_node.op()->name);
+ }
+ private:
+ const std::unordered_set<std::string>& op_names_;
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+ static SubgraphPropertyPtr Create() { return std::make_shared<DefaultSubgraphProperty>(); }
+ virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const {
+ nnvm::NodePtr n = nnvm::Node::Create();
+ n->attrs.op = Op::Get("_default_subgraph_op");
+ n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+ n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+ return n;
+ }
+ virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+ return std::make_shared<ContainOpSelector>(
+ this->GetAttr<std::unordered_set<std::string>>("op_names"));
+ }
+};
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(default, DefaultSubgraphProperty);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
index 9672877..e8c3069 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -29,7 +29,7 @@
#include <stack>
#include <queue>
-#include "./default_subgraph_op.h"
+#include "./subgraph_property.h"
#include "./common.h"
namespace nnvm {
@@ -408,7 +408,7 @@ void FindSubgraphs(Graph* g,
&preselected_nodes);
// filter out unqualified pre-selected nodes
- std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(g, preselected_nodes);
+ std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);
// make sure filtered_nodes is a subset of preselected_nodes
for (const auto n : filtered_nodes) {
diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h
new file mode 100644
index 0000000..2153a36
--- /dev/null
+++ b/src/operator/subgraph/subgraph_property.h
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+
+#include <nnvm/node.h>
+#include <dmlc/base.h>
+#include <dmlc/thread_local.h>
+#include <unordered_map>
+#include <vector>
+#include <string>
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for selecting nodes in a subgraph.
+ * When a node is passed to this object, the selection criteria may be changed.
+ * We can also specify what links we should use when traversing the neighbor
+ * nodes.
+ */
+class SubgraphSelector {
+ public:
+ virtual ~SubgraphSelector() {}
+ // Determine if the node should be selected for a subgraph.
+ virtual bool Select(const nnvm::Node &n) = 0;
+ // Determine if the input node should be selected for a subgraph.
+ virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+ // Determine if the output node should be selected for a subgraph.
+ virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+ // Post processes pre-selected subgraph nodes. Return a list of nodes that
+ // users want to keep in subgraph(s).
+ virtual std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) {
+ return candidates;
+ }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+ // the criteria of selecting the subgraph nodes.
+ virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+ // create an nnvm node for a given subgraph. Here users can customize how to
+ // execute the operators in the subgraph.
+ virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+ const int subgraph_id = 0) const = 0;
+ // set an attr with name in the attr map
+ template<typename T>
+ SubgraphProperty& SetAttr(const std::string& name, const T& value) {
+ attrs_[name] = std::make_shared<dmlc::any>(value);
+ return *this;
+ }
+ // get the attr with the name
+ template<typename T>
+ const T& GetAttr(const std::string& name) const {
+ auto it = attrs_.find(name);
+ CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty";
+ return nnvm::get<T>(*it->second);
+ }
+ protected:
+ std::unordered_map<std::string, std::shared_ptr<nnvm::any>> attrs_;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+class SubgraphPropertyRegistry {
+ public:
+ typedef SubgraphPropertyPtr (*SubgraphPropertyCreateFn)(void);
+ static SubgraphPropertyRegistry* Get() {
+ static SubgraphPropertyRegistry inst;
+ return &inst;
+ }
+
+ SubgraphPropertyPtr CreateSubgraphProperty(const std::string& name) {
+ auto it = prop_fn_map_.find(name);
+ CHECK(it != prop_fn_map_.end()) << "SubgraphProperty " << name
+ << " is not found in SubgraphPropertyRegistry";
+ return it->second();
+ }
+
+ SubgraphPropertyCreateFn __REGISTER__(const std::string& name, SubgraphPropertyCreateFn fn) {
+ CHECK_EQ(prop_fn_map_.count(name), 0U) << "Subgraph property " << name
+ << " has been registered";
+ prop_fn_map_[name] = fn;
+ return prop_fn_map_[name];
+ }
+
+ private:
+ SubgraphPropertyRegistry() = default;
+ SubgraphPropertyRegistry(const SubgraphPropertyRegistry&) = delete;
+ SubgraphPropertyRegistry(SubgraphPropertyRegistry&&) = delete;
+ SubgraphPropertyRegistry& operator=(const SubgraphPropertyRegistry&) = delete;
+ std::unordered_map<std::string, SubgraphPropertyCreateFn> prop_fn_map_;
+};
+
+// This op name set is for setting the names of operators that should be grouped into
+// subgraphs. In practice, every backend accelerator should have a predefined name set.
+// This set is only used for the testing purpose.
+// key: property name, value: op name set
+typedef dmlc::ThreadLocalStore<std::unordered_map<std::string, std::unordered_set<std::string>>>
+ SubgraphPropertyOpNameSet;
+
+#define MXNET_REGISTER_SUBGRAPH_PROPERTY(Name, SubgraphPropertyType) \
+ static DMLC_ATTRIBUTE_UNUSED auto __make_ ## SubgraphPropertyType ## _ ## Name ## __ = \
+ SubgraphPropertyRegistry::Get()->__REGISTER__(#Name, &SubgraphPropertyType::Create)
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index 92d0958..6d669c1 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -275,6 +275,64 @@ TEST(Engine, basics) {
LOG(INFO) << "All pass";
}
+TEST(Engine, VarVersion) {
+ const size_t num_engines = 3;
+ std::vector<mxnet::Engine*> engines(num_engines);
+ engines[0] = mxnet::engine::CreateNaiveEngine();
+ engines[1] = mxnet::engine::CreateThreadedEnginePooled();
+ engines[2] = mxnet::engine::CreateThreadedEnginePerDevice();
+ std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled", "ThreadedEnginePerDevice"};
+ for (size_t k = 0; k < num_engines; ++k) {
+ auto engine = engines[k];
+ std::vector<mxnet::Engine::OprHandle> oprs;
+
+ LOG(INFO) << "Testing var as a read dependency in " << type_names[k];
+ auto var = engine->NewVariable();
+ EXPECT_EQ(var->version(), 0U);
+ for (int i = 0; i < 10; ++i) {
+ oprs.push_back(engine->NewOperator(
+ [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+ Foo(ctx, i);
+ cb();
+ },
+ {var}, {}));
+ engine->Push(oprs.at(i), mxnet::Context{});
+ }
+ engine->WaitForAll();
+ EXPECT_EQ(var->version(), 0U);
+ for (auto&& i : oprs) {
+ engine->DeleteOperator(i);
+ }
+ engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+ engine->WaitForAll();
+
+ LOG(INFO) << "Testing var as a write dependency in " << type_names[k];
+ var = engine->NewVariable();
+ EXPECT_EQ(var->version(), 0U);
+ oprs.clear();
+ for (int i = 0; i < 10; ++i) {
+ oprs.push_back(engine->NewOperator(
+ [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+ Foo(ctx, i);
+ cb();
+ },
+ {}, {var}));
+ engine->Push(oprs.at(i), mxnet::Context{});
+ }
+ engine->WaitForAll();
+ EXPECT_EQ(var->version(), 10U);
+ for (auto&& i : oprs) {
+ engine->DeleteOperator(i);
+ }
+ engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+ engine->WaitForAll();
+
+ var = nullptr;
+ oprs.clear();
+ LOG(INFO) << "All pass";
+ }
+}
+
#ifdef _OPENMP
struct TestSaveAndRestoreOMPState {
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index 13e8e4e..2a34400 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -175,7 +175,6 @@ def test_trainer_save_load():
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
-@unittest.skip("temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11353")
@with_seed()
def test_trainer_reset_kv():
def check_trainer_reset_kv(kv):
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
index f6a33c2..40d609a 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -15,26 +15,29 @@
# specific language governing permissions and limitations
# under the License.
+import os
import ctypes
import mxnet as mx
-from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, c_str
from mxnet.symbol import Symbol
import numpy as np
from mxnet.test_utils import assert_almost_equal
def test_subgraph_exe():
- def check_subgraph_exe(sym, op_names):
+ def _check_subgraph_exe1(sym, op_names):
+ """Use the partitioned sym to simple_bind an executor and compare the outputs
+ with those of the original executor"""
out = SymbolHandle()
- check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)),
- c_str_array(op_names), ctypes.byref(out)))
+ check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, c_str('default'), mx_uint(len(op_names)),
+ c_str_array(op_names), ctypes.byref(out)))
partitioned_sym = Symbol(out)
assert partitioned_sym.list_inputs() == sym.list_inputs()
assert partitioned_sym.list_arguments() == sym.list_arguments()
assert partitioned_sym.list_auxiliary_states() == sym.list_auxiliary_states()
- exe = sym.simple_bind(ctx=mx.cpu(), grad_req='null')
- partitioned_exe = partitioned_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+ exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+ partitioned_exe = partitioned_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
input_names = sym.list_inputs()
for name in input_names:
if name in exe.arg_dict:
@@ -46,12 +49,109 @@ def test_subgraph_exe():
partitioned_exe.aux_dict[name][:] = exe.aux_dict[name]
exe.forward()
partitioned_exe.forward()
- mx.nd.waitall()
assert len(exe.outputs) == len(partitioned_exe.outputs)
for i in range(len(exe.outputs)):
assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
np.zeros(shape=(1,)))
+ def _check_subgraph_exe2(sym, op_names):
+ """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in simple_bind
+ and compare results of the partitioned sym and the original sym."""
+ def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
+ if subgraph_backend is not None:
+ os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+ check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+ c_str_array(op_names)))
+ exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+ input_names = sym.list_inputs()
+ for name in input_names:
+ if name in exe.arg_dict:
+ exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\
+ if original_exec is None else original_exec.arg_dict[name]
+ else:
+ assert name in exe.aux_dict
+ exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
+ if original_exec is None else original_exec.aux_dict[name]
+ exe.forward()
+ if subgraph_backend is not None:
+ check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+ del os.environ['MXNET_SUBGRAPH_BACKEND']
+ return exe
+
+ original_exec = get_executor(sym)
+ partitioned_exec = get_executor(sym, 'default', op_names, original_exec)
+ outputs1 = original_exec.outputs
+ outputs2 = partitioned_exec.outputs
+ assert len(outputs1) == len(outputs2)
+ for i in range(len(outputs1)):
+ assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+ def _check_subgraph_exe3(sym, op_names):
+ """Use the partitioned sym to bind an executor and compare the outputs
+ with those of the original executor"""
+ out = SymbolHandle()
+ check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, c_str('default'), mx_uint(len(op_names)),
+ c_str_array(op_names), ctypes.byref(out)))
+
+ partitioned_sym = Symbol(out)
+ input_names = sym.list_inputs()
+ arg_names = sym.list_arguments()
+ aux_names = sym.list_auxiliary_states()
+ assert partitioned_sym.list_inputs() == input_names
+ assert partitioned_sym.list_arguments() == arg_names
+ assert partitioned_sym.list_auxiliary_states() == aux_names
+ arg_shapes, _, aux_shapes = sym.infer_shape()
+ arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+ aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+ exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
+ partitioned_exe = partitioned_sym.bind(ctx=mx.current_context(), args=arg_array,
+ aux_states=aux_array, grad_req='null')
+ exe.forward()
+ partitioned_exe.forward()
+ assert len(exe.outputs) == len(partitioned_exe.outputs)
+ for i in range(len(exe.outputs)):
+ assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
+ np.zeros(shape=(1,)))
+
+ def _check_subgraph_exe4(sym, op_names):
+ """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in bind
+ and compare results of the partitioned sym and the original sym."""
+ def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
+ if subgraph_backend is not None:
+ os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+ check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+ c_str_array(op_names)))
+ arg_shapes, _, aux_shapes = sym.infer_shape()
+ if subgraph_backend is None:
+ arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+ aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+ else:
+ arg_array = None
+ aux_array = None
+ exe = sym.bind(ctx=mx.current_context(),
+ args=arg_array if subgraph_backend is None else original_exec.arg_arrays,
+ aux_states=aux_array if subgraph_backend is None else original_exec.aux_arrays,
+ grad_req='null')
+ exe.forward()
+ if subgraph_backend is not None:
+ check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+ del os.environ['MXNET_SUBGRAPH_BACKEND']
+ return exe
+
+ original_exec = get_executor(sym)
+ partitioned_exec = get_executor(sym, 'default', op_names, original_exec)
+ outputs1 = original_exec.outputs
+ outputs2 = partitioned_exec.outputs
+ assert len(outputs1) == len(outputs2)
+ for i in range(len(outputs1)):
+ assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+ def check_subgraph_exe(sym, op_names):
+ _check_subgraph_exe1(sym, op_names)
+ _check_subgraph_exe2(sym, op_names)
+ _check_subgraph_exe3(sym, op_names)
+ _check_subgraph_exe4(sym, op_names)
+
def test_network_structure_1():
data1 = mx.sym.var('data1', shape=(2, 3, 10, 10))
data2 = mx.sym.var('data2')