You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/08/14 01:51:33 UTC

[incubator-mxnet] branch subgraph updated: [DO NOT REVIEW] Subgraph API (#12104)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/subgraph by this push:
     new 4c1933e  [DO NOT REVIEW] Subgraph API (#12104)
4c1933e is described below

commit 4c1933e8ca34e0a522a326e2366b9db730d915f0
Author: reminisce <wu...@gmail.com>
AuthorDate: Mon Aug 13 18:51:25 2018 -0700

    [DO NOT REVIEW] Subgraph API (#12104)
    
    * Initial commit
    
    * Add unit tests
    
    * Fix lint
    
    * Fix lint
    
    * Clean up
    
    * Add graph partitiong to Bind
    
    * Add property name to graph partitioning c api
    
    * Fix unit test gpu context
    
    * Address cr
    
    * Move subgraph to attrs.subgraphs and fix the example
    
    * Fix lint
    
    * Add var version unit test
    
    * Address cr
    
    * Enable unit test that was flaky
---
 example/subgraph_op/imagenet_inference.py         |  31 ++++-
 include/mxnet/c_api.h                             |   5 -
 include/mxnet/c_api_test.h                        |  66 ++++++++++
 include/mxnet/engine.h                            |   4 +-
 src/c_api/c_api_symbolic.cc                       |  25 ----
 src/c_api/c_api_test.cc                           |  73 +++++++++++
 src/engine/naive_engine.cc                        |   4 -
 src/engine/threaded_engine.cc                     |   2 +-
 src/engine/threaded_engine.h                      |   2 +-
 src/executor/graph_executor.cc                    | 151 ++++++++++++++++++++++
 src/executor/graph_executor.h                     |   4 +
 src/operator/subgraph/common.h                    |  18 +--
 src/operator/subgraph/default_subgraph_op.cc      |   8 +-
 src/operator/subgraph/default_subgraph_op.cu      |   7 +-
 src/operator/subgraph/default_subgraph_op.h       | 127 ------------------
 src/operator/subgraph/default_subgraph_property.h |  81 ++++++++++++
 src/operator/subgraph/partition_graph.cc          |   4 +-
 src/operator/subgraph/subgraph_property.h         | 132 +++++++++++++++++++
 tests/cpp/engine/threaded_engine_test.cc          |  58 +++++++++
 tests/python/unittest/test_gluon_trainer.py       |   1 -
 tests/python/unittest/test_subgraph_op.py         | 114 +++++++++++++++-
 21 files changed, 720 insertions(+), 197 deletions(-)

diff --git a/example/subgraph_op/imagenet_inference.py b/example/subgraph_op/imagenet_inference.py
index 8a38cff..a0f16f6 100644
--- a/example/subgraph_op/imagenet_inference.py
+++ b/example/subgraph_op/imagenet_inference.py
@@ -87,7 +87,8 @@ def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples,
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Score a model on a dataset')
-    parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
+    parser.add_argument('--model', type=str, required=True,
+                        choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
                         help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
     parser.add_argument('--batch-size', type=int, default=32)
     parser.add_argument('--label-name', type=str, default='softmax_label')
@@ -107,6 +108,8 @@ if __name__ == '__main__':
                         help='shuffling seed, see'
                              ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
                              ' for more details')
+    parser.add_argument('--subgraph-backend', type=str, default='default', help='subgraph backend name.')
+    parser.add_argument('--ctx', type=str, default='cpu')
 
     args = parser.parse_args()
 
@@ -133,6 +136,15 @@ if __name__ == '__main__':
     download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
     logger.info('Dataset for inference: %s' % dataset)
 
+    subgraph_backend = args.subgraph_backend
+
+    if args.ctx == 'cpu':
+        ctx = mx.cpu()
+    elif args.ctx == 'gpu':
+        ctx = mx.gpu(0)
+    else:
+        raise ValueError('unknown ctx option, only cpu and gpu are supported')
+
     # creating data iterator
     data = mx.io.ImageRecordIter(path_imgrec=dataset,
                                  label_width=1,
@@ -151,16 +163,21 @@ if __name__ == '__main__':
     prefix, epoch = download_model(model_name=args.model, logger=logger)
     sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
     op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
-    out = SymbolHandle()
-    check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names),
-                                     ctypes.byref(out)))
-    psym = Symbol(out)
-
+    if subgraph_backend is not None:
+        os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+        if subgraph_backend == 'default':
+            check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+                                                         c_str_array(op_names)))
     # make sure that fp32 inference works on the same images as calibrated quantized model
     logger.info('Skipping the first %d batches' % args.num_skipped_batches)
     data = advance_data_iter(data, args.num_skipped_batches)
 
     num_inference_images = args.num_inference_batches * batch_size
     logger.info('Running model %s for inference' % args.model)
