You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2018/06/21 18:40:39 UTC

[incubator-mxnet] branch subgraph updated: Graph partitioner and subgraph op (#11251)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch subgraph
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/subgraph by this push:
     new 9567bd7  Graph partitioner and subgraph op (#11251)
9567bd7 is described below

commit 9567bd7bcb70bf0ccfdbb24a7212d723200e18bd
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu Jun 21 11:40:31 2018 -0700

    Graph partitioner and subgraph op (#11251)
    
    Graph partitioner and subgraph op
---
 example/subgraph_op/common                   |   1 +
 example/subgraph_op/imagenet_inference.py    | 166 +++++++
 include/mxnet/c_api.h                        |   5 +
 include/mxnet/engine.h                       |  22 +-
 include/mxnet/ndarray.h                      |   4 +
 include/mxnet/op_attr_types.h                |   7 +-
 src/c_api/c_api_symbolic.cc                  |  25 +
 src/engine/engine_impl.h                     |  10 +
 src/engine/naive_engine.cc                   |  31 +-
 src/engine/threaded_engine.cc                |  10 +-
 src/engine/threaded_engine.h                 |   1 +
 src/executor/graph_executor.cc               |   3 +
 src/imperative/imperative_utils.h            |  21 +-
 src/ndarray/ndarray.cc                       |   2 +
 src/operator/subgraph/common.h               | 270 +++++++++++
 src/operator/subgraph/default_subgraph_op.cc | 113 +++++
 src/operator/subgraph/default_subgraph_op.cu |  41 ++
 src/operator/subgraph/default_subgraph_op.h  | 127 +++++
 src/operator/subgraph/partition_graph.cc     | 687 +++++++++++++++++++++++++++
 tests/python/gpu/test_operator_gpu.py        |   1 +
 tests/python/unittest/test_gluon_trainer.py  |   1 +
 tests/python/unittest/test_subgraph_op.py    | 135 ++++++
 22 files changed, 1672 insertions(+), 11 deletions(-)

diff --git a/example/subgraph_op/common b/example/subgraph_op/common
new file mode 120000
index 0000000..cafb914
--- /dev/null
+++ b/example/subgraph_op/common
@@ -0,0 +1 @@
+../image-classification/common
\ No newline at end of file
diff --git a/example/subgraph_op/imagenet_inference.py b/example/subgraph_op/imagenet_inference.py
new file mode 100644
index 0000000..8a38cff
--- /dev/null
+++ b/example/subgraph_op/imagenet_inference.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import time
+import mxnet as mx
+from common import modelzoo
+from mxnet import nd
+from mxnet.contrib.quantization import *
+from mxnet.base import _LIB
+
+
+def download_dataset(dataset_url, dataset_dir, logger=None):
+    if logger is not None:
+        logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir))
+    mx.test_utils.download(dataset_url, dataset_dir)
+
+
+def download_model(model_name, logger=None):
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    model_path = os.path.join(dir_path, 'model')
+    if logger is not None:
+        logger.info('Downloading model %s... into path %s' % (model_name, model_path))
+    return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))
+
+
+def advance_data_iter(data_iter, n):
+    assert n >= 0
+    if n == 0:
+        return data_iter
+    has_next_batch = True
+    while has_next_batch:
+        try:
+            data_iter.next()
+            n -= 1
+            if n == 0:
+                return data_iter
+        except StopIteration:
+            has_next_batch = False
+
+
+def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger=None):
+    metrics = [mx.metric.create('acc'),
+               mx.metric.create('top_k_accuracy', top_k=5)]
+    if not isinstance(metrics, list):
+        metrics = [metrics, ]
+    mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ])
+    mod.bind(for_training=False,
+             data_shapes=data.provide_data,
+             label_shapes=data.provide_label)
+    mod.set_params(arg_params, aux_params)
+
+    tic = time.time()
+    num = 0
+    for batch in data:
+        mod.forward(batch, is_train=False)
+        for m in metrics:
+            mod.update_metric(m, batch.label)
+        num += batch_size
+        if max_num_examples is not None and num >= max_num_examples:
+            break
+
+    speed = num / (time.time() - tic)
+
+    if logger is not None:
+        logger.info('Finished inference with %d images' % num)
+        logger.info('Finished with %f images per second', speed)
+        for m in metrics:
+            logger.info(m.get())
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Score a model on a dataset')
+    parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
+                        help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
+    parser.add_argument('--batch-size', type=int, default=32)
+    parser.add_argument('--label-name', type=str, default='softmax_label')
+    parser.add_argument('--dataset', type=str, required=True, help='dataset path')
+    parser.add_argument('--rgb-mean', type=str, default='0,0,0')
+    parser.add_argument('--image-shape', type=str, default='3,224,224')
+    parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding')
+    parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference')
+    parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference')
+    parser.add_argument('--shuffle-dataset', action='store_true', default=True,
+                        help='shuffle the calibration dataset')
+    parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
+                        help='shuffling chunk seed, see'
+                             ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+                             ' for more details')
+    parser.add_argument('--shuffle-seed', type=int, default=48564309,
+                        help='shuffling seed, see'
+                             ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+                             ' for more details')
+
+    args = parser.parse_args()
+
+    logging.basicConfig()
+    logger = logging.getLogger('logger')
+    logger.setLevel(logging.INFO)
+    data_nthreads = args.data_nthreads
+    batch_size = args.batch_size
+    logger.info('batch size = %d for inference' % batch_size)
+
+    rgb_mean = args.rgb_mean
+    logger.info('rgb_mean = %s' % rgb_mean)
+    rgb_mean = [float(i) for i in rgb_mean.split(',')]
+    mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
+
+    label_name = args.label_name
+    logger.info('label_name = %s' % label_name)
+
+    image_shape = args.image_shape
+    data_shape = tuple([int(i) for i in image_shape.split(',')])
+    logger.info('Input data shape = %s' % str(data_shape))
+
+    dataset = args.dataset
+    download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
+    logger.info('Dataset for inference: %s' % dataset)
+
+    # creating data iterator
+    data = mx.io.ImageRecordIter(path_imgrec=dataset,
+                                 label_width=1,
+                                 preprocess_threads=data_nthreads,
+                                 batch_size=batch_size,
+                                 data_shape=data_shape,
+                                 label_name=label_name,
+                                 rand_crop=False,
+                                 rand_mirror=False,
+                                 shuffle=True,
+                                 shuffle_chunk_seed=3982304,
+                                 seed=48564309,
+                                 **mean_args)
+
+    # download model
+    prefix, epoch = download_model(model_name=args.model, logger=logger)
+    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+    op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
+    out = SymbolHandle()
+    check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names),
+                                     ctypes.byref(out)))
+    psym = Symbol(out)
+
+    # make sure that fp32 inference works on the same images as calibrated quantized model
+    logger.info('Skipping the first %d batches' % args.num_skipped_batches)
+    data = advance_data_iter(data, args.num_skipped_batches)
+
+    num_inference_images = args.num_inference_batches * batch_size
+    logger.info('Running model %s for inference' % args.model)
+    score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
+          max_num_examples=num_inference_images, logger=logger)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 4dd858a..8a714a9 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1457,6 +1457,11 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
                                                const float* high_quantiles,
                                                SymbolHandle* ret_sym_handle);
 
+MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
+                               const mx_uint num_ops,
+                               const char** op_names,
+                               SymbolHandle* ret_sym_handle);
+
 //--------------------------------------------
 // Part 4: Executor interface
 //--------------------------------------------
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index fd1fe89..2424a67 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -41,8 +41,26 @@ class Engine;
 
 /*! \brief namespace of engine internal types. */
 namespace engine {
-/*! \brief Internal representation of variable. */
-struct Var;
+/*! \brief base class of engine variables.*/
+struct Var {
+  virtual uint32_t version() {
+    return version_;
+  }
+  virtual ~Var() = default;
+  /*!
+   * \brief cast variable to derived type T
+   * \tparam T the type we want to cast into.
+   * \return A casted variable.
+   */
+  template <typename T>
+  inline T* Cast();
+  /*!
+   * \brief version number of the var. Every time the object it is associated with
+   * is modified, the version number is incremented by 1.
+   */
+  uint32_t version_{0};
+};  // struct Var
+
 /*! \brief Internal representation of operator.  */
 struct Opr;
 /*! \brief Variable pointer type, usually hold by user used to specify dependencies. */
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index faffe1b..f73a6ed 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -338,6 +338,10 @@ class NDArray {
   inline size_t byte_offset() const {
     return byte_offset_;
   }
+  /*! \brief return var version of the NDArray*/
+  inline uint32_t version() const {
+    return var()->version();
+  }
   /*!
    * \brief save the content into binary stream
    * \param strm the output stream
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index f4694ef..ebe8249 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -98,7 +98,12 @@ enum class ExecType {
    *  In current implementation, copy operator is specially handled by executor.
    *  This flag is used for special case treatment and future extension of different copy ops.
    */
-  kCrossDeviceCopy
+  kCrossDeviceCopy,
+  /*!
+   * A subgraph execution should happen in the main thread, instead of
+   * in the execution engine.
+   */
+  kSubgraphExec,
 };
 
 /*! \brief the dispatch mode of the operator */
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index e5e9b52..2f8c4f5 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -31,6 +31,7 @@
 #include "./c_api_common.h"
 #include "../operator/operator_common.h"
 #include "../executor/exec_pass.h"
+#include "../operator/subgraph/default_subgraph_op.h"
 
 namespace mxnet {
 namespace op {
@@ -625,3 +626,27 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
   *ret_qsym_handle = s;
   API_END_HANDLE_ERROR(delete s);
 }
+
+int MXPartitionGraph(SymbolHandle sym_handle,
+                     const mx_uint num_ops,
+                     const char** op_names,
+                     SymbolHandle* ret_sym_handle) {
+  nnvm::Symbol* s = new nnvm::Symbol();
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+  *s = sym->Copy();
+  nnvm::Graph g = Symbol2Graph(*s);
+  if (!op_name_set.empty()) {
+    mxnet::op::SubgraphPropertyPtr property
+        = std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
+    g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
+  }
+  g = ApplyPass(std::move(g), "PartitionGraph");
+  s->outputs = g.outputs;
+  *ret_sym_handle = s;
+  API_END_HANDLE_ERROR(delete s);
+}
diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h
index b3ec34d..9219b91 100644
--- a/src/engine/engine_impl.h
+++ b/src/engine/engine_impl.h
@@ -33,8 +33,12 @@
 namespace mxnet {
 namespace engine {
 
+#if 0
 /*! \brief base class of engine variables, used for type checking */
 struct Var {
+  virtual uint32_t version() {
+    return version_;
+  }
 #if ENGINE_DEBUG
   virtual ~Var() = default;
 #endif  // ENGINE_DEBUG
@@ -45,7 +49,13 @@ struct Var {
    */
   template <typename T>
   inline T* Cast();
+  /*!
+   * \brief version number of the var. Every time the object it is associated with
+   * is modified, the version number is incremented by 1.
+   */
+  uint32_t version_{0};
 };  // struct Var
+#endif
 
 /*! \brief base class of engine operators, used for type checking */
 struct Opr {
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 8196af2..e0a47fa 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -28,10 +28,24 @@
 #include "./engine_impl.h"
 #include "../profiler/profiler.h"
 #include "./openmp.h"
+#include "../common/object_pool.h"
 
 namespace mxnet {
 namespace engine {
 
+/*!
+ * \brief var used in Naive Engine for tracking the version
+ * of the objects it is associated with.
+ */
+class NaiveVar final
+    : public Var, public common::ObjectPoolAllocatable<NaiveVar> {
+ public:
+  inline static NaiveVar* CastFromBase(Var* ptr) {
+    return ptr->Cast<NaiveVar>();
+  }
+};  // class NaiveVar
+
+
 // implement naive engine
 class NaiveEngine final : public Engine {
  public:
@@ -71,8 +85,11 @@ class NaiveEngine final : public Engine {
 
   // new variables
   VarHandle NewVariable() override {
+    return NaiveVar::New();
+#if 0
     size_t v = ++counter_;
     return reinterpret_cast<VarHandle>(v);
+#endif
   }
 
   OprHandle NewOperator(AsyncFn fn,
@@ -165,14 +182,26 @@ class NaiveEngine final : public Engine {
     }
     CHECK(this->req_completed_)
         << "NaiveEngine only support synchronize Push so far";
+    // increment var version
+    for (auto var : mutable_vars) {
+      ++var->version_;
+    }
     if (profiling) {
       opr->opr_profile->stop();
     }
   }
 
   void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override {
+    NaiveVar* naive_var = NaiveVar::CastFromBase(var);
+    this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable {
+        delete_fn(ctx);
+        NaiveVar::Delete(naive_var);
+        on_complete();
+      }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable");
+#if 0
     this->PushSync(delete_fn, exec_ctx, {}, {var},
                    FnProperty::kNormal, 0, "DeleteVariable");
+#endif
   }
 
   void WaitForVar(VarHandle var) override {
@@ -192,8 +221,6 @@ class NaiveEngine final : public Engine {
   }
   // whether action is completed
   bool req_completed_;
-  // counter
-  std::atomic<size_t> counter_{0};
   /*! \brief whether it is during shutdown phase*/
   std::atomic<bool> shutdown_phase_{false};
   // CPU stream
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc19..bd11697 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -130,6 +130,9 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
     assert(pending_write_ != nullptr);
     CHECK_EQ(num_pending_reads_, kWriteTriggered);
 
+    // increment version number
+    ++version_;
+
     // really delete
     if (to_delete_) {
       VersionedVarBlock *head = pending_write_->next;
@@ -164,7 +167,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
   }
   // This is outside of lock scope
   // Be very carful, pending_write_ and num_pending_reads_
-  // can change now, do not reply ont the two variables.
+  // can change now, do not rely on these two variables.
   // The linked list \in [old_pending_write, end_of_read_chain)
   // is already detached from this Var.
   // So it is safe to modify these
@@ -196,6 +199,11 @@ inline bool ThreadedVar::ready_to_read() {
   return this->is_ready_to_read();
 }
 
+inline uint32_t ThreadedVar::version() {
+  std::lock_guard<std::mutex> lock{mutex_};
+  return this->version_;
+}
+
 // implementation of threaded engine
 ThreadedVar* ThreadedEngine::NewVariable() {
   return ThreadedVar::New(VersionedVarBlock::New());
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 428f0d8..7730c06 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,6 +162,7 @@ class ThreadedVar final
   inline void SetToDelete();
   /*! \return whether this variable is ready to read. */
   inline bool ready_to_read();
+  inline uint32_t version() override;
   /*!
    * \brief Cast a Var pointer to ThreadedVar pointer
    * \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f9..ae05fe4 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1614,6 +1614,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
       CHECK_EQ(opnode.exec->in_array.size(), 1U);
       CHECK_EQ(opnode.exec->out_array.size(), 1U);
       CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
+    } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) {
+      // If the node contains a subgraph, we can't execute it in the engine.
+      opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
     } else if (opnode.cached_opr != nullptr) {
       bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
       Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 726531d..08ea05a 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -434,7 +434,8 @@ inline void PushFComputeEx(const FComputeEx& fn,
       }
     };
 
-  if (exec_type == ExecType::kCrossDeviceCopy) {
+  if (exec_type == ExecType::kCrossDeviceCopy
+      || exec_type == ExecType::kSubgraphExec) {
     run(RunContext{ctx, nullptr});
   } else {
     CHECK(exec_type == ExecType::kSync);
@@ -475,12 +476,18 @@ inline void PushOperator(const OpStatePtr& state,
       InvalidateOutputs(outputs, req);
 #endif
       fcompute_ex(state, opctx, inputs, req, outputs);
-      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
+          && rctx.get_stream<gpu>()) {
         rctx.get_stream<gpu>()->Wait();
       }
     };
 
-    if (exec_type == ExecType::kSync) {
+    // For operators with subgraphs, we need to invoke them in the main thread
+    // instead of the threaded engine.
+    if (exec_type == ExecType::kSubgraphExec) {
+      RunContext rctx{ctx, nullptr};
+      run(rctx, engine::CallbackOnComplete());
+    } else if (exec_type == ExecType::kSync) {
       Engine::Get()->PushSync(
           [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
           ctx, read_vars, write_vars, FnProperty::kNormal, 0,
@@ -519,12 +526,16 @@ inline void PushOperator(const OpStatePtr& state,
         fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
         // post-fcompute fallback, cast to original storage type, if necessary
         CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-        if (is_gpu && exec_type == ExecType::kSync) {
+        if (is_gpu && exec_type == ExecType::kSync
+            && rctx.get_stream<gpu>()) {
           rctx.get_stream<gpu>()->Wait();
         }
       };
 
-    if (exec_type == ExecType::kSync) {
+    if (exec_type == ExecType::kSubgraphExec) {
+      RunContext rctx{ctx, nullptr};
+      run(rctx, engine::CallbackOnComplete());
+    } else if (exec_type == ExecType::kSync) {
       Engine::Get()->PushSync(
           [=](RunContext rctx) {
             run(rctx, engine::CallbackOnComplete());
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 94d3d90..583e2bf 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -39,6 +39,7 @@
 #include "../operator/tensor/matrix_op-inl.h"
 #include "../operator/tensor/init_op.h"
 #include "../operator/nn/mkldnn/mkldnn_base-inl.h"
+#include "../engine/engine_impl.h"
 
 #if MXNET_USE_OPENCV
 #include <opencv2/opencv.hpp>
@@ -2041,6 +2042,7 @@ void NDArray::SyncCheckFormat(const bool full_check) const {
   CHECK_EQ(err, kNormalErr) << "Check the validity of this sparse NDArray";
 }
 
+
 #if MXNET_PREDICT_ONLY == 0
 // register API function
 // those with underscore will be registered at NDArray
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
new file mode 100644
index 0000000..472312d
--- /dev/null
+++ b/src/operator/subgraph/common.h
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+
+#include <string>
+#include <set>
+#include <vector>
+#include "../elemwise_op_common.h"
+#include "../../executor/exec_pass.h"
+
+namespace mxnet {
+namespace op {
+namespace sg {
+
+struct SimpleNode;
+using SimpleNodePtr = std::shared_ptr<SimpleNode>;
+
+/*!
+ * \brief Node of the undirected graph which replicates the network structures
+ * of the computational graph. It is used to ease the graph traversal for finding
+ * subgraphs.
+ */
+struct SimpleNode {
+  static SimpleNodePtr Create() {
+    return std::make_shared<SimpleNode>();
+  }
+  SimpleNode() : label(-1), node(nullptr) {}
+  /*! subgraph label */
+  int label;
+  /*! the original node in the computational graph it references*/
+  nnvm::Node* node;
+  /*!
+   * \brief output nodes of the current node
+   * key is node ptr and value is an array of indices standing for the entry indices
+   * in key->inputs whose source is the current node.
+   */
+  std::unordered_map<nnvm::Node*, std::vector<size_t>> outputs;
+};  // struct SimpleNode
+}  // namespace sg
+
+inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListInputNames(nnvm::Symbol::kAll).size();
+}
+
+inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListOutputNames().size();
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListInputNames(nnvm::Symbol::kAll);
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListOutputNames();
+}
+
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+                                   std::vector<TShape> *in_shapes,
+                                   std::vector<TShape> *out_shapes) {
+  using namespace exec;
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
+  CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
+
+  // Put the input and output shapes to the shape vector.
+  nnvm::ShapeVector shapes(idx_g.num_node_entries());
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_shapes->size());
+  for (size_t i = 0; i < in_shapes->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    shapes[eid] = in_shapes->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_shapes->size());
+  for (size_t i = 0; i < out_shapes->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    shapes[eid] = out_shapes->at(i);
+  }
+
+  // Infer shape of the graph.
+  g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+  g = exec::InferShape(std::move(g));
+
+  // Copy the inferred shape back to the input shapes and the output shapes.
+  shapes = g.GetAttr<nnvm::ShapeVector>("shape");
+  // assign to in_shapes
+  for (size_t i = 0; i < in_shapes->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]);
+  }
+  // assign to out_shapes
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
+  }
+  // Check if we have inferred the shapes correctly.
+  return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+                                  std::vector<int> *in_types,
+                                  std::vector<int> *out_types) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
+  CHECK_EQ(idx_g.outputs().size(), out_types->size());
+
+  // Put the input and output data types to the dtype vector.
+  nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_types->size());
+  for (size_t i = 0; i < in_types->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    types[eid] = in_types->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_types->size());
+  for (size_t i = 0; i < out_types->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    types[eid] = out_types->at(i);
+  }
+
+  // Infer data type of the graph.
+  g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
+  g = exec::InferType(std::move(g));
+
+  types = g.GetAttr<nnvm::DTypeVector>("dtype");
+  // assign to in_types
+  for (size_t i = 0; i < in_types->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    TYPE_ASSIGN_CHECK(*in_types, i, types[eid]);
+  }
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
+  }
+  // Check if we have inferred the dtypes correctly.
+  return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int>* in_stypes,
+                                         std::vector<int>* out_stypes) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
+  CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
+  exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
+
+  // Put the input and output storages to the storage vector.
+  StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_stypes->size());
+  for (size_t i = 0; i < in_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    stypes[eid] = in_stypes->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_stypes->size());
+  for (size_t i = 0; i < out_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    stypes[eid] = out_stypes->at(i);
+  }
+
+  // Infer storage type of the graph.
+  bool dev_match = g.attrs.count("dev_mask") &&
+                   g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
+  if (!dev_match) {
+    g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
+  }
+  g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
+  g = exec::InferStorageType(std::move(g));
+
+  stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  // assign to in_types
+  for (size_t i = 0; i < in_stypes->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes[eid]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
+  }
+  // Check if we have inferred the storages correctly.
+  return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
+}
+
+inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+}
+
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
+  const std::vector<std::string> immutable_input_names =
+    subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+  const std::vector<std::string> mutable_input_names =
+    subgraph_sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+  CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size());
+  std::vector<uint32_t> ret;
+  size_t i1 = 0, i2 = 0;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) {
+      ++i1;
+    } else {
+      CHECK(i2 < mutable_input_names.size());
+      CHECK_EQ(input_names[i], mutable_input_names[i2]);
+      ++i2;
+      ret.push_back(i);
+    }
+  }
+  return ret;
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
+  std::set<ResourceRequest::Type> resource_types;
+  DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
+    if (!node->is_variable() && fresource.count(node->op())) {
+      for (ResourceRequest& r : fresource[node->op()](node->attrs)){
+        resource_types.insert(r.type);
+      }
+    }
+  });
+  return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
+}
+
+#if 0
+// TODO(junwu): add this attribute for visible outputs
+inline uint32_t DefaultSubgraphOpNumVisibleOutputs(const nnvm::NodeAttrs& attrs) {
+}
+#endif
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_COMMON_H_
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
new file mode 100644
index 0000000..8372ae9
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -0,0 +1,113 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#include <mxnet/ndarray.h>
+#include "./default_subgraph_op.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
+
+namespace mxnet {
+namespace op {
+
+#define SUBGRAPH_DEBUG 1
+
+class DefaultSubgraphOperator {
+ public:
+  explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
+    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}}));
+  }
+
+  void Forward(const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs);
+  void Backward(const OpContext& ctx,
+                const std::vector<NDArray>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<NDArray>& outputs) {
+    LOG(FATAL) << "Not implemented";
+  }
+
+ private:
+  nnvm::Symbol subgraph_sym_;
+  CachedOpPtr subgraph_exec_;
+};
+
+void DefaultSubgraphOperator::Forward(const OpContext& ctx,
+                                      const std::vector<NDArray>& inputs,
+                                      const std::vector<OpReqType>& req,
+                                      const std::vector<NDArray>& outputs) {
+  std::vector<NDArray> tmp_inputs = inputs;
+  std::vector<NDArray*> input_ptrs;
+  input_ptrs.reserve(inputs.size());
+  for (auto& nd : tmp_inputs) {
+    input_ptrs.push_back(&nd);
+  }
+  std::vector<NDArray> tmp_outputs = outputs;
+  std::vector<NDArray*> output_ptrs;
+  for (auto& nd : tmp_outputs) {
+    output_ptrs.push_back(&nd);
+  }
+#if SUBGRAPH_DEBUG
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
+  }
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
+  }
+#endif
+  subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
+}
+
+OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
+                                        Context ctx,
+                                        const std::vector<TShape>& in_shapes,
+                                        const std::vector<int>& in_types) {
+  const Symbol& subgraph_sym = nnvm::get<Symbol>(attrs.parsed);
+  return OpStatePtr::Create<DefaultSubgraphOperator>(subgraph_sym);
+}
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+  DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
+  op.Forward(ctx, inputs, req, outputs);
+}
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
+.set_num_inputs(DefaultSubgraphOpNumInputs)
+.set_num_outputs(DefaultSubgraphOpNumOutputs)
+.set_attr<nnvm::FListInputNames>("FListInputNames", DefaultSubgraphOpListInputs)
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
+.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
+.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
+.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DefaultSubgraphOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
+.set_attr<FResourceRequest>("FResourceRequest", DefaultSubgraphOpResourceRequest)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
new file mode 100644
index 0000000..15a76e3
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file subgraph_op.cu
+ * \brief GPU Implementation of subgraph operations
+ */
+
+#include "./default_subgraph_op.h"
+
+namespace mxnet {
+namespace op {
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs);
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", DefaultSubgraphOpForward);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.h b/src/operator/subgraph/default_subgraph_op.h
new file mode 100644
index 0000000..7d6624e
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
+#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
+
+#include <vector>
+#include <string>
+#include "./common.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for selecting nodes in a subgraph.
+ * When a node is passed to this object, the selection criteria may be changed.
+ * We can also specify what links we should use when traversing the neighbor
+ * nodes.
+ */
+class SubgraphSelector {
+ public:
+  virtual ~SubgraphSelector() {
+  }
+  // Determine if the node should be selected for a subgraph.
+  virtual bool Select(const nnvm::Node &n) = 0;
+  // Determine if the input node should be selected for a subgraph.
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+  // Determine if the output node should be selected for a subgraph.
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+  // Post processes pre-selected subgraph nodes. Return a list of nodes that
+  // users want to keep in subgraph(s).
+  virtual std::vector<nnvm::Node*> Filter(nnvm::Graph* g,
+                                          const std::vector<nnvm::Node*>& candidates) {
+    return candidates;
+  }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+  // the criteria of selecting the subgraph nodes.
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+  // create an nnvm node for a given subgraph. Here users can customize how to
+  // execute the operators in the subgraph.
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+                                           const int subgraph_id = 0) const = 0;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+void RegisterSubgraphProperty(SubgraphPropertyPtr property);
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+  std::shared_ptr<const std::unordered_set<std::string>> op_names;
+
+ public:
+  explicit ContainOpSelector(std::shared_ptr<const std::unordered_set<std::string>> op_names) {
+    this->op_names = op_names;
+  }
+
+  virtual bool Select(const nnvm::Node &n) {
+    return !n.is_variable() && op_names->count(n.op()->name);
+  }
+
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names->count(new_node.op()->name);
+  }
+
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names->count(new_node.op()->name);
+  }
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+  explicit DefaultSubgraphProperty(const std::unordered_set<std::string> &op_names) :
+    op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {}
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                           const int subgraph_id = 0) const {
+    nnvm::NodePtr n = nnvm::Node::Create();
+    n->attrs.op = Op::Get("_default_subgraph_op");
+    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+    n->attrs.parsed = sym;
+    return n;
+  }
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+    return std::make_shared<ContainOpSelector>(op_names_);
+  }
+
+ private:
+  std::shared_ptr<const std::unordered_set<std::string>> op_names_;
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
new file mode 100644
index 0000000..11af49a
--- /dev/null
+++ b/src/operator/subgraph/partition_graph.cc
@@ -0,0 +1,687 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+#include <queue>
+
+#include "./default_subgraph_op.h"
+#include "./common.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+// TODO(junwu): Change this to 0
+#define SUBGRAPH_DEBUG 1
+
+namespace sg {  // sg stands for subgraph
+
+#if SUBGRAPH_DEBUG
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+  std::string op_names = "";
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    op_names += simple_nodes[i]->node->attrs.name + ' ';
+  }
+  LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+  std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+    + ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
+  LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+  for (size_t i = 0; i < entries.size(); ++i) {
+    PrintNodeEntry(*entries[i]);
+  }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+                       std::vector<SimpleNodePtr>* simple_nodes) {
+  const auto& indexed_graph = g.indexed_graph();
+  simple_nodes->reserve(indexed_graph.num_nodes());
+  DFSVisit(g.outputs, [&](const NodePtr& node) {
+    SimpleNodePtr sn = SimpleNode::Create();
+    sn->node = node.get();
+    for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+      const auto& e = sn->node->inputs[i];
+      const auto input_nid = indexed_graph.node_id(e.node.get());
+      CHECK_LT(input_nid, simple_nodes->size());
+      auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+      auto it = input_node_outputs.find(sn->node);
+      if (it == input_node_outputs.end()) {
+        input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+      } else {
+        it->second.push_back(i);
+      }
+    }
+    simple_nodes->emplace_back(std::move(sn));
+  });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+                     const std::vector<SimpleNodePtr>& simple_nodes,
+                     std::vector<nnvm::Node*>* subgraph_nodes) {
+  for (auto n : *subgraph_nodes) {
+    const auto nid = g.indexed_graph().node_id(n);
+    simple_nodes[nid]->label = -1;
+  }
+  subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+                   SubgraphSelectorPtr subgraph_selector,
+                   const int label,
+                   const size_t snid,  // simple node id, this is a seed
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<nnvm::Node*>* subgraph_nodes,
+                   std::unordered_set<const nnvm::Node*>* excluded_nodes = nullptr) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::queue<SimpleNode*> node_queue;
+  if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+    CHECK_EQ(simple_nodes[snid]->label, -1);
+    simple_nodes[snid]->label = label;
+    node_queue.push(simple_nodes[snid].get());
+  }
+  // key: nodes that serve as input/output nodes to the subgraph
+  // value: pair of vectors of nodes in the subgraph. The first vector contains the
+  // output nodes of the key in the subgraph, and the second vector contains the
+  // input ndoes of the key in the subgraph. If both vectors are non-empty,
+  // it means there is a loop between the subgraph and the key node.
+  // When breaking the loop, we want to start removing the node with the largest node id.
+  std::unordered_map<const nnvm::Node*,
+    std::pair<std::vector<const nnvm::Node*>,
+              std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+  while (!node_queue.empty()) {
+    SimpleNode* cur_node = node_queue.front();
+    node_queue.pop();
+    subgraph_nodes->push_back(cur_node->node);
+    // get qualified adjacent input nodes
+    for (auto& e : cur_node->node->inputs) {
+      const bool select_input = (!excluded_nodes || !excluded_nodes->count(e.node.get()))
+        && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+      if (select_input) {
+        // e.node is a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // e.node is an input node of the subgraph
+        non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+      }
+    }
+    // get qualified output nodes
+    for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
+      const bool select_output = (!excluded_nodes || !excluded_nodes->count(it->first))
+          && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+      if (select_output) {
+        // it->first is a subgraph node
+        const auto nid = indexed_graph.node_id(it->first);
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // it->first is an output node of the subgraph
+        non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+      }
+    }
+  }
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  // check whether there is a loop between the subgraph and its input/output nodes
+  int excluded_node_id = -1;
+  for (auto& kv : non_subgraph_node_map) {
+    auto& output_nodes = kv.second.first;
+    auto& input_nodes = kv.second.second;
+    if (!output_nodes.empty() && !input_nodes.empty()) {
+      // there is a loop between kv->first and the subgraph
+      std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+      std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+      const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+                                    indexed_graph.node_id(input_nodes.back()));
+      excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+    }
+  }
+  if (excluded_node_id != -1) {
+    CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+    CHECK_NE(excluded_node_id, static_cast<int>(snid))
+      << "A cycle is found in the computational graph between nodes "
+      << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+      << simple_nodes[snid]->node->attrs.name;
+    excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+    ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+    return false;
+  }
+  std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+  return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+                            SubgraphSelectorPtr subgraph_selector,
+                            const int label,
+                            const size_t snid,
+                            const std::vector<SimpleNodePtr>& simple_nodes,
+                            std::vector<nnvm::Node*>* subgraph_nodes) {
+  std::unordered_set<const nnvm::Node*> excluded_nodes;
+  const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+  size_t count = 0;
+  bool success = false;
+  while (!success && count < max_num_retry) {
+    success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes,
+                            subgraph_nodes, &excluded_nodes);
+    if (!success) {
+      CHECK(!excluded_nodes.empty());
+      std::string excluded_node_names;
+      for (auto node : excluded_nodes) {
+        excluded_node_names += node->attrs.name + ", ";
+      }
+      LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
+                << ". Excluding nodes " << excluded_node_names << "and retrying";
+    }
+    ++count;
+  }
+  if (!success) {
+    LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
+              << simple_nodes[snid]->node->attrs.name << " without success because a loop "
+                  "is always found between the subgraph and some other nodes. Will treat "
+                  "seed node " << simple_nodes[snid]->node->attrs.name
+              << "as a subgraph with one node";
+    CHECK(subgraph_nodes->empty());
+    simple_nodes[snid]->label = label;
+    subgraph_nodes->push_back(simple_nodes[snid]->node);
+  }
+}
+
+/*!
+ * \brief Given a vector of nodes, group them into individual subgraphs
+ * based upon their connectivity.
+ */
+void PostProcessNodeCandidates(const nnvm::Graph& g,
+                               const std::vector<nnvm::Node*>& nodes,
+                               const std::vector<SimpleNodePtr>& simple_nodes,
+                               std::vector<std::vector<SimpleNode*>>* subgraphs,
+                               size_t* subgraph_id) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::unordered_set<nnvm::Node*> node_set(nodes.begin(), nodes.end());
+  auto simple_node_cmp = [&] (const SimpleNode* node1, const SimpleNode* node2) {
+    return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
+  };
+  for (auto node : nodes) {
+    if (!node_set.count(node)) {
+      // The node has been included in a subgraph
+      continue;
+    }
+    std::queue<nnvm::Node*> q;
+    q.push(node);
+    CHECK_EQ(node_set.erase(node), 1U);
+    subgraphs->emplace_back();
+    const auto nid = indexed_graph.node_id(node);
+    simple_nodes[nid]->label = *subgraph_id;
+    subgraphs->back().push_back(simple_nodes[nid].get());
+    while (!q.empty()) {
+      nnvm::Node* cur_node = q.front();
+      q.pop();
+      for (auto& e : cur_node->inputs) {
+        auto in_it = node_set.find(e.node.get());
+        if (in_it != node_set.end()) {
+          q.push(*in_it);
+          const auto in_nid = indexed_graph.node_id(*in_it);
+          simple_nodes[in_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[in_nid].get());
+          node_set.erase(in_it);
+        }
+      }
+      const auto cur_nid = indexed_graph.node_id(cur_node);
+      const SimpleNode* cur_snode = simple_nodes[cur_nid].get();
+      for (const auto& kv : cur_snode->outputs) {
+        const auto out_it = node_set.find(kv.first);
+        if (out_it != node_set.end()) {
+          q.push(*out_it);
+          const auto out_nid = indexed_graph.node_id(*out_it);
+          simple_nodes[out_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[out_nid].get());
+          node_set.erase(out_it);
+        }
+      }
+    }
+    ++(*subgraph_id);
+    std::sort(subgraphs->back().begin(), subgraphs->back().end(), simple_node_cmp);
+  }
+  CHECK(node_set.empty());
+}
+
+/*!
+ * \brief Finds subgraphs with all nodes that meet certain criteria.
+ * All nodes in a subgraph are marked with the same label.
+ */
+void FindSubgraphs(Graph* g,
+                   const SubgraphProperty &subg_prop,
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
+  const auto& indexed_graph = g->indexed_graph();
+  CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  size_t subgraph_id = 0;
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    nnvm::Node* node = simple_nodes[i]->node;
+    auto subgraph_selector = subg_prop.CreateSubgraphSelector();
+    if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+      // pre-select nodes that can be grouped in a subgraph
+      std::vector<nnvm::Node*> preselected_nodes;
+      PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i, simple_nodes,
+                             &preselected_nodes);
+
+      // filter out unqualified pre-selected nodes
+      std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(g, preselected_nodes);
+
+      // make sure filtered_nodes is a subset of preselected_nodes
+      for (const auto n : filtered_nodes) {
+        const auto nit = std::find(preselected_nodes.begin(), preselected_nodes.end(), n);
+        CHECK(nit != preselected_nodes.end())
+          << "Node " << n->attrs.name << " is not found in the pre-selected subgraph nodes."
+             " Please make sure that no new nodes were added in your subgraph"
+             " selector's Filter function";
+      }
+
+      // make sure nodes are sorted
+      std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
+
+      // reset node labels that are not in filtered nodes
+      for (const auto n : preselected_nodes) {
+        const auto nit = std::find(filtered_nodes.begin(), filtered_nodes.end(), n);
+        if (nit == filtered_nodes.end()) {
+          simple_nodes[indexed_graph.node_id(n)]->label = -1;
+        }
+      }
+      // find out subgraphs from the filtered nodes
+      std::vector<std::vector<SimpleNode*>> subgraphs;
+      PostProcessNodeCandidates(*g, filtered_nodes, simple_nodes, &subgraphs, &subgraph_id);
+      if (!subgraphs.empty()) {
+        subgraph_nodes->insert(subgraph_nodes->end(), subgraphs.begin(), subgraphs.end());
+      }
+    }
+  }
+}
+
+/*!
+ * \brief Sorts entries according to their topological order.
+ * Note that entry ids cannot be used to sort entries.
+ * \param entry_top_order_map mapping from entry pointer to its topological position in the graph
+ * \param entries Node entries to be sorted
+ */
+void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+                 std::vector<nnvm::NodeEntry*>* entries) {
+  auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
+    const auto it1 = entry_top_order_map.find(e1);
+    CHECK(it1 != entry_top_order_map.end());
+    const auto it2 = entry_top_order_map.find(e2);
+    CHECK(it2 != entry_top_order_map.end());
+    return it1->second < it2->second;
+  };
+  std::sort(entries->begin(), entries->end(), entry_cmp);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param input_entries input entries of the subgraph
+ */
+
+void FindInputEntries(const Graph& g,
+                      const std::vector<SimpleNodePtr>& simple_nodes,
+                      const std::vector<SimpleNode*>& subgraph_nodes,
+                      const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+                      std::vector<nnvm::NodeEntry*>* input_entries) {
+  const auto& indexed_graph = g.indexed_graph();
+  int label = -1;
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+    if (label == -1) {
+      label = subgraph_nodes[i]->label;
+    } else {
+      CHECK_EQ(subgraph_nodes[i]->label, label);
+    }
+    auto& inputs = subgraph_nodes[i]->node->inputs;
+    for (size_t j = 0; j < inputs.size(); ++j) {
+      auto& e = inputs[j];
+      if (indexed_graph.exist(e.node.get())) {
+        // e's source node is not a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        // this is a node not belonging to the subgraph
+        if (simple_nodes[nid]->label != label) {
+          input_entries->push_back(&e);
+        }
+      } else {
+        // e's source node is a subgraph node.
+        // In this case, two subgraphs are adjacent.
+        input_entries->push_back(&e);
+      }
+    }
+  }
+  SortEntries(entry_top_order_map, input_entries);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param output_entries output entries of the subgraph
+ */
+void FindOutputEntries(Graph* g,
+                       const std::vector<SimpleNodePtr>& simple_nodes,
+                       const std::vector<SimpleNode*>& subgraph_nodes,
+                       const std::unordered_map<const nnvm::NodeEntry*, size_t>&
+                         entry_top_order_map,
+                       std::vector<nnvm::NodeEntry*>* output_entries) {
+  if (subgraph_nodes.empty()) return;
+  const auto& indexed_graph = g->indexed_graph();
+  int label = -1;
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+    if (label == -1) {
+      label = subgraph_nodes[i]->label;
+    } else {
+      CHECK_EQ(subgraph_nodes[i]->label, label);
+    }
+    for (auto it = subgraph_nodes[i]->outputs.begin();
+         it != subgraph_nodes[i]->outputs.end(); ++it) {
+      if (indexed_graph.exist(it->first)) {
+        // if the output node is a normal graph node (not a subgraph node)
+        const auto nid = indexed_graph.node_id(it->first);
+        // this is a node not belonging to the current subgraph
+        if (simple_nodes[nid]->label != label) {
+          // TODO(zhengda) I need to test this.
+          for (auto idx : it->second) {
+            auto& e = simple_nodes[nid]->node->inputs[idx];
+            output_entries->push_back(&e);
+          }
+        }
+      } else {
+        // if the output node is a subgraph node
+        // two graphs are adjacent
+        for (auto idx : it->second) {
+          output_entries->push_back(&(it->first->inputs[idx]));
+        }
+      }
+    }
+  }
+  // Check if current subgraph contains a node which is the last node
+  // of the whole graph. If so, save its corresponding entry as well.
+  for (size_t i = 0; i < g->outputs.size(); ++i) {
+    auto& entry = g->outputs[i];
+    // The entry might has been updated as an output of
+    // a subgraph node. In this case, no need
+    // to check its source for the current subgraph. Otherwise,
+    // do the following.
+    if (indexed_graph.exist(entry.node.get())) {
+      const auto nid = indexed_graph.node_id(entry.node.get());
+      if (simple_nodes[nid]->label == label) {
+        output_entries->push_back(&entry);
+      }
+    }
+  }
+  SortEntries(entry_top_order_map, output_entries);
+}
+
+/*!
+ * \brief Given a computation graph and a set of input node entries, this function cuts
+ * the node entries and creates new variable nodes as the input nodes of the
+ * subgraph. It returns the nodes that connect to the subgraph directly and
+ * the names of the new variable nodes.
+ */
+void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
+                    std::vector<nnvm::NodeEntry> *orig_entries,
+                    const bool skip_var = false) {
+  orig_entries->resize(input_entries.size());
+  for (size_t i = 0; i < input_entries.size(); ++i) {
+    nnvm::NodeEntry *e = input_entries[i];
+    // If the node is a variable itself, we may want to skip the node.
+    if (e->node->is_variable() && skip_var) {
+      continue;
+    }
+
+    orig_entries->at(i) = *e;
+    nnvm::Symbol sym;
+    sym.outputs.push_back(*e);
+    const auto output_names = sym.ListOutputNames();
+    CHECK_EQ(output_names.size(), 1U);
+    nnvm::NodePtr n = nnvm::CreateVariableNode(output_names[0]);
+    *e = nnvm::NodeEntry{n, 0, 0};
+  }
+}
+
+/*!
+ * \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node
+ * and keep the subgraph in the subgraph node. The input entries and output entries
+ * of the subgraph node are kept in the same order as the subgraph's.
+ */
+void CreateSubgraphNode(Graph* g,
+                        const std::vector<SimpleNodePtr>& simple_nodes,
+                        const std::vector<SimpleNode*>& subgraph_nodes,
+                        const size_t subgraph_id,
+                        std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+#if SUBGRAPH_DEBUG
+  LOG(INFO) << "Searching for input entries...";
+#endif
+  std::vector<nnvm::NodeEntry*> input_entries;
+  FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
+  std::vector<nnvm::NodeEntry> orig_input_entries;
+  // TODO(junwu): Confirm what value to pass to skip_var
+  CutGraphInputs(input_entries, &orig_input_entries, false);
+#if SUBGRAPH_DEBUG
+  PrintNodeEntries(input_entries);
+  LOG(INFO) << "Searching for output entries...";
+#endif
+  std::vector<nnvm::NodeEntry*> output_entries;
+  FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries);
+
+  // Create a subgraph for the subgraph node
+  nnvm::Symbol sym;
+  sym.outputs.resize(output_entries.size());
+  for (size_t i = 0; i < output_entries.size(); ++i) {
+    sym.outputs[i] = *output_entries[i];
+  }
+  const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
+  nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id);
+
+  // Connect the external nodes to the subgraph node.
+  for (size_t i = 0; i < output_entries.size(); ++i) {
+    *output_entries[i] = nnvm::NodeEntry{n, static_cast<uint32_t>(i), 0};
+  }
+  n->inputs = orig_input_entries;
+  const auto& indexed_graph = g->indexed_graph();
+  for (size_t i = 0; i < n->inputs.size(); ++i) {
+    auto& e = n->inputs[i];
+    // update entry_top_order_map with newly created orig_input_entries
+    auto it = entry_top_order_map->find(input_entries[i]);
+    CHECK(it != entry_top_order_map->end());
+    CHECK_EQ(entry_top_order_map->count(&e), 0U);
+    entry_top_order_map->emplace(&e, it->second);
+    // update input entries' source simple nodes' outputs map
+    nnvm::Node* node = e.node.get();
+    if (indexed_graph.exist(node)) {
+      const auto nid = indexed_graph.node_id(node);
+      SimpleNode* sn = simple_nodes[nid].get();
+      for (SimpleNode* dest_node : subgraph_nodes) {
+        sn->outputs.erase(dest_node->node);
+      }
+      sn->outputs[n.get()].push_back(i);
+    }
+  }
+#if SUBGRAPH_DEBUG
+  PrintNodeEntries(output_entries);
+#endif
+}
+
+}  // namespace sg
+
+/*!
+ * \brief Sort entries of all the nodes' inputs vectors in the topological order.
+ * This is going to be used to sort input/output entries of subgraphs to keep
+ * the topological order unchanged.
+ */
+void TopSortEntries(const Graph& g,
+                    std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+  CHECK(entry_top_order_map != nullptr);
+  std::unordered_set<const nnvm::Node*> visited;
+  // tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
+  std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
+  auto in_degree = [] (const nnvm::Node* node)->size_t {
+    if (!node) {
+      return 0;
+    }
+    CHECK_EQ(node->control_deps.size(), 0U);
+    return node->inputs.size();
+  };
+  for (auto& e : g.outputs) {
+    nnvm::Node* node = e.node.get();
+    if (visited.count(node) == 0U) {
+      s.emplace(node, 0U, &e);
+      visited.insert(node);
+    }
+    while (!s.empty()) {
+      auto& top = s.top();
+      if (std::get<1>(top) == in_degree(std::get<0>(top))) {
+        // The node's inputs has been exhausted.
+        entry_top_order_map->emplace(std::get<2>(top), entry_top_order_map->size());
+        s.pop();
+      } else {
+        // The node still has input entries not visited.
+        CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
+        auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
+        nnvm::Node* input_node = entry.node.get();
+        if (visited.count(input_node) == 0U) {
+          // The entry's source node has not been visited.
+          // Push the entry to the stack for marking order later.
+          s.emplace(input_node, 0U, &entry);
+          visited.insert(input_node);
+        } else {
+          // The entry's source node has been visited before.
+          // Marking order for it.
+          entry_top_order_map->emplace(&entry, entry_top_order_map->size());
+        }
+      }
+    }
+  }
+}
+
+Graph PartitionGraph(Graph&& g) {
+  if (!g.HasAttr("subgraph_property")) {  // treat the whole graph as a subgraph
+    LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
+                 "The original graph is returned.";
+    return g;
+  }
+  using namespace sg;
+  const SubgraphPropertyPtr& subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
+  // top sort NodeEntry of all the nodes' inputs
+  std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
+  TopSortEntries(g, &entry_top_order_map);
+
+  // Create undirected graph for ease of finding subgraphs
+  std::vector<SimpleNodePtr> simple_nodes;
+  CreateSimpleGraph(g, &simple_nodes);
+  std::vector<std::vector<SimpleNode*>> subgraph_nodes;
+  FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes);
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+#if SUBGRAPH_DEBUG
+    std::set<SimpleNode*> simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end());
+    CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
+    PrintSubgraph(subgraph_nodes[i]);
+#endif
+    CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i, &entry_top_order_map);
+  }
+  return g;
+}
+
+NNVM_REGISTER_PASS(PartitionGraph)
+.describe("Partition a graph according to the user defined rules "
+          "in a derived class of SubgraphProperty")
+.set_body(PartitionGraph)
+.set_change_graph(true);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index ed4aaa4..9f20627 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -44,6 +44,7 @@ from test_gluon_rnn import *
 from test_sparse_ndarray import *
 from test_sparse_operator import *
 from test_ndarray import *
