You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/31 02:13:35 UTC

[GitHub] eric-haibin-lin closed pull request #12157: Subgraph API for integrating accelerators with MXNet

eric-haibin-lin closed pull request #12157: Subgraph API for integrating accelerators with MXNet
URL: https://github.com/apache/incubator-mxnet/pull/12157
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h
new file mode 100644
index 00000000000..fe6fc7fe9cc
--- /dev/null
+++ b/include/mxnet/c_api_test.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file c_api_test.h
+ * \brief C API of mxnet for ease of testing backend in Python
+ */
+#ifndef MXNET_C_API_TEST_H_
+#define MXNET_C_API_TEST_H_
+
+/*! \brief Inhibit C++ name-mangling for MXNet functions. */
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+#include <mxnet/c_api.h>
+
+/*!
+ * \brief This API partitions a graph only by the operator names
+ * provided by users. This will attach a DefaultSubgraphProperty
+ * to the input graph for partitioning. This function should be
+ * used only for the testing purpose.
+ */
+MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+                                        const char* prop_name,
+                                        const mx_uint num_ops,
+                                        const char** op_names,
+                                        SymbolHandle* ret_sym_handle);
+
+/*!
+ * \brief Given a subgraph property name, use the provided op names
+ * as the op_names attribute for that subgraph property, instead of
+ * the predefined one. This is only for the purpose of testing.
+ */
+MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
+                                           const mx_uint num_ops,
+                                           const char** op_names);
+
+/*!
+ * \brief Given a subgraph property name, delete the op name set
+ * in the SubgraphPropertyOpNameSet.
+ */
+MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
+
+#ifdef __cplusplus
+}
+#endif  // __cplusplus
+
+#endif  // MXNET_C_API_TEST_H_
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index dc48bfb83fa..11e64edfcd5 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -41,8 +41,26 @@ class Engine;
 
 /*! \brief namespace of engine internal types. */
 namespace engine {
-/*! \brief Internal representation of variable. */
-struct Var;
+/*! \brief base class of engine variables.*/
+struct Var {
+  virtual size_t version() {
+    return version_;
+  }
+  virtual ~Var() = default;
+  /*!
+   * \brief cast variable to derived type T
+   * \tparam T the type we want to cast into.
+   * \return A casted variable.
+   */
+  template <typename T>
+  inline T* Cast();
+  /*!
+   * \brief version number of the var. Every time the object it is associated with
+   * is modified, the version number is incremented by 1.
+   */
+  size_t version_{0};
+};  // struct Var
+
 /*! \brief Internal representation of operator.  */
 struct Opr;
 /*! \brief Variable pointer type, usually hold by user used to specify dependencies. */
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index bae3ea90d5e..6141a4da78e 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -340,6 +340,10 @@ class NDArray {
   inline size_t byte_offset() const {
     return byte_offset_;
   }
+  /*! \brief return var version of the NDArray*/
+  inline size_t version() const {
+    return var()->version();
+  }
   /*!
    * \brief save the content into binary stream
    * \param strm the output stream
diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc
new file mode 100644
index 00000000000..623faa71adc
--- /dev/null
+++ b/src/c_api/c_api_test.cc
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file c_api_test.cc
+ * \brief C API of mxnet for the ease of testing backend in Python
+ */
+#include <mxnet/c_api_test.h>
+#include <nnvm/pass.h>
+#include "./c_api_common.h"
+#include "../operator/subgraph/subgraph_property.h"
+
+int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
+                              const char* prop_name,
+                              const mx_uint num_ops,
+                              const char** op_names,
+                              SymbolHandle* ret_sym_handle) {
+  nnvm::Symbol* s = new nnvm::Symbol();
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+  *s = sym->Copy();
+  nnvm::Graph g;
+  g.outputs = s->outputs;
+  if (!op_name_set.empty()) {
+    mxnet::op::SubgraphPropertyPtr property
+        = mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+    property->SetAttr("op_names", op_name_set);
+    g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
+  }
+  g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
+  s->outputs = g.outputs;
+  *ret_sym_handle = s;
+  API_END_HANDLE_ERROR(delete s);
+}
+
+int MXSetSubgraphPropertyOpNames(const char* prop_name,
+                                 const mx_uint num_ops,
+                                 const char** op_names) {
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  (*mxnet::op::SubgraphPropertyOpNameSet::Get())[prop_name] = op_name_set;
+  API_END();
+}
+
+int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
+  API_BEGIN();
+  mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
+  API_END();
+}
diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h
index b3ec34dc857..f15141f4e7a 100644
--- a/src/engine/engine_impl.h
+++ b/src/engine/engine_impl.h
@@ -33,20 +33,6 @@
 namespace mxnet {
 namespace engine {
 
-/*! \brief base class of engine variables, used for type checking */
-struct Var {
-#if ENGINE_DEBUG
-  virtual ~Var() = default;
-#endif  // ENGINE_DEBUG
-  /*!
-   * \brief cast variable to derived type T
-   * \tparam T the type we want to cast into.
-   * \return A casted variable.
-   */
-  template <typename T>
-  inline T* Cast();
-};  // struct Var
-
 /*! \brief base class of engine operators, used for type checking */
 struct Opr {
 #if ENGINE_DEBUG
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 8196af2de2f..daff5306694 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -28,10 +28,24 @@
 #include "./engine_impl.h"
 #include "../profiler/profiler.h"
 #include "./openmp.h"
+#include "../common/object_pool.h"
 
 namespace mxnet {
 namespace engine {
 
+/*!
+ * \brief var used in Naive Engine for tracking the version
+ * of the objects it is associated with.
+ */
+class NaiveVar final
+    : public Var, public common::ObjectPoolAllocatable<NaiveVar> {
+ public:
+  inline static NaiveVar* CastFromBase(Var* ptr) {
+    return ptr->Cast<NaiveVar>();
+  }
+};  // class NaiveVar
+
+
 // implement naive engine
 class NaiveEngine final : public Engine {
  public:
@@ -71,8 +85,7 @@ class NaiveEngine final : public Engine {
 
   // new variables
   VarHandle NewVariable() override {
-    size_t v = ++counter_;
-    return reinterpret_cast<VarHandle>(v);
+    return NaiveVar::New();
   }
 
   OprHandle NewOperator(AsyncFn fn,
@@ -146,6 +159,10 @@ class NaiveEngine final : public Engine {
       opr->opr_profile.reset(new profiler::ProfileOperator(opr->opr_name, attrs.release()));
       opr->opr_profile->start(exec_ctx.dev_type, exec_ctx.dev_id);
     }
+    // increment mutable var version
+    for (auto var : mutable_vars) {
+      ++var->version_;
+    }
     if (exec_ctx.dev_mask() == gpu::kDevMask) {
 #if MXNET_USE_CUDA
       size_t dev_id = static_cast<size_t>(exec_ctx.dev_id);
@@ -171,8 +188,12 @@ class NaiveEngine final : public Engine {
   }
 
   void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override {
-    this->PushSync(delete_fn, exec_ctx, {}, {var},
-                   FnProperty::kNormal, 0, "DeleteVariable");
+    NaiveVar* naive_var = NaiveVar::CastFromBase(var);
+    this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable {
+        delete_fn(ctx);
+        NaiveVar::Delete(naive_var);
+        on_complete();
+      }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable");
   }
 
   void WaitForVar(VarHandle var) override {
@@ -192,8 +213,6 @@ class NaiveEngine final : public Engine {
   }
   // whether action is completed
   bool req_completed_;
-  // counter
-  std::atomic<size_t> counter_{0};
   /*! \brief whether it is during shutdown phase*/
   std::atomic<bool> shutdown_phase_{false};
   // CPU stream
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc197c0c..3a7587fef13 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -130,6 +130,9 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
     assert(pending_write_ != nullptr);
     CHECK_EQ(num_pending_reads_, kWriteTriggered);
 
+    // increment version number
+    ++version_;
+
     // really delete
     if (to_delete_) {
       VersionedVarBlock *head = pending_write_->next;
@@ -164,7 +167,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
   }
   // This is outside of lock scope
   // Be very carful, pending_write_ and num_pending_reads_
-  // can change now, do not reply ont the two variables.
+  // can change now, do not rely on these two variables.
   // The linked list \in [old_pending_write, end_of_read_chain)
   // is already detached from this Var.
   // So it is safe to modify these
@@ -196,6 +199,11 @@ inline bool ThreadedVar::ready_to_read() {
   return this->is_ready_to_read();
 }
 
+inline size_t ThreadedVar::version() {
+  std::lock_guard<std::mutex> lock{mutex_};
+  return this->version_;
+}
+
 // implementation of threaded engine
 ThreadedVar* ThreadedEngine::NewVariable() {
   return ThreadedVar::New(VersionedVarBlock::New());
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 428f0d8c554..a2c1a2b943a 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,6 +162,7 @@ class ThreadedVar final
   inline void SetToDelete();
   /*! \return whether this variable is ready to read. */
   inline bool ready_to_read();
+  inline size_t version() override;
   /*!
    * \brief Cast a Var pointer to ThreadedVar pointer
    * \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 32b14b8e963..265554ab391 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -33,6 +33,7 @@
 #include "../profiler/profiler.h"
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
+#include "../operator/subgraph/subgraph_property.h"
 
 namespace mxnet {
 namespace exec {
@@ -42,6 +43,7 @@ using namespace mxnet::common;
 GraphExecutor::GraphExecutor() {
   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
   need_grad_ = false;
+  subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", std::string());
 }
 
 GraphExecutor::~GraphExecutor() {
@@ -1428,6 +1430,146 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start,
     iter->c_str());
   return ret;
 }
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+                                     nnvm::ShapeVector arg_shapes,
+                                     nnvm::DTypeVector arg_dtypes,
+                                     StorageTypeVector arg_stypes,
+                                     const Context& default_ctx,
+                                     const std::map<std::string, Context>& ctx_map,
+                                     const std::vector<Context>& in_arg_ctxes,
+                                     const std::vector<Context>& aux_state_ctxes) {
+  const auto& indexed_graph = g.indexed_graph();
+  const auto num_forward_inputs = indexed_graph.input_nodes().size();
+  g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+                   aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs, indexed_graph,
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs, indexed_graph,
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+  g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+  return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const nnvm::ShapeVector& arg_shapes,
+                                   const nnvm::DTypeVector& arg_dtypes,
+                                   const StorageTypeVector& arg_stypes,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& aux_state_ctxes) {
+  auto subgraph_prop = op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+  nnvm::Symbol ret = src.Copy();
+  nnvm::Graph g;
+  g.outputs = ret.outputs;
+  g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
+                        ctx_map, in_arg_ctxes, aux_state_ctxes);
+  subgraph_prop->SetAttr("graph", g);
+  auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name);
+  // assign a op name set to the subgraph property if it has been provided by users
+  if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
+    LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << prop_name
+              << " has been assigned a value. Please make sure it is initialized"
+                 " only for the testing purpose.";
+    subgraph_prop->SetAttr("op_names", it->second);
+  }
+  g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(subgraph_prop));
+  g = ApplyPass(std::move(g), "PartitionGraph");
+  ret.outputs = g.outputs;
+  return ret;
+}
+
+// Given input attr dicts, partition the graph using the backend name equal to prop_name.
+// This is for simple_bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::unordered_map<std::string, TShape>& arg_shape_map,
+                                   const std::unordered_map<std::string, int>& arg_dtype_map,
+                                   const std::unordered_map<std::string, int>& arg_stype_map,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& aux_state_ctxes) {
+  const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
+  nnvm::ShapeVector arg_shapes(input_names.size(), TShape());
+  nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
+  StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    auto it1 = arg_shape_map.find(input_names[i]);
+    if (arg_shape_map.end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map.find(input_names[i]);
+    if (arg_dtype_map.end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map.find(input_names[i]);
+    if (arg_stype_map.end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
+
+// Given input ndarrays, partition the graph using the backend name equal to prop_name.
+// This is for bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::vector<NDArray> &in_args,
+                                   const std::vector<NDArray> &aux_states,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& ctx_map) {
+  const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
+  const std::vector<std::string> arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+  const std::vector<std::string> aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+  CHECK_EQ(arg_names.size(), in_args.size());
+  CHECK_EQ(aux_names.size(), aux_states.size());
+  nnvm::ShapeVector arg_shapes;  // all input shapes
+  arg_shapes.reserve(input_names.size());
+  nnvm::DTypeVector arg_dtypes;  // all input dtypes
+  arg_dtypes.reserve(input_names.size());
+  StorageTypeVector arg_stypes;  // all input stypes
+  arg_stypes.reserve(input_names.size());
+  std::vector<Context> in_arg_ctxes(in_args.size());
+  std::vector<Context> aux_state_ctxes(aux_states.size());
+
+  size_t i1 = 0, i2 = 0;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    if (i2 < aux_names.size() && aux_names[i2] == input_names[i]) {
+      arg_shapes.push_back(aux_states[i2].shape());
+      arg_dtypes.push_back(aux_states[i2].dtype());
+      arg_stypes.push_back(aux_states[i2].storage_type());
+      aux_state_ctxes[i2] = aux_states[i2].ctx();
+      ++i2;
+    } else {
+      CHECK(i1 < arg_names.size());
+      CHECK_EQ(arg_names[i1], input_names[i]);
+      arg_shapes.push_back(in_args[i1].shape());
+      arg_dtypes.push_back(in_args[i1].dtype());
+      arg_stypes.push_back(in_args[i1].storage_type());
+      in_arg_ctxes[i1] = in_args[i1].ctx();
+      ++i1;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
 }  // namespace exec
 
 Executor *Executor::SimpleBind(nnvm::Symbol symbol,
@@ -1447,6 +1589,11 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
                                std::unordered_map<std::string, NDArray>* shared_buffer,
                                Executor* shared_exec) {
   auto exec = new exec::GraphExecutor();
+  if (!exec->subgraph_property().empty()) {
+    symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
+                                  arg_stype_map, default_ctx, group2ctx, in_arg_ctxes,
+                                  aux_state_ctxes);
+  }
   exec->Init(symbol, default_ctx, group2ctx,
              in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
              arg_shape_map, arg_dtype_map, arg_stype_map,
@@ -1465,6 +1612,10 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
                          const std::vector<NDArray> &aux_states,
                          Executor* shared_exec) {
   auto exec = new exec::GraphExecutor();
+  if (!exec->subgraph_property().empty()) {
+    symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), in_args, aux_states,
+                                  default_ctx, group2ctx);
+  }
   exec->Init(symbol, default_ctx, group2ctx,
              in_args, arg_grad_store, grad_req_type, aux_states,
              reinterpret_cast<Executor*>(shared_exec));
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 7b936c30025..b94bb437778 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -117,6 +117,8 @@ class GraphExecutor : public Executor {
                     std::vector<NDArray>* arg_grads,
                     std::vector<NDArray>* aux_states) override;
 
+  const std::string& subgraph_property() const { return subgraph_property_; }
+
  protected:
   friend class mxnet::Imperative;
   // Information about operational node
@@ -256,6 +258,8 @@ class GraphExecutor : public Executor {
   std::unordered_set<std::string> cached_seg_opr_names_;
   // verbose logging
   bool log_verbose_ = false;
+  // subgraph property name
+  std::string subgraph_property_;
 };
 
 }  // namespace exec
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
new file mode 100644
index 00000000000..22058d556e0
--- /dev/null
+++ b/src/operator/subgraph/common.h
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+
+#include <string>
+#include <set>
+#include <vector>
+#include "../elemwise_op_common.h"
+#include "../../executor/exec_pass.h"
+
+namespace mxnet {
+namespace op {
+
+inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
+  return sym.ListInputNames(nnvm::Symbol::kAll).size();
+}
+
+inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
+  return sym.ListOutputNames().size();
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
+  return sym.ListInputNames(nnvm::Symbol::kAll);
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = *attrs.subgraphs[0];
+  return sym.ListOutputNames();
+}
+
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+                                   std::vector<TShape> *in_shapes,
+                                   std::vector<TShape> *out_shapes) {
+  using namespace exec;
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
+  CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
+
+  // Put the input and output shapes to the shape vector.
+  nnvm::ShapeVector shapes(idx_g.num_node_entries());
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_shapes->size());
+  for (size_t i = 0; i < in_shapes->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    shapes[eid] = in_shapes->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_shapes->size());
+  for (size_t i = 0; i < out_shapes->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    shapes[eid] = out_shapes->at(i);
+  }
+
+  // Infer shape of the graph.
+  g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+  g = exec::InferShape(std::move(g));
+
+  // Copy the inferred shape back to the input shapes and the output shapes.
+  shapes = g.GetAttr<nnvm::ShapeVector>("shape");
+  // assign to in_shapes
+  for (size_t i = 0; i < in_shapes->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]);
+  }
+  // assign to out_shapes
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
+  }
+  // Check if we have inferred the shapes correctly.
+  return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+                                  std::vector<int> *in_types,
+                                  std::vector<int> *out_types) {
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
+  CHECK_EQ(idx_g.outputs().size(), out_types->size());
+
+  // Put the input and output data types to the dtype vector.
+  nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_types->size());
+  for (size_t i = 0; i < in_types->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    types[eid] = in_types->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_types->size());
+  for (size_t i = 0; i < out_types->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    types[eid] = out_types->at(i);
+  }
+
+  // Infer data type of the graph.
+  g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
+  g = exec::InferType(std::move(g));
+
+  types = g.GetAttr<nnvm::DTypeVector>("dtype");
+  // assign to in_types
+  for (size_t i = 0; i < in_types->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    TYPE_ASSIGN_CHECK(*in_types, i, types[eid]);
+  }
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
+  }
+  // Check if we have inferred the dtypes correctly.
+  return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int>* in_stypes,
+                                         std::vector<int>* out_stypes) {
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
+  CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
+  exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
+
+  // Put the input and output storages to the storage vector.
+  StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_stypes->size());
+  for (size_t i = 0; i < in_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    stypes[eid] = in_stypes->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_stypes->size());
+  for (size_t i = 0; i < out_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    stypes[eid] = out_stypes->at(i);
+  }
+
+  // Infer storage type of the graph.
+  bool dev_match = g.attrs.count("dev_mask") &&
+                   g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
+  if (!dev_match) {
+    g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
+  }
+  g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
+  g = exec::InferStorageType(std::move(g));
+
+  stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  // assign to in_types
+  for (size_t i = 0; i < in_stypes->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes[eid]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
+  }
+  // Check if we have inferred the storages correctly.
+  return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
+}
+
+inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+}
+
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+  const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
+  const std::vector<std::string> immutable_input_names =
+    subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+  const std::vector<std::string> mutable_input_names =
+    subgraph_sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+  CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size());
+  std::vector<uint32_t> ret;
+  size_t i1 = 0, i2 = 0;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) {
+      ++i1;
+    } else {
+      CHECK(i2 < mutable_input_names.size());
+      CHECK_EQ(input_names[i], mutable_input_names[i2]);
+      ++i2;
+      ret.push_back(i);
+    }
+  }
+  return ret;
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+  static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
+  std::set<ResourceRequest::Type> resource_types;
+  DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
+    if (!node->is_variable() && fresource.count(node->op())) {
+      for (ResourceRequest& r : fresource[node->op()](node->attrs)){
+        resource_types.insert(r.type);
+      }
+    }
+  });
+  return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_COMMON_H_
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
new file mode 100644
index 00000000000..d5fb7ee2db6
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -0,0 +1,112 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#include <mxnet/ndarray.h>
+#include "./common.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
+
+namespace mxnet {
+namespace op {
+
+#define DEBUG_SUBGRAPH 0
+
+class DefaultSubgraphOperator {
+ public:
+  explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
+    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
+                                            {"static_shape", "true"}}));
+  }
+
+  void Forward(const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs);
+  void Backward(const OpContext& ctx,
+                const std::vector<NDArray>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<NDArray>& outputs) {
+    LOG(FATAL) << "Not implemented";
+  }
+
+ private:
+  nnvm::Symbol subgraph_sym_;
+  CachedOpPtr subgraph_exec_;
+};
+
+void DefaultSubgraphOperator::Forward(const OpContext& ctx,
+                                      const std::vector<NDArray>& inputs,
+                                      const std::vector<OpReqType>& req,
+                                      const std::vector<NDArray>& outputs) {
+  std::vector<NDArray> tmp_inputs = inputs;
+  std::vector<NDArray*> input_ptrs;
+  input_ptrs.reserve(inputs.size());
+  for (auto& nd : tmp_inputs) {
+    input_ptrs.push_back(&nd);
+  }
+  std::vector<NDArray> tmp_outputs = outputs;
+  std::vector<NDArray*> output_ptrs;
+  for (auto& nd : tmp_outputs) {
+    output_ptrs.push_back(&nd);
+  }
+#if DEBUG_SUBGRAPH
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
+  }
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
+  }
+#endif
+  subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
+}
+
+OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
+                                        Context ctx,
+                                        const std::vector<TShape>& in_shapes,
+                                        const std::vector<int>& in_types) {
+  return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
+}
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+  DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
+  op.Forward(ctx, inputs, req, outputs);
+}
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
+.set_num_inputs(DefaultSubgraphOpNumInputs)
+.set_num_outputs(DefaultSubgraphOpNumOutputs)
+.set_attr<nnvm::FListInputNames>("FListInputNames", DefaultSubgraphOpListInputs)
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
+.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
+.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
+.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DefaultSubgraphOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
new file mode 100644
index 00000000000..008826b21d7
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file default_subgraph_op.cu
+ * \brief GPU Implementation of subgraph operations
+ */
+
+#include <mxnet/ndarray.h>
+#include "./common.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
+
+namespace mxnet {
+namespace op {
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs);
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", DefaultSubgraphOpForward);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc
new file mode 100644
index 00000000000..c8d3e9ffd43
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_property.cc
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <vector>
+#include <string>
+#include "./common.h"
+#include "./subgraph_property.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+ public:
+  explicit ContainOpSelector(const std::unordered_set<std::string>& op_names)
+    : op_names_(op_names) {}
+
+  virtual bool Select(const nnvm::Node &seed_node) {
+    return !seed_node.is_variable() && op_names_.count(seed_node.op()->name);
+  }
+
+  virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node &input_node) {
+    return !input_node.is_variable() && op_names_.count(input_node.op()->name);
+  }
+
+  virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node &output_node) {
+    return !output_node.is_variable() && op_names_.count(output_node.op()->name);
+  }
+ private:
+  const std::unordered_set<std::string>& op_names_;
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+  static SubgraphPropertyPtr Create() { return std::make_shared<DefaultSubgraphProperty>(); }
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                           const int subgraph_id = 0) const {
+    nnvm::NodePtr n = nnvm::Node::Create();
+    n->attrs.op = Op::Get("_default_subgraph_op");
+    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+    n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+    return n;
+  }
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+    return std::make_shared<ContainOpSelector>(
+        this->GetAttr<std::unordered_set<std::string>>("op_names"));
+  }
+};
+
+MXNET_REGISTER_SUBGRAPH_PROPERTY(default, DefaultSubgraphProperty);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
new file mode 100644
index 00000000000..315f7eec00c
--- /dev/null
+++ b/src/operator/subgraph/partition_graph.cc
@@ -0,0 +1,774 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+#include <queue>
+
+#include "./subgraph_property.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+#define DEBUG_SUBGRAPH 0
+
+namespace sg {  // sg stands for subgraph
+
+struct SimpleNode;
+using SimpleNodePtr = std::shared_ptr<SimpleNode>;
+
+/*!
+ * \brief Node of the undirected graph which replicates the network structures
+ * of the computational graph. It is used to ease the graph traversal for finding
+ * subgraphs.
+ */
+struct SimpleNode {
+  static SimpleNodePtr Create() {
+    return std::make_shared<SimpleNode>();
+  }
+  SimpleNode() : label(-1), node(nullptr) {}
+  /*! subgraph label */
+  int label;
+  /*! the original node in the computational graph it references*/
+  nnvm::Node* node;
+  /*!
+   * \brief output nodes of the current node
+   * key is node ptr and value is an array of indices standing for the entry indices
+   * in key->inputs whose source is the current node.
+   */
+  std::unordered_map<nnvm::Node*, std::vector<size_t>> outputs;
+};  // struct SimpleNode
+
+#if DEBUG_SUBGRAPH
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+  std::string op_names = "";
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    op_names += simple_nodes[i]->node->attrs.name + ' ';
+  }
+  LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+  std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+    + ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
+  LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+  for (size_t i = 0; i < entries.size(); ++i) {
+    PrintNodeEntry(*entries[i]);
+  }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+                       std::vector<SimpleNodePtr>* simple_nodes) {
+  const auto& indexed_graph = g.indexed_graph();
+  simple_nodes->reserve(indexed_graph.num_nodes());
+  DFSVisit(g.outputs, [&](const NodePtr& node) {
+    SimpleNodePtr sn = SimpleNode::Create();
+    sn->node = node.get();
+    for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+      const auto& e = sn->node->inputs[i];
+      const auto input_nid = indexed_graph.node_id(e.node.get());
+      CHECK_LT(input_nid, simple_nodes->size());
+      auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+      auto it = input_node_outputs.find(sn->node);
+      if (it == input_node_outputs.end()) {
+        input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+      } else {
+        it->second.push_back(i);
+      }
+    }
+    simple_nodes->emplace_back(std::move(sn));
+  });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+                     const std::vector<SimpleNodePtr>& simple_nodes,
+                     std::vector<nnvm::Node*>* subgraph_nodes) {
+  for (auto n : *subgraph_nodes) {
+    const auto nid = g.indexed_graph().node_id(n);
+    simple_nodes[nid]->label = -1;
+  }
+  subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+                   SubgraphSelectorPtr subgraph_selector,
+                   const int label,
+                   const size_t snid,  // simple node id, this is a seed
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<nnvm::Node*>* subgraph_nodes,
+                   std::unordered_set<const nnvm::Node*>* excluded_nodes = nullptr) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::queue<SimpleNode*> node_queue;
+  if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+    CHECK_EQ(simple_nodes[snid]->label, -1);
+    simple_nodes[snid]->label = label;
+    node_queue.push(simple_nodes[snid].get());
+  }
+  // key: nodes that serve as input/output nodes to the subgraph
+  // value: pair of vectors of nodes in the subgraph. The first vector contains the
+  // output nodes of the key in the subgraph, and the second vector contains the
+  // input nodes of the key in the subgraph.
+  // If a non-subgraph node has inputs from the subgraph and the other non-subgraph node
+  // has outputs to the subgraph, and the first non-subgraph node is an ancestor
+  // of the second non-subgraph node, there exits a cycle.
+  // When breaking the cycle, we want to start from removing the node with the largest node id
+  // in the subgraph.
+  std::unordered_map<const nnvm::Node*,
+    std::pair<std::vector<const nnvm::Node*>,
+              std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+  while (!node_queue.empty()) {
+    SimpleNode* cur_node = node_queue.front();
+    node_queue.pop();
+    subgraph_nodes->push_back(cur_node->node);
+    // get qualified adjacent input nodes
+    for (auto& e : cur_node->node->inputs) {
+      const bool select_input = (!excluded_nodes || !excluded_nodes->count(e.node.get()))
+        && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+      if (select_input) {
+        // e.node is a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // e.node is an input node of the subgraph
+        non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+      }
+    }
+    // get qualified output nodes
+    for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
+      const bool select_output = (!excluded_nodes || !excluded_nodes->count(it->first))
+          && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+      if (select_output) {
+        // it->first is a subgraph node
+        const auto nid = indexed_graph.node_id(it->first);
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // it->first is an output node of the subgraph
+        non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+      }
+    }
+  }
+  // prepare to check if there is a cycle
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  std::vector<const nnvm::Node*> non_subgraph_nodes;
+  non_subgraph_nodes.reserve(non_subgraph_node_map.size());
+  for (auto& kv : non_subgraph_node_map) {
+    auto& output_nodes = kv.second.first;
+    std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+    auto& input_nodes = kv.second.second;
+    std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+    non_subgraph_nodes.push_back(kv.first);
+  }
+  // check whether there is a cycle between the subgraph and its input/output nodes
+  auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node* descendant,
+                         const std::vector<nnvm::Node*>& snodes) {
+    if (ancestor == descendant) return true;
+    std::stack<const nnvm::Node*> s;
+    s.push(descendant);
+    size_t count = 0;
+    while (!s.empty()) {
+      CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. There is probably"
+                                                    " a loop in the graph";
+      ++count;
+      const nnvm::Node* top = s.top();
+      s.pop();
+      if (top == ancestor) {
+        return true;
+      }
+      for (const auto& entry : top->inputs) {
+        // when searching for the ancestor, the path cannot cross any subgraph node
+        auto it = std::find(snodes.begin(), snodes.end(), entry.node.get());
+        if (it == snodes.end()) {
+          s.push(entry.node.get());
+        }
+      }
+    }
+    return false;
+  };
+  std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
+  int excluded_node_id = -1;
+  for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
+    auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
+    CHECK(it1 != non_subgraph_node_map.end());
+    auto& output_nodes = it1->second.first;  // has been top sorted
+    auto& input_nodes = it1->second.second;  // has been top sorted
+    if (!output_nodes.empty() && !input_nodes.empty()) {
+      // there is a loop between node i and the subgraph
+      const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+                                    indexed_graph.node_id(input_nodes.back()));
+      excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+    } else if (!input_nodes.empty()) {
+      // node i is an input to the subgraph, find out if there is a node j
+      // which is an output of the subgraph and also a child of node i.
+      for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
+        auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
+        CHECK(it2 != non_subgraph_node_map.end());
+        // i is topologically before j, j might be a direct/indirect output node of i
+        CHECK_LT(indexed_graph.node_id(it1->first), indexed_graph.node_id(it2->first));
+        if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first, *subgraph_nodes)) {
+          // found a loop
+          const auto node_id = std::max(indexed_graph.node_id(input_nodes.back()),
+                                        indexed_graph.node_id(it2->second.first.back()));
+          excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+        }
+      }
+    }
+  }
+
+  if (excluded_node_id != -1) {
+    CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+    CHECK_NE(excluded_node_id, static_cast<int>(snid))
+      << "A cycle is found in the computational graph between nodes "
+      << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+      << simple_nodes[snid]->node->attrs.name;
+    excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+    ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+    return false;
+  }
+  std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+  return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+                            SubgraphSelectorPtr subgraph_selector,
+                            const int label,
+                            const size_t snid,
+                            const std::vector<SimpleNodePtr>& simple_nodes,
+                            std::vector<nnvm::Node*>* subgraph_nodes) {
+  std::unordered_set<const nnvm::Node*> excluded_nodes;
+  const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+  size_t count = 0;
+  bool success = false;
+  while (!success && count < max_num_retry) {
+    success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes,
+                            subgraph_nodes, &excluded_nodes);
+    if (!success) {
+      CHECK(!excluded_nodes.empty());
+      std::string excluded_node_names;
+      for (auto node : excluded_nodes) {
+        excluded_node_names += node->attrs.name + ", ";
+      }
+      LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
+                << ". Excluding nodes " << excluded_node_names << "and retrying";
+    }
+    ++count;
+  }
+  if (!success) {
+    LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
+              << simple_nodes[snid]->node->attrs.name << " without success because a loop "
+                  "is always found between the subgraph and some other nodes. Will treat "
+                  "seed node " << simple_nodes[snid]->node->attrs.name
+              << "as a subgraph with one node";
+    CHECK(subgraph_nodes->empty());
+    simple_nodes[snid]->label = label;
+    subgraph_nodes->push_back(simple_nodes[snid]->node);
+  }
+}
+
+/*!
+ * \brief Given a vector of nodes, group them into individual subgraphs
+ * based upon their connectivity.
+ */
+void PostProcessNodeCandidates(const nnvm::Graph& g,
+                               const std::vector<nnvm::Node*>& nodes,
+                               const std::vector<SimpleNodePtr>& simple_nodes,
+                               std::vector<std::vector<SimpleNode*>>* subgraphs,
+                               size_t* subgraph_id) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::unordered_set<nnvm::Node*> node_set(nodes.begin(), nodes.end());
+  auto simple_node_cmp = [&] (const SimpleNode* node1, const SimpleNode* node2) {
+    return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
+  };
+  for (auto node : nodes) {
+    if (!node_set.count(node)) {
+      // The node has been included in a subgraph
+      continue;
+    }
+    std::queue<nnvm::Node*> q;
+    q.push(node);
+    CHECK_EQ(node_set.erase(node), 1U);
+    subgraphs->emplace_back();
+    const auto nid = indexed_graph.node_id(node);
+    simple_nodes[nid]->label = *subgraph_id;
+    subgraphs->back().push_back(simple_nodes[nid].get());
+    while (!q.empty()) {
+      nnvm::Node* cur_node = q.front();
+      q.pop();
+      for (auto& e : cur_node->inputs) {
+        auto in_it = node_set.find(e.node.get());
+        if (in_it != node_set.end()) {
+          q.push(*in_it);
+          const auto in_nid = indexed_graph.node_id(*in_it);
+          simple_nodes[in_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[in_nid].get());
+          node_set.erase(in_it);
+        }
+      }
+      const auto cur_nid = indexed_graph.node_id(cur_node);
+      const SimpleNode* cur_snode = simple_nodes[cur_nid].get();
+      for (const auto& kv : cur_snode->outputs) {
+        const auto out_it = node_set.find(kv.first);
+        if (out_it != node_set.end()) {
+          q.push(*out_it);
+          const auto out_nid = indexed_graph.node_id(*out_it);
+          simple_nodes[out_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[out_nid].get());
+          node_set.erase(out_it);
+        }
+      }
+    }
+    ++(*subgraph_id);
+    std::sort(subgraphs->back().begin(), subgraphs->back().end(), simple_node_cmp);
+  }
+  CHECK(node_set.empty());
+}
+
+/*!
+ * \brief Finds subgraphs with all nodes that meet certain criteria.
+ * All nodes in a subgraph are marked with the same label.
+ */
+void FindSubgraphs(Graph* g,
+                   const SubgraphProperty &subg_prop,
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
+  const auto& indexed_graph = g->indexed_graph();
+  CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  size_t subgraph_id = 0;
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    nnvm::Node* node = simple_nodes[i]->node;
+    auto subgraph_selector = subg_prop.CreateSubgraphSelector();
+    if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+      // pre-select nodes that can be grouped in a subgraph
+      std::vector<nnvm::Node*> preselected_nodes;
+      PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i, simple_nodes,
+                             &preselected_nodes);
+
+      // filter out unqualified pre-selected nodes
+      std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);
+
+      // make sure filtered_nodes is a subset of preselected_nodes
+      for (const auto n : filtered_nodes) {
+        const auto nit = std::find(preselected_nodes.begin(), preselected_nodes.end(), n);
+        CHECK(nit != preselected_nodes.end())
+          << "Node " << n->attrs.name << " is not found in the pre-selected subgraph nodes."
+             " Please make sure that no new nodes were added in your subgraph"
+             " selector's Filter function";
+      }
+
+      // make sure nodes are sorted
+      std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
+
+      // reset node labels that are not in filtered nodes
+      for (const auto n : preselected_nodes) {
+        const auto nit = std::find(filtered_nodes.begin(), filtered_nodes.end(), n);
+        if (nit == filtered_nodes.end()) {
+          simple_nodes[indexed_graph.node_id(n)]->label = -1;
+        }
+      }
+      // find out subgraphs from the filtered nodes
+      std::vector<std::vector<SimpleNode*>> subgraphs;
+      PostProcessNodeCandidates(*g, filtered_nodes, simple_nodes, &subgraphs, &subgraph_id);
+      if (!subgraphs.empty()) {
+        subgraph_nodes->insert(subgraph_nodes->end(), subgraphs.begin(), subgraphs.end());
+      }
+    }
+  }
+}
+
+/*!
+ * \brief Sorts entries according to their topological order.
+ * Note that entry ids cannot be used to sort entries.
+ * \param entry_top_order_map mapping from entry pointer to its topological position in the graph
+ * \param entries Node entries to be sorted
+ */
+void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+                 std::vector<nnvm::NodeEntry*>* entries) {
+  auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
+    const auto it1 = entry_top_order_map.find(e1);
+    CHECK(it1 != entry_top_order_map.end());
+    const auto it2 = entry_top_order_map.find(e2);
+    CHECK(it2 != entry_top_order_map.end());
+    return it1->second < it2->second;
+  };
+  std::sort(entries->begin(), entries->end(), entry_cmp);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param input_entries input entries of the subgraph
+ */
+void FindInputEntries(const Graph& g,
+                      const std::vector<SimpleNodePtr>& simple_nodes,
+                      const std::vector<SimpleNode*>& subgraph_nodes,
+                      const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+                      std::vector<nnvm::NodeEntry*>* input_entries) {
+  const auto& indexed_graph = g.indexed_graph();
+  int label = -1;
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+    if (label == -1) {
+      label = subgraph_nodes[i]->label;
+    } else {
+      CHECK_EQ(subgraph_nodes[i]->label, label);
+    }
+    auto& inputs = subgraph_nodes[i]->node->inputs;
+    for (size_t j = 0; j < inputs.size(); ++j) {
+      auto& e = inputs[j];
+      if (indexed_graph.exist(e.node.get())) {
+        // e's source node is not a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        // this is a node not belonging to the subgraph
+        if (simple_nodes[nid]->label != label) {
+          input_entries->push_back(&e);
+        }
+      } else {
+        // e's source node is a subgraph node.
+        // In this case, two subgraphs are adjacent.
+        input_entries->push_back(&e);
+      }
+    }
+  }
+  SortEntries(entry_top_order_map, input_entries);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param output_entries output entries of the subgraph
+ */
+void FindOutputEntries(Graph* g,
+                       const std::vector<SimpleNodePtr>& simple_nodes,
+                       const std::vector<SimpleNode*>& subgraph_nodes,
+                       const std::unordered_map<const nnvm::NodeEntry*, size_t>&
+                         entry_top_order_map,
+                       std::vector<nnvm::NodeEntry*>* output_entries) {
+  if (subgraph_nodes.empty()) return;
+  const auto& indexed_graph = g->indexed_graph();
+  int label = -1;
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+    if (label == -1) {
+      label = subgraph_nodes[i]->label;
+    } else {
+      CHECK_EQ(subgraph_nodes[i]->label, label);
+    }
+    for (auto it = subgraph_nodes[i]->outputs.begin();
+         it != subgraph_nodes[i]->outputs.end(); ++it) {
+      if (indexed_graph.exist(it->first)) {
+        // if the output node is a normal graph node (not a subgraph node)
+        const auto nid = indexed_graph.node_id(it->first);
+        // this is a node not belonging to the current subgraph
+        if (simple_nodes[nid]->label != label) {
+          for (auto idx : it->second) {
+            auto& e = simple_nodes[nid]->node->inputs[idx];
+            output_entries->push_back(&e);
+          }
+        }
+      } else {
+        // if the output node is a subgraph node
+        // two graphs are adjacent
+        for (auto idx : it->second) {
+          output_entries->push_back(&(it->first->inputs[idx]));
+        }
+      }
+    }
+  }
+  // Check if current subgraph contains a node which is the last node
+  // of the whole graph. If so, save its corresponding entry as well.
+  for (size_t i = 0; i < g->outputs.size(); ++i) {
+    auto& entry = g->outputs[i];
+    // The entry might has been updated as an output of
+    // a subgraph node. In this case, no need
+    // to check its source for the current subgraph. Otherwise,
+    // do the following.
+    if (indexed_graph.exist(entry.node.get())) {
+      const auto nid = indexed_graph.node_id(entry.node.get());
+      if (simple_nodes[nid]->label == label) {
+        output_entries->push_back(&entry);
+      }
+    }
+  }
+  SortEntries(entry_top_order_map, output_entries);
+}
+
+/*!
+ * \brief Given a computation graph and a set of input node entries, this function cuts
+ * the node entries and creates new variable nodes as the input nodes of the
+ * subgraph. It returns the nodes that connect to the subgraph directly and
+ * the names of the new variable nodes.
+ */
+void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
+                    std::vector<nnvm::NodeEntry> *orig_entries,
+                    const bool skip_var = false) {
+  orig_entries->resize(input_entries.size());
+  // map for creating unique var nodes for deduplicating entries from the same node
+  std::unordered_map<std::string, int> name_count_map;
+  for (size_t i = 0; i < input_entries.size(); ++i) {
+    nnvm::NodeEntry *e = input_entries[i];
+    // If the node is a variable itself, we may want to skip the node.
+    if (e->node->is_variable() && skip_var) {
+      continue;
+    }
+
+    orig_entries->at(i) = *e;
+    nnvm::Symbol sym;
+    sym.outputs.push_back(*e);
+    const auto output_names = sym.ListOutputNames();
+    CHECK_EQ(output_names.size(), 1U);
+    const std::string& var_name = output_names[0];
+    auto it = name_count_map.find(var_name);
+    if (name_count_map.end() == it) {
+      name_count_map.emplace(var_name, 0);
+    } else {
+      ++(it->second);
+    }
+    nnvm::NodePtr n = nnvm::CreateVariableNode(var_name + std::to_string(name_count_map[var_name]));
+    *e = nnvm::NodeEntry{n, 0, 0};
+  }
+}
+
+/*!
+ * \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node
+ * and keep the subgraph in the subgraph node. The input entries and output entries
+ * of the subgraph node are kept in the same order as the subgraph's.
+ */
+void CreateSubgraphNode(Graph* g,
+                        const std::vector<SimpleNodePtr>& simple_nodes,
+                        const std::vector<SimpleNode*>& subgraph_nodes,
+                        const size_t subgraph_id,
+                        std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+#if DEBUG_SUBGRAPH
+  LOG(INFO) << "Searching for input entries...";
+#endif
+  std::vector<nnvm::NodeEntry*> input_entries;
+  FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
+  std::vector<nnvm::NodeEntry> orig_input_entries;
+  CutGraphInputs(input_entries, &orig_input_entries, false);
+#if DEBUG_SUBGRAPH
+  PrintNodeEntries(input_entries);
+  LOG(INFO) << "Searching for output entries...";
+#endif
+  std::vector<nnvm::NodeEntry*> output_entries;
+  FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries);
+
+  // Create a subgraph for the subgraph node
+  nnvm::Symbol sym;
+  sym.outputs.resize(output_entries.size());
+  for (size_t i = 0; i < output_entries.size(); ++i) {
+    sym.outputs[i] = *output_entries[i];
+  }
+  const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
+  nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id);
+
+  // Connect the external nodes to the subgraph node.
+  for (size_t i = 0; i < output_entries.size(); ++i) {
+    *output_entries[i] = nnvm::NodeEntry{n, static_cast<uint32_t>(i), 0};
+  }
+  n->inputs = orig_input_entries;
+  const auto& indexed_graph = g->indexed_graph();
+  for (size_t i = 0; i < n->inputs.size(); ++i) {
+    auto& e = n->inputs[i];
+    // update entry_top_order_map with newly created orig_input_entries
+    auto it = entry_top_order_map->find(input_entries[i]);
+    CHECK(it != entry_top_order_map->end());
+    entry_top_order_map->emplace(&e, it->second);
+    // update input entries' source simple nodes' outputs map
+    nnvm::Node* node = e.node.get();
+    if (indexed_graph.exist(node)) {
+      const auto nid = indexed_graph.node_id(node);
+      SimpleNode* sn = simple_nodes[nid].get();
+      for (SimpleNode* dest_node : subgraph_nodes) {
+        sn->outputs.erase(dest_node->node);
+      }
+      sn->outputs[n.get()].push_back(i);
+    }
+  }
+#if DEBUG_SUBGRAPH
+  PrintNodeEntries(output_entries);
+#endif
+}
+
+}  // namespace sg
+
+/*!
+ * \brief Sort entries of all the nodes' inputs vectors in the topological order.
+ * This is going to be used to sort input/output entries of subgraphs to keep
+ * the topological order unchanged.
+ */
+void TopSortEntries(const Graph& g,
+                    std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+  CHECK(entry_top_order_map != nullptr);
+  std::unordered_set<const nnvm::Node*> visited;
+  // tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
+  std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
+  auto in_degree = [] (const nnvm::Node* node)->size_t {
+    if (!node) {
+      return 0;
+    }
+    CHECK_EQ(node->control_deps.size(), 0U);
+    return node->inputs.size();
+  };
+  for (auto& e : g.outputs) {
+    nnvm::Node* node = e.node.get();
+    if (visited.count(node) == 0U) {
+      s.emplace(node, 0U, &e);
+      visited.insert(node);
+    } else {
+      // The entry's source node has been visited before.
+      // Marking the order for it.
+      entry_top_order_map->emplace(&e, entry_top_order_map->size());
+    }
+    while (!s.empty()) {
+      auto& top = s.top();
+      if (std::get<1>(top) == in_degree(std::get<0>(top))) {
+        // The node's inputs has been exhausted.
+        entry_top_order_map->emplace(std::get<2>(top), entry_top_order_map->size());
+        s.pop();
+      } else {
+        // The node still has input entries not visited.
+        CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
+        auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
+        nnvm::Node* input_node = entry.node.get();
+        if (visited.count(input_node) == 0U) {
+          // The entry's source node has not been visited.
+          // Push the entry to the stack for marking order later.
+          s.emplace(input_node, 0U, &entry);
+          visited.insert(input_node);
+        } else {
+          // The entry's source node has been visited before.
+          // Marking the order for it.
+          entry_top_order_map->emplace(&entry, entry_top_order_map->size());
+        }
+      }
+    }
+  }
+}
+
+Graph PartitionGraph(Graph&& g) {
+  if (!g.HasAttr("subgraph_property")) {  // treat the whole graph as a subgraph
+    LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
+                 "The original graph is returned.";
+    return g;
+  }
+  using namespace sg;
+  const SubgraphPropertyPtr& subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
+  // top sort NodeEntry of all the nodes' inputs
+  std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
+  TopSortEntries(g, &entry_top_order_map);
+
+  // Create undirected graph for ease of finding subgraphs
+  std::vector<SimpleNodePtr> simple_nodes;
+  CreateSimpleGraph(g, &simple_nodes);
+  std::vector<std::vector<SimpleNode*>> subgraph_nodes;
+  FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes);
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+#if DEBUG_SUBGRAPH
+    std::set<SimpleNode*> simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end());
+    CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
+    PrintSubgraph(subgraph_nodes[i]);
+#endif
+    CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i, &entry_top_order_map);
+  }
+  return g;
+}
+
+NNVM_REGISTER_PASS(PartitionGraph)
+.describe("Partition a graph according to the user defined rules "
+          "in a derived class of SubgraphProperty")
+.set_body(PartitionGraph)
+.set_change_graph(true);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h
new file mode 100644
index 00000000000..cfbc1f83733
--- /dev/null
+++ b/src/operator/subgraph/subgraph_property.h
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
+
+#include <nnvm/node.h>
+#include <dmlc/base.h>
+#include <dmlc/thread_local.h>
+#include <unordered_map>
+#include <vector>
+#include <string>
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for the graph partitioning algorithm to select
+ * nodes to subgraphs.
+ * The algorithm first sorts all the nodes in topological order, and then
+ * loops through the sorted nodes and tries to find a subgraph starting
+ * from each node (we call it a seed node) that satisfies the following two conditions:
+ * 1. The node has not been selected before.
+ * 2. The function Select is called on the node and returns true.
+ *
+ * Expanding from this seed node, we do BFS to traverse the graph.
+ * During the traversal, we call SelectInput and SelectOutput to determine
+ * if a neighboring node of the current node should be selected as a candidate for the subgraph.
+ * The search continues when a new node is selected as a candidate, and terminates when no more
+ * qualified nodes are found. When the search ends, all of the candidate nodes will
+ * be passed to the function Filter to finalize the subgraph. The filtering gives
+ * developers the last opportunity to drop off some of the candidate nodes.
+ * By default, Filter returns all nodes as the subgraph nodes.
+ * If the pre-selected subgraph becomes disconnected because some
+ * nodes are filtered out in the Filter function, the algorithm will automatically convert
+ * the rest of the nodes to multiple valid subgraphs based upon their connectivity.
+ */
+class SubgraphSelector {
+ public:
+  virtual ~SubgraphSelector() {}
+  /*!
+   * \brief Determines if to search for other nodes to form a subgraph from the seed_node.
+   */
+  virtual bool Select(const nnvm::Node &seed_node) = 0;
+  /*!
+   * \brief Determines if to select input_node when traverse to the cur_node.
+   * \param cur_node the node for determining whether its input_node should be selected
+   * \param input_node the input node of the cur_node
+   */
+  virtual bool SelectInput(const nnvm::Node &cur_node, const nnvm::Node &input_node) = 0;
+  /*!
+   * \brief Determines if to select output_node when traverse to the cur_node.
+   * \param cur_node the node for determining whether its output_node should be selected
+   * \param output_node the output node of the cur_node
+   */
+  virtual bool SelectOutput(const nnvm::Node &cur_node, const nnvm::Node &output_node) = 0;
+  // Post processes pre-selected subgraph nodes. Return a list of nodes that
+  // users want to keep in subgraph(s).
+  virtual std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) {
+    return candidates;
+  }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+  // the criteria of selecting the subgraph nodes.
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+  // create an nnvm node for a given subgraph. Here users can customize how to
+  // execute the operators in the subgraph.
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+                                           const int subgraph_id = 0) const = 0;
+  // set an attr with name in the attr map
+  template<typename T>
+  SubgraphProperty& SetAttr(const std::string& name, const T& value) {
+    attrs_[name] = std::make_shared<dmlc::any>(value);
+    return *this;
+  }
+  // get the attr with the name
+  template<typename T>
+  const T& GetAttr(const std::string& name) const {
+    auto it = attrs_.find(name);
+    CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty";
+    return nnvm::get<T>(*it->second);
+  }
+ protected:
+  std::unordered_map<std::string, std::shared_ptr<nnvm::any>> attrs_;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+class SubgraphPropertyRegistry {
+ public:
+  typedef SubgraphPropertyPtr (*SubgraphPropertyCreateFn)(void);
+  static SubgraphPropertyRegistry* Get() {
+    static SubgraphPropertyRegistry inst;
+    return &inst;
+  }
+
+  SubgraphPropertyPtr CreateSubgraphProperty(const std::string& name) {
+    auto it = prop_fn_map_.find(name);
+    CHECK(it != prop_fn_map_.end()) << "SubgraphProperty " << name
+                                    << " is not found in SubgraphPropertyRegistry";
+    return it->second();
+  }
+
+  SubgraphPropertyCreateFn __REGISTER_OR_GET__(const std::string& name,
+                                               SubgraphPropertyCreateFn fn) {
+    if (prop_fn_map_.count(name) == 0U) {
+      return __REGISTER__(name, fn);
+    } else {
+      return prop_fn_map_.at(name);
+    }
+  }
+
+ private:
+  SubgraphPropertyCreateFn __REGISTER__(const std::string& name, SubgraphPropertyCreateFn fn) {
+    CHECK_EQ(prop_fn_map_.count(name), 0U) << "Subgraph property " << name
+                                           << " has been registered";
+    prop_fn_map_[name] = fn;
+    return prop_fn_map_[name];
+  }
+
+  SubgraphPropertyRegistry() = default;
+  SubgraphPropertyRegistry(const SubgraphPropertyRegistry&) = delete;
+  SubgraphPropertyRegistry(SubgraphPropertyRegistry&&) = delete;
+  SubgraphPropertyRegistry& operator=(const SubgraphPropertyRegistry&) = delete;
+  std::unordered_map<std::string, SubgraphPropertyCreateFn> prop_fn_map_;
+};
+
+// This op name set is for setting the names of operators that should be grouped into
+// subgraphs. In practice, every backend accelerator should have a predefined name set.
+// This set is only used for the testing purpose.
+// key: property name, value: op name set
+typedef dmlc::ThreadLocalStore<std::unordered_map<std::string, std::unordered_set<std::string>>>
+  SubgraphPropertyOpNameSet;
+
+#define MXNET_REGISTER_SUBGRAPH_PROPERTY(Name, SubgraphPropertyType) \
+  static DMLC_ATTRIBUTE_UNUSED auto __make_ ## SubgraphPropertyType ## _ ## Name ## __ = \
+    SubgraphPropertyRegistry::Get()->__REGISTER_OR_GET__(#Name, &SubgraphPropertyType::Create)
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_PROPERTY_H_
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index 92d0958c463..6d669c19bca 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -275,6 +275,64 @@ TEST(Engine, basics) {
   LOG(INFO) << "All pass";
 }
 
+TEST(Engine, VarVersion) {
+  const size_t num_engines = 3;
+  std::vector<mxnet::Engine*> engines(num_engines);
+  engines[0] = mxnet::engine::CreateNaiveEngine();
+  engines[1] = mxnet::engine::CreateThreadedEnginePooled();
+  engines[2] = mxnet::engine::CreateThreadedEnginePerDevice();
+  std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled", "ThreadedEnginePerDevice"};
+  for (size_t k = 0; k < num_engines; ++k) {
+    auto engine = engines[k];
+    std::vector<mxnet::Engine::OprHandle> oprs;
+
+    LOG(INFO) << "Testing var as a read dependency in " << type_names[k];
+    auto var = engine->NewVariable();
+    EXPECT_EQ(var->version(), 0U);
+    for (int i = 0; i < 10; ++i) {
+      oprs.push_back(engine->NewOperator(
+          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+            Foo(ctx, i);
+            cb();
+          },
+          {var}, {}));
+      engine->Push(oprs.at(i), mxnet::Context{});
+    }
+    engine->WaitForAll();
+    EXPECT_EQ(var->version(), 0U);
+    for (auto&& i : oprs) {
+      engine->DeleteOperator(i);
+    }
+    engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+    engine->WaitForAll();
+
+    LOG(INFO) << "Testing var as a write dependency in " << type_names[k];
+    var = engine->NewVariable();
+    EXPECT_EQ(var->version(), 0U);
+    oprs.clear();
+    for (int i = 0; i < 10; ++i) {
+      oprs.push_back(engine->NewOperator(
+          [i](mxnet::RunContext ctx, mxnet::Engine::CallbackOnComplete cb) {
+            Foo(ctx, i);
+            cb();
+          },
+          {}, {var}));
+      engine->Push(oprs.at(i), mxnet::Context{});
+    }
+    engine->WaitForAll();
+    EXPECT_EQ(var->version(), 10U);
+    for (auto&& i : oprs) {
+      engine->DeleteOperator(i);
+    }
+    engine->DeleteVariable([](mxnet::RunContext) {}, mxnet::Context{}, var);
+    engine->WaitForAll();
+
+    var = nullptr;
+    oprs.clear();
+    LOG(INFO) << "All pass";
+  }
+}
+
 #ifdef _OPENMP
 
 struct TestSaveAndRestoreOMPState {
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 5612b0a647e..0ff33e1e409 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -41,6 +41,7 @@
 from test_sparse_ndarray import *
 from test_sparse_operator import *
 from test_ndarray import *
+from test_subgraph_op import *
 
 set_default_context(mx.gpu(0))
 del test_support_vector_machine_l1_svm  # noqa
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
new file mode 100644
index 00000000000..40d609ad354
--- /dev/null
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import ctypes
+import mxnet as mx
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, c_str
+from mxnet.symbol import Symbol
+import numpy as np
+from mxnet.test_utils import assert_almost_equal
+
+
+def test_subgraph_exe():
+    def _check_subgraph_exe1(sym, op_names):
+        """Use the partitioned sym to simple_bind an executor and compare the outputs
+        with those of the original executor"""
+        out = SymbolHandle()
+        check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, c_str('default'), mx_uint(len(op_names)),
+                                                  c_str_array(op_names), ctypes.byref(out)))
+
+        partitioned_sym = Symbol(out)
+        assert partitioned_sym.list_inputs() == sym.list_inputs()
+        assert partitioned_sym.list_arguments() == sym.list_arguments()
+        assert partitioned_sym.list_auxiliary_states() == sym.list_auxiliary_states()
+        exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+        partitioned_exe = partitioned_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+        input_names = sym.list_inputs()
+        for name in input_names:
+            if name in exe.arg_dict:
+                exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)
+                partitioned_exe.arg_dict[name][:] = exe.arg_dict[name]
+            else:
+                assert name in exe.aux_dict
+                exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)
+                partitioned_exe.aux_dict[name][:] = exe.aux_dict[name]
+        exe.forward()
+        partitioned_exe.forward()
+        assert len(exe.outputs) == len(partitioned_exe.outputs)
+        for i in range(len(exe.outputs)):
+            assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
+                                np.zeros(shape=(1,)))
+
+    def _check_subgraph_exe2(sym, op_names):
+        """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in simple_bind
+        and compare results of the partitioned sym and the original sym."""
+        def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
+            if subgraph_backend is not None:
+                os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+                check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+                                                             c_str_array(op_names)))
+            exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+            input_names = sym.list_inputs()
+            for name in input_names:
+                if name in exe.arg_dict:
+                    exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\
+                        if original_exec is None else original_exec.arg_dict[name]
+                else:
+                    assert name in exe.aux_dict
+                    exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
+                        if original_exec is None else original_exec.aux_dict[name]
+            exe.forward()
+            if subgraph_backend is not None:
+                check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+                del os.environ['MXNET_SUBGRAPH_BACKEND']
+            return exe
+
+        original_exec = get_executor(sym)
+        partitioned_exec = get_executor(sym, 'default', op_names, original_exec)
+        outputs1 = original_exec.outputs
+        outputs2 = partitioned_exec.outputs
+        assert len(outputs1) == len(outputs2)
+        for i in range(len(outputs1)):
+            assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+    def _check_subgraph_exe3(sym, op_names):
+        """Use the partitioned sym to bind an executor and compare the outputs
+        with those of the original executor"""
+        out = SymbolHandle()
+        check_call(_LIB.MXPartitionGraphByOpNames(sym.handle, c_str('default'), mx_uint(len(op_names)),
+                                                  c_str_array(op_names), ctypes.byref(out)))
+
+        partitioned_sym = Symbol(out)
+        input_names = sym.list_inputs()
+        arg_names = sym.list_arguments()
+        aux_names = sym.list_auxiliary_states()
+        assert partitioned_sym.list_inputs() == input_names
+        assert partitioned_sym.list_arguments() == arg_names
+        assert partitioned_sym.list_auxiliary_states() == aux_names
+        arg_shapes, _, aux_shapes = sym.infer_shape()
+        arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+        aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+        exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
+        partitioned_exe = partitioned_sym.bind(ctx=mx.current_context(), args=arg_array,
+                                               aux_states=aux_array, grad_req='null')
+        exe.forward()
+        partitioned_exe.forward()
+        assert len(exe.outputs) == len(partitioned_exe.outputs)
+        for i in range(len(exe.outputs)):
+            assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
+                                np.zeros(shape=(1,)))
+
+    def _check_subgraph_exe4(sym, op_names):
+        """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in bind
+        and compare results of the partitioned sym and the original sym."""
+        def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
+            if subgraph_backend is not None:
+                os.environ['MXNET_SUBGRAPH_BACKEND'] = subgraph_backend
+                check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
+                                                             c_str_array(op_names)))
+            arg_shapes, _, aux_shapes = sym.infer_shape()
+            if subgraph_backend is None:
+                arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
+                aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
+            else:
+                arg_array = None
+                aux_array = None
+            exe = sym.bind(ctx=mx.current_context(),
+                           args=arg_array if subgraph_backend is None else original_exec.arg_arrays,
+                           aux_states=aux_array if subgraph_backend is None else original_exec.aux_arrays,
+                           grad_req='null')
+            exe.forward()
+            if subgraph_backend is not None:
+                check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
+                del os.environ['MXNET_SUBGRAPH_BACKEND']
+            return exe
+
+        original_exec = get_executor(sym)
+        partitioned_exec = get_executor(sym, 'default', op_names, original_exec)
+        outputs1 = original_exec.outputs
+        outputs2 = partitioned_exec.outputs
+        assert len(outputs1) == len(outputs2)
+        for i in range(len(outputs1)):
+            assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
+
+    def check_subgraph_exe(sym, op_names):
+        _check_subgraph_exe1(sym, op_names)
+        _check_subgraph_exe2(sym, op_names)
+        _check_subgraph_exe3(sym, op_names)
+        _check_subgraph_exe4(sym, op_names)
+
+    def test_network_structure_1():
+        data1 = mx.sym.var('data1', shape=(2, 3, 10, 10))
+        data2 = mx.sym.var('data2')
+        conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
+        conv2 = mx.sym.Convolution(data=data2, no_bias=True, kernel=(1, 1), num_filter=1)
+        out = mx.sym.Group([conv1, conv2])
+        check_subgraph_exe(out, ['Convolution'])
+
+    def test_network_structure_2():
+        # this tests whether the partitioning algorithm can deal with cycles
+        data = mx.sym.var('data', shape=(2, 3, 10, 10))
+        ret = mx.sym.exp(data)
+        ret1 = mx.sym.cos(ret)
+        ret2 = mx.sym.sin(ret)
+        ret = ret1 + ret2
+        check_subgraph_exe(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+        check_subgraph_exe(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+
+    def test_network_structure_3():
+        # this tests whether the partitioned sym can distinguish in_args and aux_states
+        data = mx.sym.var('data', shape=(2, 3, 10, 10))
+        ret = mx.sym.exp(data)
+        ret1 = mx.sym.cos(ret)
+        ret2 = mx.sym.sin(ret)
+        ret = ret1 + ret2
+        ret = mx.sym.BatchNorm(ret)
+        ret = mx.sym.BatchNorm(ret)
+        check_subgraph_exe(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+        check_subgraph_exe(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+        check_subgraph_exe(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+        check_subgraph_exe(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+        check_subgraph_exe(ret, ['exp', 'BatchNorm'])
+        check_subgraph_exe(ret, ['BatchNorm'])
+
+    def test_network_structure_4():
+        # the last op has multiple duplicate outputs
+        data = mx.sym.var('data', shape=(2, 3, 10, 10))
+        ret = mx.sym.exp(data)
+        ret = mx.sym.Group([ret, ret, ret])
+        check_subgraph_exe(ret, ['exp'])
+
+    def test_network_structure_5():
+        # the subgraph has two duplicate input entries
+        data = mx.sym.var('data', shape=(2, 3, 10, 10))
+        ret = data + data
+        check_subgraph_exe(ret, ['_plus', '_Plus', 'elemwise_add'])
+
+    def test_network_structure_6():
+        def get_graph():
+            data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32)
+            data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
+            data3 = mx.sym.sin(data2)
+            conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1)
+            rets = [(conv, []),
+                    (conv, [mx.sym.sin.__name__]),
+                    (conv, [mx.sym.Convolution.__name__]),
+                    (conv, [mx.sym.sin.__name__, mx.sym.Convolution.__name__])]
+            return rets
+
+        for sym, op_names in get_graph():
+            check_subgraph_exe(sym, op_names)
+
+    def test_network_structure_7():
+        # in this graph, the subgraph node and the other two external nodes form a cycle
+        data = mx.sym.Variable('data', shape=(1,))
+        ret1 = mx.sym.sin(data)
+        ret2 = mx.sym.cos(ret1)
+        for _ in range(5):
+            ret2 = mx.sym.cos(ret2)
+        ret = ret1 + ret2
+        check_subgraph_exe(ret, ['sin', 'elemwise_add', '_plus', '_Plus'])
+
+    test_network_structure_1()
+    test_network_structure_2()
+    test_network_structure_3()
+    test_network_structure_4()
+    test_network_structure_5()
+    test_network_structure_6()
+    test_network_structure_7()
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services