-    score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
+    score(sym, arg_params, aux_params, data, [ctx], label_name,
           max_num_examples=num_inference_images, logger=logger)
+
+    if subgraph_backend is not None:
+        del os.environ['MXNET_SUBGRAPH_BACKEND']
+        if subgraph_backend == 'default':
+            check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 2987cd7..75147cf 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1479,11 +1479,6 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
                                                const float* high_quantiles,
                                                SymbolHandle* ret_sym_handle);
 
-MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
-                               const mx_uint num_ops,
-                               const char** op_names,
-                               SymbolHandle* ret_sym_handle);
-
 //--------------------------------------------
 // Part 4: Executor interface
 //--------------------------------------------
diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h
new file mode 100644
index 0000000..fe6fc7f
--- /dev/null
+++ b/include/mxnet/c_api_test.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file c_api_test.h
+ * \brief C API of mxnet for ease of testing backend in Python
+ */
+#ifndef MXNET_C_API_TEST_H_
+#define MXNET_C_API_TEST_H_
+
+/*! \brief Inhibit C++ name-mangling for MXNet functions. */
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+#include <mxnet/c_api.h>
+
+/*!
+ * \brief This API partitions a graph only by the operator names
+ * provided by users. This will attach a DefaultSubgraphProperty
+ * to the input graph for partitioning. This function should be
+ * used only for the testing purpose.
+ */
+MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+                                        const char* prop_name,
+                                        const mx_uint num_ops,
+                                        const char** op_names,
+                                        SymbolHandle* ret_sym_handle);
+
+/*!
+ * \brief Given a subgraph property name, use the provided op names
+ * as the op_names attribute for that subgraph property, instead of
+ * the predefined one. This is only for the purpose of testing.
+ */
+MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
+                                           const mx_uint num_ops,
+                                           const char** op_names);
+
+/*!
+ * \brief Given a subgraph property name, delete the op name set
+ * in the SubgraphPropertyOpNameSet.
+ */
+MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
+
+#ifdef __cplusplus
+}
+#endif  // __cplusplus
+
+#endif  // MXNET_C_API_TEST_H_
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 2c33b6c..11e64ed 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -43,7 +43,7 @@ class Engine;
 namespace engine {
 /*! \brief base class of engine variables.*/
 struct Var {
-  virtual uint32_t version() {
+  virtual size_t version() {
     return version_;
   }
   virtual ~Var() = default;
@@ -58,7 +58,7 @@ struct Var {
    * \brief version number of the var. Every time the object it is associated with
    * is modified, the version number is incremented by 1.
    */
-  uint32_t version_{0};
+  size_t version_{0};
 };  // struct Var
 
 /*! \brief Internal representation of operator.  */
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 7ed86ec..c27a59a 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -31,7 +31,6 @@
 #include "./c_api_common.h"
 #include "../operator/operator_common.h"
 #include "../executor/exec_pass.h"
-#include "../operator/subgraph/default_subgraph_op.h"
 
 namespace mxnet {
 namespace op {
@@ -697,27 +696,3 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
   *ret_qsym_handle = s;
   API_END_HANDLE_ERROR(delete s);
 }
-
-int MXPartitionGraph(SymbolHandle sym_handle,
-                     const mx_uint num_ops,
-                     const char** op_names,
-                     SymbolHandle* ret_sym_handle) {
-  nnvm::Symbol* s = new nnvm::Symbol();
-  API_BEGIN();
-  std::unordered_set<std::string> op_name_set;
-  for (size_t i = 0; i < num_ops; ++i) {
-    op_name_set.emplace(op_names[i]);
-  }
-  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
-  *s = sym->Copy();
-  nnvm::Graph g = Symbol2Graph(*s);
-  if (!op_name_set.empty()) {
-    mxnet::op::SubgraphPropertyPtr property
-        = std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
-    g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
-  }
-  g = ApplyPass(std::move(g), "PartitionGraph");
-  s->outputs = g.outputs;
-  *ret_sym_handle = s;
-  API_END_HANDLE_ERROR(delete s);
-}
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
new file mode 100644
index 0000000..2f5ad76
--- /dev/null
+++ b/src/c_api/c_api_test.cc
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file c_api_test.cc
+ * \brief C API of mxnet for the ease of testing backend in Python
+ */
+#include <mxnet/c_api_test.h>
+#include <nnvm/pass.h>
+#include "./c_api_common.h"
+#include "../operator/subgraph/default_subgraph_property.h"
+
+int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+                              const char* prop_name,
+                              const mx_uint num_ops,
+                              const char** op_names,
+                              SymbolHandle* ret_sym_handle) {
+  nnvm::Symbol* s = new nnvm::Symbol();
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+  *s = sym->Copy();
+  nnvm::Graph g;
+  g.outputs = s->outputs;
+  if (!op_name_set.empty()) {
+    mxnet::op::SubgraphPropertyPtr property
+        = mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+    property->SetAttr("op_names", op_name_set);
+    g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
+  }
+  g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
+  s->outputs = g.outputs;
+  *ret_sym_handle = s;
+  API_END_HANDLE_ERROR(delete s);
+}
+
+int MXSetSubgraphPropertyOpNames(const char* prop_name,
+                                 const mx_uint num_ops,
+                                 const char** op_names) {
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  (*mxnet::op::SubgraphPropertyOpNameSet::Get())[prop_name] = op_name_set;
+  API_END();
+}
+
+int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
+  API_BEGIN();
+  mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
+  API_END();
+}
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index e0a47fa..8adac9e 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -86,10 +86,6 @@ class NaiveEngine final : public Engine {
   // new variables
   VarHandle NewVariable() override {
     return NaiveVar::New();
-#if 0
-    size_t v = ++counter_;
-    return reinterpret_cast<VarHandle>(v);
-#endif
   }
 
   OprHandle NewOperator(AsyncFn fn,
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index bd11697..3a7587f 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -199,7 +199,7 @@ inline bool ThreadedVar::ready_to_read() {
   return this->is_ready_to_read();
 }
 
-inline uint32_t ThreadedVar::version() {
+inline size_t ThreadedVar::version() {
   std::lock_guard<std::mutex> lock{mutex_};
   return this->version_;
 }
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 7730c06..a2c1a2b 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,7 +162,7 @@ class ThreadedVar final
   inline void SetToDelete();
   /*! \return whether this variable is ready to read. */
   inline bool ready_to_read();
-  inline uint32_t version() override;
+  inline size_t version() override;
   /*!
    * \brief Cast a Var pointer to ThreadedVar pointer
    * \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 33c6f57..4fc36b9 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -33,6 +33,7 @@
 #include "../profiler/profiler.h"
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
+#include "../operator/subgraph/subgraph_property.h"
 
 namespace mxnet {
 namespace exec {
@@ -40,6 +41,7 @@ namespace exec {
 GraphExecutor::GraphExecutor() {
   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
   need_grad_ = false;
+  subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
 }
 
 GraphExecutor::~GraphExecutor() {
@@ -1699,6 +1701,146 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start,
     iter->c_str());
   return ret;
 }
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+                                     nnvm::ShapeVector arg_shapes,
+                                     nnvm::DTypeVector arg_dtypes,
+                                     StorageTypeVector arg_stypes,
+                                     const Context& default_ctx,
+                                     const std::map<std::string, Context>& ctx_map,
+                                     const std::vector<Context>& in_arg_ctxes,
+                                     const std::vector<Context>& aux_state_ctxes) {
+  const auto& indexed_graph = g.indexed_graph();
+  const auto num_forward_inputs = indexed_graph.input_nodes().size();
+  g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+                   aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs, indexed_graph,
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs, indexed_graph,
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+  return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const nnvm::ShapeVector& arg_shapes,
+                                   const nnvm::DTypeVector& arg_dtypes,
+                                   const StorageTypeVector& arg_stypes,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& aux_state_ctxes) {
+  auto subgraph_prop = op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+  nnvm::Symbol ret = src.Copy();
+  nnvm::Graph g;
+  g.outputs = ret.outputs;
+  g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
+                        ctx_map, in_arg_ctxes, aux_state_ctxes);
+  subgraph_prop->SetAttr("graph", g);
+  auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name);
+  // assign a op name set to the subgraph property if it has been provided by users
+  if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
+    LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << prop_name
+              << " has been assigned a value. Please make sure it is initialized"
+                 " only for the testing purpose.";
+    subgraph_prop->SetAttr("op_names", it->second);
+  }
+  g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(subgraph_prop));
+  g = ApplyPass(std::move(g), "PartitionGraph");
+  ret.outputs = g.outputs;
+  return ret;
+}
+
+// Given input attr dicts, partition the graph using the backend name equal to prop_name.
+// This is for simple_bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::unordered_map<std::string, TShape>& arg_shape_map,
+                                   const std::unordered_map<std::string, int>& arg_dtype_map,
+                                   const std::unordered_map<std::string, int>& arg_stype_map,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& aux_state_ctxes) {
+  const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
+  nnvm::ShapeVector arg_shapes(input_names.size(), TShape());
+  nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
+  StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    auto it1 = arg_shape_map.find(input_names[i]);
+    if (arg_shape_map.end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map.find(input_names[i]);
+    if (arg_dtype_map.end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map.find(input_names[i]);
+    if (arg_stype_map.end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
+
+// Given input ndarrays, partition the graph using the backend name equal to prop_name.
+// This is for bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::vector<NDArray> &in_args,
+                                   const std::vector<NDArray> &aux_states,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& ctx_map) {
+  const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
+  const std::vector<std::string> arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+  const std::vector<std::string> aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+  CHECK_EQ(arg_names.size(), in_args.size());
+  CHECK_EQ(aux_names.size(), aux_states.size());
+  nnvm::ShapeVector arg_shapes;  // all input shapes
+  arg_shapes.reserve(input_names.size());
+  nnvm::DTypeVector arg_dtypes;  // all input dtypes
+  arg_dtypes.reserve(input_names.size());
+  StorageTypeVector arg_stypes;  // all input stypes
+  arg_stypes.reserve(input_names.size());
+  std::vector<Context> in_arg_ctxes(in_args.size());
+  std::vector<Context> aux_state_ctxes(aux_states.size());
+
+  size_t i1 = 0, i2 = 0;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    if (i2 < aux_names.size() && aux_names[i2] == input_names[i]) {
+      arg_shapes.push_back(aux_states[i2].shape());
+      arg_dtypes.push_back(aux_states[i2].dtype());
+      arg_stypes.push_back(aux_states[i2].storage_type());
+      aux_state_ctxes[i2] = aux_states[i2].ctx();
+      ++i2;
+    } else {
+      CHECK(i1 < arg_names.size());
+      CHECK_EQ(arg_names[i1], input_names[i]);
+      arg_shapes.push_back(in_args[i1].shape());
+      arg_dtypes.push_back(in_args[i1].dtype());
+      arg_stypes.push_back(in_args[i1].storage_type());
+      in_arg_ctxes[i1] = in_args[i1].ctx();
+      ++i1;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
 }  // namespace exec
 
 Executor *Executor::SimpleBind(nnvm::Symbol symbol,
@@ -1718,6 +1860,11 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
                                std::unordered_map<std::string, NDArray>* shared_buffer,
                                Executor* shared_exec) {
   auto exec = new exec::GraphExecutor();
+  if (!exec->subgraph_property().empty()) {
+    symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
+                                  arg_stype_map, default_ctx, group2ctx, in_arg_ctxes,
+                                  aux_state_ctxes);
+  }
   exec->Init(symbol, default_ctx, group2ctx,
              in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
              arg_shape_map, arg_dtype_map, arg_stype_map,
@@ -1736,6 +1883,10 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
                          const std::vector<NDArray> &aux_states,
                          Executor* shared_exec) {
   auto exec = new exec::GraphExecutor();
+  if (!exec->subgraph_property().empty()) {
+    symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, aux_states,
+                                  default_ctx, group2ctx);
+  }
   exec->Init(symbol, default_ctx, group2ctx,
              in_args, arg_grad_store, grad_req_type, aux_states,
              reinterpret_cast<Executor*>(shared_exec));
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bfc415b..b4d36b1 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -117,6 +117,8 @@ class GraphExecutor : public Executor {
                     std::vector<NDArray>* arg_grads,
                     std::vector<NDArray>* aux_states) override;
 
+  const std::string& subgraph_property() const { return subgraph_property_; }
+
  protected:
   friend class mxnet::Imperative;
   // Information about operational node
@@ -256,6 +258,8 @@ class GraphExecutor : public Executor {
   std::unordered_set<std::string> cached_seg_opr_names_;
   // verbose logging
   bool log_verbose_ = false;
+  // subgraph property name
+  std::string subgraph_property_;
 };
 
 }  // namespace exec
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
index 472312d..bf46104 100644
--- a/src/operator/subgraph/common.h
+++ b/src/operator/subgraph/common.h
@@ -57,22 +57,22 @@ struct SimpleNode {
 }  // namespace sg
 
 inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListInputNames(nnvm::Symbol::kAll).size();
 }
 
 inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListOutputNames().size();
 }
 
 inline std::vector<std::string> DefaultSubgraphOpListInputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListInputNames(nnvm::Symbol::kAll);
 }
 
 inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
   return sym.ListOutputNames();
 }
 
@@ -80,7 +80,7 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
                                    std::vector<TShape> *in_shapes,
                                    std::vector<TShape> *out_shapes) {
   using namespace exec;
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -124,7 +124,7 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
 inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
                                   std::vector<int> *in_types,
                                   std::vector<int> *out_types) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -169,7 +169,7 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
                                          DispatchMode* dispatch_mode,
                                          std::vector<int>* in_stypes,
                                          std::vector<int>* out_stypes) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -222,7 +222,7 @@ inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
 }
 
 inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
   const std::vector<std::string> immutable_input_names =
     subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
@@ -245,7 +245,7 @@ inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr
 }
 
 inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
   std::set<ResourceRequest::Type> resource_types;
   DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
index 8372ae9..491d6ee 100644
--- a/src/operator/subgraph/default_subgraph_op.cc
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -18,7 +18,7 @@
 */
 
 #include <mxnet/ndarray.h>
-#include "./default_subgraph_op.h"
+#include "./common.h"
 #include "../../imperative/imperative_utils.h"
 #include "../../imperative/cached_op.h"
 
@@ -30,7 +30,8 @@ namespace op {
 class DefaultSubgraphOperator {
  public:
   explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
-    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}}));
+    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
+                                            {"static_shape", "true"}}));
   }
 
   void Forward(const OpContext& ctx,
@@ -79,8 +80,7 @@ OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
                                         Context ctx,
                                         const std::vector<TShape>& in_shapes,
                                         const std::vector<int>& in_types) {
-  const Symbol& subgraph_sym = nnvm::get<Symbol>(attrs.parsed);
-  return OpStatePtr::Create<DefaultSubgraphOperator>(subgraph_sym);
+  return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
 }
 
 void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