+from test_subgraph_op import *
 
 set_default_context(mx.gpu(0))
 del test_support_vector_machine_l1_svm
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index 1c59cea..f8833f8 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -175,6 +175,7 @@ def test_trainer_save_load():
     # check if parameter dict is correctly associated with optimizer after load_state
     assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
 
+@unittest.skip("temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11353")
 @with_seed()
 def test_trainer_reset_kv():
     def check_trainer_reset_kv(kv):
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
new file mode 100644
index 0000000..f08c42c
--- /dev/null
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+import mxnet as mx
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
+from mxnet.symbol import Symbol
+import numpy as np
+
+
+def test_subgraph():
+    def get_graph():
+        data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32)
+        data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
+        data3 = mx.sym.sin(data2)
+        conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1)
+        rets = []
+        rets.append((conv, []))
+        rets.append((conv, [mx.sym.sin.__name__]))
+        rets.append((conv, [mx.sym.Convolution.__name__]))
+        rets.append((conv, [mx.sym.sin.__name__, mx.sym.Convolution.__name__]))
+        return rets
+
+    for regular_sym, op_names in get_graph():
+        input_names = regular_sym.list_inputs()
+        shapes = regular_sym.infer_shape()
+        types = regular_sym.infer_type()
+        out = SymbolHandle()
+
+        check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)),
+            c_str_array(op_names), ctypes.byref(out)))
+        subgraph_sym = Symbol(out)
+        assert input_names == subgraph_sym.list_inputs()
+
+        print(subgraph_sym.list_outputs())
+        assert shapes == subgraph_sym.infer_shape()
+        assert types == subgraph_sym.infer_type()
+
+        regular_exec = regular_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+        subgraph_exec = subgraph_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+
+        for name in input_names:
+            regular_exec.arg_dict[name][:] = mx.nd.random.normal(
+                    shape=regular_exec.arg_dict[name].shape)
+            subgraph_exec.arg_dict[name][:] = regular_exec.arg_dict[name]
+
+        subgraph_exec.forward()
+        regular_exec.forward()
+        mx.nd.waitall()
+        assert (subgraph_exec.outputs[0] - regular_exec.outputs[0]).abs().sum().asscalar() == 0.0
+
+
+def test_input_name_order():
+    def check_input_order(sym, op_names):
+        out = SymbolHandle()
+        check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)),
+                                         c_str_array(op_names), ctypes.byref(out)))
+
+        new_sym = Symbol(out)
+        #print(sym.list_inputs())
+        #print(new_sym.list_inputs())
+        assert new_sym.list_inputs() == sym.list_inputs()
+        assert new_sym.list_arguments() == sym.list_arguments()
+        assert new_sym.list_auxiliary_states() == sym.list_auxiliary_states()
+        #print(new_sym.list_arguments())
+        #print(new_sym.list_auxiliary_states())
+        #print('original outputs: %s' % sym.list_outputs())
+        #print('new sym outputs: %s' % new_sym.list_outputs())
+
+    def test_network_structure_1():
+        data1 = mx.sym.var('data1')
+        data2 = mx.sym.var('data2')
+        conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
+        conv2 = mx.sym.Convolution(data=data2, weight=data1, no_bias=True, kernel=(2, 2), num_filter=1)
+        out = mx.sym.Group([conv1, conv2])
+        check_input_order(out, ['Convolution'])
+
+    def test_network_structure_2():
+        data1 = mx.sym.var('data1')
+        data2 = mx.sym.var('data2')
+        conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
+        conv2 = mx.sym.Convolution(data=data2, weight=data1, no_bias=True, kernel=(2, 2), num_filter=1)
+        out = conv1 + conv2
+        check_input_order(out, ['Convolution'])
+        check_input_order(out, ['Convolution', '_Plus', 'elemwise_add', '_plus'])
+
+    def test_network_structure_3():
+        # this tests whether the partitioning algorithm can deal with cycles
+        data = mx.sym.var('data')
+        ret = mx.sym.exp(data)
+        ret1 = mx.sym.cos(ret)
+        ret2 = mx.sym.sin(ret)
+        ret = ret1 + ret2
+        check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+        check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+
+    def test_network_structure_4():
+        # this tests whether the partitioned sym can distinguish in_args and aux_states
+        data = mx.sym.var('data')
+        ret = mx.sym.exp(data)
+        ret1 = mx.sym.cos(ret)
+        ret2 = mx.sym.sin(ret)
+        ret = ret1 + ret2
+        ret = mx.sym.BatchNorm(ret)
+        ret = mx.sym.BatchNorm(ret)
+        check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+        check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+        check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+        check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+        check_input_order(ret, ['exp', 'BatchNorm'])
+        check_input_order(ret, ['BatchNorm'])
+
+    test_network_structure_1()
+    test_network_structure_2()
+    test_network_structure_3()
+    test_network_structure_4()
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()