index 15a76e3..008826b 100644
--- a/src/operator/subgraph/default_subgraph_op.cu
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -19,11 +19,14 @@
 
 /*!
  *  Copyright (c) 2018 by Contributors
- * \file subgraph_op.cu
+ * \file default_subgraph_op.cu
  * \brief GPU Implementation of subgraph operations
  */
 
-#include "./default_subgraph_op.h"
+#include <mxnet/ndarray.h>
+#include "./common.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/subgraph/default_subgraph_op.h b/src/operator/subgraph/default_subgraph_op.h
deleted file mode 100644
index 7d6624e..0000000
--- a/src/operator/subgraph/default_subgraph_op.h
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
-#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
-
-#include <vector>
-#include <string>
-#include "./common.h"
-
-namespace mxnet {
-namespace op {
-
-/*
- * This provides criteria for selecting nodes in a subgraph.
- * When a node is passed to this object, the selection criteria may be changed.
- * We can also specify what links we should use when traversing the neighbor
- * nodes.
- */
-class SubgraphSelector {
- public:
-  virtual ~SubgraphSelector() {
-  }
-  // Determine if the node should be selected for a subgraph.
-  virtual bool Select(const nnvm::Node &n) = 0;
-  // Determine if the input node should be selected for a subgraph.
-  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
-  // Determine if the output node should be selected for a subgraph.
-  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
-  // Post processes pre-selected subgraph nodes. Return a list of nodes that
-  // users want to keep in subgraph(s).
-  virtual std::vector<nnvm::Node*> Filter(nnvm::Graph* g,
-                                          const std::vector<nnvm::Node*>& candidates) {
-    return candidates;
-  }
-};
-
-using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
-
-/*!
- * \brief This provides a set of properties for partitioning a graph into subgraphs,
- * reconstructing a new graph from the subgraphs and creating a subgraph
- * operator to execute the subgraph.
- */
-class SubgraphProperty {
- public:
-  // the criteria of selecting the subgraph nodes.
-  virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
-  // create an nnvm node for a given subgraph. Here users can customize how to
-  // execute the operators in the subgraph.
-  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
-                                           const int subgraph_id = 0) const = 0;
-};
-
-using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
-
-void RegisterSubgraphProperty(SubgraphPropertyPtr property);
-
-/*
- * This selects nodes for a subgraph that only contains operators
- * in a given set and it visits nodes via both input and output links.
- */
-class ContainOpSelector: public SubgraphSelector {
-  std::shared_ptr<const std::unordered_set<std::string>> op_names;
-
- public:
-  explicit ContainOpSelector(std::shared_ptr<const std::unordered_set<std::string>> op_names) {
-    this->op_names = op_names;
-  }
-
-  virtual bool Select(const nnvm::Node &n) {
-    return !n.is_variable() && op_names->count(n.op()->name);
-  }
-
-  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
-    return !new_node.is_variable() && op_names->count(new_node.op()->name);
-  }
-
-  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
-    return !new_node.is_variable() && op_names->count(new_node.op()->name);
-  }
-};
-
-/*
- * This subgraph property finds a subgraph whose nodes have only operators
- * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
- */
-class DefaultSubgraphProperty: public SubgraphProperty {
- public:
-  explicit DefaultSubgraphProperty(const std::unordered_set<std::string> &op_names) :
-    op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {}
-  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
-                                           const int subgraph_id = 0) const {
-    nnvm::NodePtr n = nnvm::Node::Create();
-    n->attrs.op = Op::Get("_default_subgraph_op");
-    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
-    n->attrs.parsed = sym;
-    return n;
-  }
-  virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
-    return std::make_shared<ContainOpSelector>(op_names_);
-  }
-
- private:
-  std::shared_ptr<const std::unordered_set<std::string>> op_names_;
-};
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
diff --git a/src/operator/subgraph/default_subgraph_property.h b/src/operator/subgraph/default_subgraph_property.h
new file mode 100644
index 0000000..3882247
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_property.h
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
+
+#include <vector>
+#include <string>
+#include "./common.h"
+#include "./subgraph_property.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+ public:
+  explicit ContainOpSelector(const std::unordered_set<std::string>& op_names)
+    : op_names_(op_names) {}
+
+  virtual bool Select(const nnvm::Node &n) {
+    return !n.is_variable() && op_names_.count(n.op()->name);
+  }
+
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names_.count(new_node.op()->name);
+  }
+
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names_.count(new_node.op()->name);
+  }
+ private:
+  const std::unordered_set<std::string>& op_names_;
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+  static SubgraphPropertyPtr Create() { return std::make_shared<DefaultSubgraphProperty>(); }
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                           const int subgraph_id = 0) const {
+    nnvm::NodePtr n = nnvm::Node::Create();
+    n->attrs.op = Op::Get("_default_subgraph_op");
+    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+    n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+    return n;
+  }
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+    return std::make_shared<ContainOpSelector>(
+        this->GetAttr<std::unordered_set<std::string>>("op_names"));
+  }
+};
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(default, DefaultSubgraphProperty);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_PROPERTY_H_
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
index 9672877..e8c3069 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -29,7 +29,7 @@
 #include <stack>
 #include <queue>
 
-#include "./default_subgraph_op.h"
+#include "./subgraph_property.h"
 #include "./common.h"
 
 namespace nnvm {
@@ -408,7 +408,7 @@ void FindSubgraphs(Graph* g,
                              &preselected_nodes);
 
       // filter out unqualified pre-selected nodes
-      std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(g, preselected_nodes);
+      std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);
 
       // make sure filtered_nodes is a subset of preselected_nodes
       for (const auto n : filtered_nodes) {
diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h
new file mode 100644
index 0000000..2153a36
--- /dev/null
+++ b/src/operator/subgraph/subgraph_property.h
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+
+#include <nnvm/node.h>
+#include <dmlc/base.h>
+#include <dmlc/thread_local.h>
+#include <unordered_map>
+#include <vector>
+#include <string>
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for selecting nodes in a subgraph.
+ * When a node is passed to this object, the selection criteria may be changed.
+ * We can also specify what links we should use when traversing the neighbor
+ * nodes.
+ */
+class SubgraphSelector {
+ public:
+  virtual ~SubgraphSelector() {}
+  // Determine if the node should be selected for a subgraph.
+  virtual bool Select(const nnvm::Node &n) = 0;
+  // Determine if the input node should be selected for a subgraph.
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+  // Determine if the output node should be selected for a subgraph.
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+  // Post processes pre-selected subgraph nodes. Return a list of nodes that
+  // users want to keep in subgraph(s).
+  virtual std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) {
+    return candidates;
+  }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+  // the criteria of selecting the subgraph nodes.
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+  // create an nnvm node for a given subgraph. Here users can customize how to
+  // execute the operators in the subgraph.
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+                                           const int subgraph_id = 0) const = 0;
+  // set an attr with name in the attr map
+  template<typename T>
+  SubgraphProperty& SetAttr(const std::string& name, const T& value) {
+    attrs_[name] = std::make_shared<dmlc::any>(value);
+    return *this;
+  }
+  // get the attr with the name
+  template<typename T>
+  const T& GetAttr(const std::string& name) const {
+    auto it = attrs_.find(name);
+    CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty";
+    return nnvm::get<T>(*it->second);
+  }
+ protected:
+  std::unordered_map<std::string, std::shared_ptr<nnvm::any>> attrs_;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+class SubgraphPropertyRegistry {
+ public:
+  typedef SubgraphPropertyPtr (*SubgraphPropertyCreateFn)(void);
+  static SubgraphPropertyRegistry* Get() {
+    static SubgraphPropertyRegistry inst;
+    return &inst;
+  }
+
+  SubgraphPropertyPtr CreateSubgraphProperty(const std::string& name) {
+    auto it = prop_fn_map_.find(name);
+    CHECK(it != prop_fn_map_.end()) << "SubgraphProperty " << name
+                                    << " is not found in SubgraphPropertyRegistry";
+    return it->second();
+  }
+
+  SubgraphPropertyCreateFn __REGISTER__(const std::string& name, SubgraphPropertyCreateFn fn) {
+    CHECK_EQ(prop_fn_map_.count(name), 0U) << "Subgraph property " << name
+                                           << " has been registered";
+    prop_fn_map_[name] = fn;
+    return prop_fn_map_[name];
+  }
+
+ private:
+  SubgraphPropertyRegistry() = default;
+  SubgraphPropertyRegistry(const SubgraphPropertyRegistry&) = delete;
+  SubgraphPropertyRegistry(SubgraphPropertyRegistry&&) = delete;
+  SubgraphPropertyRegistry& operator=(const SubgraphPropertyRegistry&) = delete;
+  std::unordered_map<std::string, SubgraphPropertyCreateFn> prop_fn_map_;
+};
+
+// This op name set is for setting the names of operators that should be grouped into
+// subgraphs. In practice, every backend accelerator should have a predefined name set.
+// This set is only used for the testing purpose.
+// key: property name, value: op name set
+typedef dmlc::ThreadLocalStore<std::unordered_map<std::string, std::unordered_set<std::string>>>
+  SubgraphPropertyOpNameSet;
+
+#define MXNET_REGISTER_SUBGRAPH_PROPERTY(Name, SubgraphPropertyType) \
+  static DMLC_ATTRIBUTE_UNUSED auto __make_ ## SubgraphPropertyType ## _ ## Name ## __ = \
+    SubgraphPropertyRegistry::Get()->__REGISTER__(#Name, &SubgraphPropertyType::Create)
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index 92d0958..6d669c1 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -275,6 +275,64 @@ TEST(Engine, basics) {
   LOG(INFO) << "All pass";
 }
 
+TEST(Engine, VarVersion) {
+  const size_t num_engines = 3;
+  std::vector<mxnet::Engine*> engines(num_engines);
+  engines[0] = mxnet::engine::CreateNaiveEngine();
+  engines[1] = mxnet::engine::CreateThreadedEnginePooled();
+  engines[2] = mxnet::engine::CreateThreadedEnginePerDevice();
+  std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled", "ThreadedEnginePerDevice"};
+  for (size_t k = 0; k < num_engines; ++k) {
+    auto engine = engines[k];
+    std::vector<mxnet::Engine::OprHandle> oprs;
+
+    LOG(INFO) << "Testing var as a read dependency in " << type_names[k];
+    auto var = engine->NewVariable();
+    EXPECT_EQ(var->version(), 0U);
+    for (int i = 0; i < 10; ++i) {
+      oprs.push_back(engine->NewOperator(
+          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+            Foo(ctx, i);
+            cb();
+          },
+          {var}, {}));
+      engine->Push(oprs.at(i), mxnet::Context{});
+    }
+    engine->WaitForAll();
+    EXPECT_EQ(var->version(), 0U);
+    for (auto&& i : oprs) {
+      engine->DeleteOperator(i);
+    }
+    engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+    engine->WaitForAll();
+
+    LOG(INFO) << "Testing var as a write dependency in " << type_names[k];
+    var = engine->NewVariable();
+    EXPECT_EQ(var->version(), 0U);
+    oprs.clear();
+    for (int i = 0; i < 10; ++i) {
+      oprs.push_back(engine->NewOperator(
+          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+            Foo(ctx, i);
+            cb();
+          },
+          {}, {var}));
+      engine->Push(oprs.at(i), mxnet::Context{});
+    }
+    engine->WaitForAll();
+    EXPECT_EQ(var->version(), 10U);
+    for (auto&& i : oprs) {
+      engine->DeleteOperator(i);
+    }
+    engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+    engine->WaitForAll();
+
+    var = nullptr;
+    oprs.clear();
+    LOG(INFO) << "All pass";
+  }
+}
+
 #ifdef _OPENMP
 
 struct TestSaveAndRestoreOMPState {
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index 13e8e4e..2a34400 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -175,7 +175,6 @@ def test_trainer_save_load():
     # check if parameter dict is correctly associated with optimizer after load_state
     assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
 
-@unittest.skip("temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11353")
 @with_seed()
 def test_trainer_reset_kv():
     def check_trainer_reset_kv(kv):
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
index f6a33c2..40d609a 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -15,26 +15,29 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import os
 import ctypes
 import mxnet as mx
-from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, c_str
 from mxnet.symbol import Symbol
 import numpy as np
 from mxnet.test_utils import assert_almost_equal
 
 
 def test_subgraph_exe():
-    def check_subgraph_exe(sym, op_names):
+    def _check_subgraph_exe1(sym, op_names):
+        """Use the partitioned sym to simple_bind an executor and compare the outputs
+        with those of the original executor"""
         out = SymbolHandle()
-        check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)),
-                                         c_str_array(op_names), ctypes.byref(out)))
+        check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, c_str('default'), mx_uint(len(op_names)),
+                                                  c_str_array(op_names), ctypes.byref(out)))
 
         partitioned_sym = Symbol(out)
         assert partitioned_sym.list_inputs() == sym.list_inputs()
         assert partitioned_sym.list_arguments() == sym.list_arguments()
         assert partitioned_sym.list_auxiliary_states() == sym.list_auxiliary_states()
-        exe = sym.simple_bind(ctx=mx.cpu(), grad_req='null')
-        partitioned_exe = partitioned_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+        exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+        partitioned_exe = partitioned_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
         input_names = sym.list_inputs()
         for name in input_names:
             if name in exe.arg_dict:
@@ -46,12 +49,109 @@ def test_subgraph_exe():
                 partitioned_exe.aux_dict[name][:] = exe.aux_dict[name]
         exe.forward()
         partitioned_exe.forward()
-        mx.nd.waitall()
         assert len(exe.outputs) == len(partitioned_exe.outputs)
         for i in range(len(exe.outputs)):
             assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
                                 np.zeros(shape=(1,)))
 
+    def _check_subgraph_exe2(sym, op_names):
+        """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in simple_bind
+        and compare results of the partitioned sym and the original sym."""
+        def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
+            if subgraph_backend is not None:
+                os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+                check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+                                                             c_str_array(op_names)))
+            exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+            input_names = sym.list_inputs()
+            for name in input_names:
+                if name in exe.arg_dict:
+                    exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\
+                        if original_exec is None else original_exec.arg_dict[name]
+                else:
+                    assert name in exe.aux_dict
+                    exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
+                        if original_exec is None else original_exec.aux_dict[name]
+            exe.forward()
+            if subgraph_backend is not None:
+                check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+                del os.environ['MXNET_SUBGRAPH_BACKEND']
+            return exe
+
+        original_exec = get_executor(sym)
+        partitioned_exec = get_executor(sym, 'default', op_names, original_exec)
+        outputs1 = original_exec.outputs
+        outputs2 = partitioned_exec.outputs
+        assert len(outputs1) == len(outputs2)
+        for i in range(len(outputs1)):
+            assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+    def _check_subgraph_exe3(sym, op_names):
+        """Use the partitioned sym to bind an executor and compare the outputs
+        with those of the original executor"""
+        out = SymbolHandle()
+        check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, c_str('default'), mx_uint(len(op_names)),
+                                                  c_str_array(op_names), ctypes.byref(out)))
+
+        partitioned_sym = Symbol(out)
+        input_names = sym.list_inputs()
+        arg_names = sym.list_arguments()
+        aux_names = sym.list_auxiliary_states()
+        assert partitioned_sym.list_inputs() == input_names
+        assert partitioned_sym.list_arguments() == arg_names
+        assert partitioned_sym.list_auxiliary_states() == aux_names
+        arg_shapes, _, aux_shapes = sym.infer_shape()
+        arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+        aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+        exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
+        partitioned_exe = partitioned_sym.bind(ctx=mx.current_context(), args=arg_array,
+                                               aux_states=aux_array, grad_req='null')
+        exe.forward()
+        partitioned_exe.forward()
+        assert len(exe.outputs) == len(partitioned_exe.outputs)
+        for i in range(len(exe.outputs)):
+            assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
+                                np.zeros(shape=(1,)))
+
+    def _check_subgraph_exe4(sym, op_names):
+        """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in bind
+        and compare results of the partitioned sym and the original sym."""
+        def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
+            if subgraph_backend is not None:
+                os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+                check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+                                                             c_str_array(op_names)))
+            arg_shapes, _, aux_shapes = sym.infer_shape()
+            if subgraph_backend is None:
+                arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+                aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+            else:
+                arg_array = None
+                aux_array = None
+            exe = sym.bind(ctx=mx.current_context(),
+                           args=arg_array if subgraph_backend is None else original_exec.arg_arrays,
+                           aux_states=aux_array if subgraph_backend is None else original_exec.aux_arrays,
+                           grad_req='null')
+            exe.forward()
+            if subgraph_backend is not None:
+                check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+                del os.environ['MXNET_SUBGRAPH_BACKEND']
+            return exe
+
+        original_exec = get_executor(sym)
+        partitioned_exec = get_executor(sym, 'default', op_names, original_exec)
+        outputs1 = original_exec.outputs
+        outputs2 = partitioned_exec.outputs
+        assert len(outputs1) == len(outputs2)
+        for i in range(len(outputs1)):
+            assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+    def check_subgraph_exe(sym, op_names):
+        _check_subgraph_exe1(sym, op_names)
+        _check_subgraph_exe2(sym, op_names)
+        _check_subgraph_exe3(sym, op_names)
+        _check_subgraph_exe4(sym, op_names)
+
     def test_network_structure_1():
         data1 = mx.sym.var('data1', shape=(2, 3, 10, 10))
         data2 = mx.sym.var('data2')