You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/21 18:40:33 UTC

[GitHub] reminisce closed pull request #11251: Graph partitioner and subgraph op

reminisce closed pull request #11251: Graph partitioner and subgraph op
URL: https://github.com/apache/incubator-mxnet/pull/11251
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/subgraph_op/common b/example/subgraph_op/common
new file mode 120000
index 00000000000..cafb9140ab6
--- /dev/null
+++ b/example/subgraph_op/common
@@ -0,0 +1 @@
+../image-classification/common
\ No newline at end of file
diff --git a/example/subgraph_op/imagenet_inference.py b/example/subgraph_op/imagenet_inference.py
new file mode 100644
index 00000000000..8a38cffc919
--- /dev/null
+++ b/example/subgraph_op/imagenet_inference.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import time
+import mxnet as mx
+from common import modelzoo
+from mxnet import nd
+from mxnet.contrib.quantization import *
+from mxnet.base import _LIB
+
+
+def download_dataset(dataset_url, dataset_dir, logger=None):
+    if logger is not None:
+        logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir))
+    mx.test_utils.download(dataset_url, dataset_dir)
+
+
+def download_model(model_name, logger=None):
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    model_path = os.path.join(dir_path, 'model')
+    if logger is not None:
+        logger.info('Downloading model %s... into path %s' % (model_name, model_path))
+    return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))
+
+
+def advance_data_iter(data_iter, n):
+    assert n >= 0
+    if n == 0:
+        return data_iter
+    has_next_batch = True
+    while has_next_batch:
+        try:
+            data_iter.next()
+            n -= 1
+            if n == 0:
+                return data_iter
+        except StopIteration:
+            has_next_batch = False
+
+
+def score(sym, arg_params, aux_params, data, devs, label_name, max_num_examples, logger=None):
+    metrics = [mx.metric.create('acc'),
+               mx.metric.create('top_k_accuracy', top_k=5)]
+    if not isinstance(metrics, list):
+        metrics = [metrics, ]
+    mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ])
+    mod.bind(for_training=False,
+             data_shapes=data.provide_data,
+             label_shapes=data.provide_label)
+    mod.set_params(arg_params, aux_params)
+
+    tic = time.time()
+    num = 0
+    for batch in data:
+        mod.forward(batch, is_train=False)
+        for m in metrics:
+            mod.update_metric(m, batch.label)
+        num += batch_size
+        if max_num_examples is not None and num >= max_num_examples:
+            break
+
+    speed = num / (time.time() - tic)
+
+    if logger is not None:
+        logger.info('Finished inference with %d images' % num)
+        logger.info('Finished with %f images per second', speed)
+        for m in metrics:
+            logger.info(m.get())
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Score a model on a dataset')
+    parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 'imagenet1k-inception-bn'],
+                        help='currently only supports imagenet1k-resnet-152 or imagenet1k-inception-bn')
+    parser.add_argument('--batch-size', type=int, default=32)
+    parser.add_argument('--label-name', type=str, default='softmax_label')
+    parser.add_argument('--dataset', type=str, required=True, help='dataset path')
+    parser.add_argument('--rgb-mean', type=str, default='0,0,0')
+    parser.add_argument('--image-shape', type=str, default='3,224,224')
+    parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding')
+    parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference')
+    parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference')
+    parser.add_argument('--shuffle-dataset', action='store_true', default=True,
+                        help='shuffle the calibration dataset')
+    parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
+                        help='shuffling chunk seed, see'
+                             ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+                             ' for more details')
+    parser.add_argument('--shuffle-seed', type=int, default=48564309,
+                        help='shuffling seed, see'
+                             ' https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+                             ' for more details')
+
+    args = parser.parse_args()
+
+    logging.basicConfig()
+    logger = logging.getLogger('logger')
+    logger.setLevel(logging.INFO)
+    data_nthreads = args.data_nthreads
+    batch_size = args.batch_size
+    logger.info('batch size = %d for inference' % batch_size)
+
+    rgb_mean = args.rgb_mean
+    logger.info('rgb_mean = %s' % rgb_mean)
+    rgb_mean = [float(i) for i in rgb_mean.split(',')]
+    mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
+
+    label_name = args.label_name
+    logger.info('label_name = %s' % label_name)
+
+    image_shape = args.image_shape
+    data_shape = tuple([int(i) for i in image_shape.split(',')])
+    logger.info('Input data shape = %s' % str(data_shape))
+
+    dataset = args.dataset
+    download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
+    logger.info('Dataset for inference: %s' % dataset)
+
+    # creating data iterator
+    data = mx.io.ImageRecordIter(path_imgrec=dataset,
+                                 label_width=1,
+                                 preprocess_threads=data_nthreads,
+                                 batch_size=batch_size,
+                                 data_shape=data_shape,
+                                 label_name=label_name,
+                                 rand_crop=False,
+                                 rand_mirror=False,
+                                 shuffle=True,
+                                 shuffle_chunk_seed=3982304,
+                                 seed=48564309,
+                                 **mean_args)
+
+    # download model
+    prefix, epoch = download_model(model_name=args.model, logger=logger)
+    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+    op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
+    out = SymbolHandle()
+    check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), c_str_array(op_names),
+                                     ctypes.byref(out)))
+    psym = Symbol(out)
+
+    # make sure that fp32 inference works on the same images as calibrated quantized model
+    logger.info('Skipping the first %d batches' % args.num_skipped_batches)
+    data = advance_data_iter(data, args.num_skipped_batches)
+
+    num_inference_images = args.num_inference_batches * batch_size
+    logger.info('Running model %s for inference' % args.model)
+    score(psym, arg_params, aux_params, data, [mx.gpu(0)], label_name,
+          max_num_examples=num_inference_images, logger=logger)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 4dd858a51c4..8a714a96a91 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1457,6 +1457,11 @@ MXNET_DLL int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
                                                const float* high_quantiles,
                                                SymbolHandle* ret_sym_handle);
 
+MXNET_DLL int MXPartitionGraph(SymbolHandle sym_handle,
+                               const mx_uint num_ops,
+                               const char** op_names,
+                               SymbolHandle* ret_sym_handle);
+
 //--------------------------------------------
 // Part 4: Executor interface
 //--------------------------------------------
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index fd1fe89bdba..2424a672055 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -41,8 +41,26 @@ class Engine;
 
 /*! \brief namespace of engine internal types. */
 namespace engine {
-/*! \brief Internal representation of variable. */
-struct Var;
+/*! \brief base class of engine variables.*/
+struct Var {
+  virtual uint32_t version() {
+    return version_;
+  }
+  virtual ~Var() = default;
+  /*!
+   * \brief cast variable to derived type T
+   * \tparam T the type we want to cast into.
+   * \return A casted variable.
+   */
+  template <typename T>
+  inline T* Cast();
+  /*!
+   * \brief version number of the var. Every time the object it is associated with
+   * is modified, the version number is incremented by 1.
+   */
+  uint32_t version_{0};
+};  // struct Var
+
 /*! \brief Internal representation of operator.  */
 struct Opr;
 /*! \brief Variable pointer type, usually hold by user used to specify dependencies. */
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index faffe1bdea9..f73a6edd34c 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -338,6 +338,10 @@ class NDArray {
   inline size_t byte_offset() const {
     return byte_offset_;
   }
+  /*! \brief return var version of the NDArray*/
+  inline uint32_t version() const {
+    return var()->version();
+  }
   /*!
    * \brief save the content into binary stream
    * \param strm the output stream
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index f4694efad29..ebe82491eb6 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -98,7 +98,12 @@ enum class ExecType {
    *  In current implementation, copy operator is specially handled by executor.
    *  This flag is used for special case treatment and future extension of different copy ops.
    */
-  kCrossDeviceCopy
+  kCrossDeviceCopy,
+  /*!
+   * A subgraph execution should happen in the main thread, instead of
+   * in the execution engine.
+   */
+  kSubgraphExec,
 };
 
 /*! \brief the dispatch mode of the operator */
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index e5e9b522890..2f8c4f5dd6e 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -31,6 +31,7 @@
 #include "./c_api_common.h"
 #include "../operator/operator_common.h"
 #include "../executor/exec_pass.h"
+#include "../operator/subgraph/default_subgraph_op.h"
 
 namespace mxnet {
 namespace op {
@@ -625,3 +626,27 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle,
   *ret_qsym_handle = s;
   API_END_HANDLE_ERROR(delete s);
 }
+
+int MXPartitionGraph(SymbolHandle sym_handle,
+                     const mx_uint num_ops,
+                     const char** op_names,
+                     SymbolHandle* ret_sym_handle) {
+  nnvm::Symbol* s = new nnvm::Symbol();
+  API_BEGIN();
+  std::unordered_set<std::string> op_name_set;
+  for (size_t i = 0; i < num_ops; ++i) {
+    op_name_set.emplace(op_names[i]);
+  }
+  nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
+  *s = sym->Copy();
+  nnvm::Graph g = Symbol2Graph(*s);
+  if (!op_name_set.empty()) {
+    mxnet::op::SubgraphPropertyPtr property
+        = std::make_shared<mxnet::op::DefaultSubgraphProperty>(op_name_set);
+    g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
+  }
+  g = ApplyPass(std::move(g), "PartitionGraph");
+  s->outputs = g.outputs;
+  *ret_sym_handle = s;
+  API_END_HANDLE_ERROR(delete s);
+}
diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h
index b3ec34dc857..9219b91ae2e 100644
--- a/src/engine/engine_impl.h
+++ b/src/engine/engine_impl.h
@@ -33,8 +33,12 @@
 namespace mxnet {
 namespace engine {
 
+#if 0
 /*! \brief base class of engine variables, used for type checking */
 struct Var {
+  virtual uint32_t version() {
+    return version_;
+  }
 #if ENGINE_DEBUG
   virtual ~Var() = default;
 #endif  // ENGINE_DEBUG
@@ -45,7 +49,13 @@ struct Var {
    */
   template <typename T>
   inline T* Cast();
+  /*!
+   * \brief version number of the var. Every time the object it is associated with
+   * is modified, the version number is incremented by 1.
+   */
+  uint32_t version_{0};
 };  // struct Var
+#endif
 
 /*! \brief base class of engine operators, used for type checking */
 struct Opr {
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 8196af2de2f..e0a47fa9951 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -28,10 +28,24 @@
 #include "./engine_impl.h"
 #include "../profiler/profiler.h"
 #include "./openmp.h"
+#include "../common/object_pool.h"
 
 namespace mxnet {
 namespace engine {
 
+/*!
+ * \brief var used in Naive Engine for tracking the version
+ * of the objects it is associated with.
+ */
+class NaiveVar final
+    : public Var, public common::ObjectPoolAllocatable<NaiveVar> {
+ public:
+  inline static NaiveVar* CastFromBase(Var* ptr) {
+    return ptr->Cast<NaiveVar>();
+  }
+};  // class NaiveVar
+
+
 // implement naive engine
 class NaiveEngine final : public Engine {
  public:
@@ -71,8 +85,11 @@ class NaiveEngine final : public Engine {
 
   // new variables
   VarHandle NewVariable() override {
+    return NaiveVar::New();
+#if 0
     size_t v = ++counter_;
     return reinterpret_cast<VarHandle>(v);
+#endif
   }
 
   OprHandle NewOperator(AsyncFn fn,
@@ -165,14 +182,26 @@ class NaiveEngine final : public Engine {
     }
     CHECK(this->req_completed_)
         << "NaiveEngine only support synchronize Push so far";
+    // increment var version
+    for (auto var : mutable_vars) {
+      ++var->version_;
+    }
     if (profiling) {
       opr->opr_profile->stop();
     }
   }
 
   void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override {
+    NaiveVar* naive_var = NaiveVar::CastFromBase(var);
+    this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable {
+        delete_fn(ctx);
+        NaiveVar::Delete(naive_var);
+        on_complete();
+      }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable");
+#if 0
     this->PushSync(delete_fn, exec_ctx, {}, {var},
                    FnProperty::kNormal, 0, "DeleteVariable");
+#endif
   }
 
   void WaitForVar(VarHandle var) override {
@@ -192,8 +221,6 @@ class NaiveEngine final : public Engine {
   }
   // whether action is completed
   bool req_completed_;
-  // counter
-  std::atomic<size_t> counter_{0};
   /*! \brief whether it is during shutdown phase*/
   std::atomic<bool> shutdown_phase_{false};
   // CPU stream
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc197c0c..bd1169768eb 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -130,6 +130,9 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
     assert(pending_write_ != nullptr);
     CHECK_EQ(num_pending_reads_, kWriteTriggered);
 
+    // increment version number
+    ++version_;
+
     // really delete
     if (to_delete_) {
       VersionedVarBlock *head = pending_write_->next;
@@ -164,7 +167,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
   }
   // This is outside of lock scope
   // Be very carful, pending_write_ and num_pending_reads_
-  // can change now, do not reply ont the two variables.
+  // can change now, do not rely on these two variables.
   // The linked list \in [old_pending_write, end_of_read_chain)
   // is already detached from this Var.
   // So it is safe to modify these
@@ -196,6 +199,11 @@ inline bool ThreadedVar::ready_to_read() {
   return this->is_ready_to_read();
 }
 
+inline uint32_t ThreadedVar::version() {
+  std::lock_guard<std::mutex> lock{mutex_};
+  return this->version_;
+}
+
 // implementation of threaded engine
 ThreadedVar* ThreadedEngine::NewVariable() {
   return ThreadedVar::New(VersionedVarBlock::New());
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 428f0d8c554..7730c064b2b 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -162,6 +162,7 @@ class ThreadedVar final
   inline void SetToDelete();
   /*! \return whether this variable is ready to read. */
   inline bool ready_to_read();
+  inline uint32_t version() override;
   /*!
    * \brief Cast a Var pointer to ThreadedVar pointer
    * \param ptr pointer from base.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f90023..ae05fe478e6 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1614,6 +1614,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
       CHECK_EQ(opnode.exec->in_array.size(), 1U);
       CHECK_EQ(opnode.exec->out_array.size(), 1U);
       CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
+    } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) {
+      // If the node contains a subgraph, we can't execute it in the engine.
+      opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
     } else if (opnode.cached_opr != nullptr) {
       bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
       Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 726531d0299..08ea05af0c4 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -434,7 +434,8 @@ inline void PushFComputeEx(const FComputeEx& fn,
       }
     };
 
-  if (exec_type == ExecType::kCrossDeviceCopy) {
+  if (exec_type == ExecType::kCrossDeviceCopy
+      || exec_type == ExecType::kSubgraphExec) {
     run(RunContext{ctx, nullptr});
   } else {
     CHECK(exec_type == ExecType::kSync);
@@ -475,12 +476,18 @@ inline void PushOperator(const OpStatePtr& state,
       InvalidateOutputs(outputs, req);
 #endif
       fcompute_ex(state, opctx, inputs, req, outputs);
-      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
+          && rctx.get_stream<gpu>()) {
         rctx.get_stream<gpu>()->Wait();
       }
     };
 
-    if (exec_type == ExecType::kSync) {
+    // For operators with subgraphs, we need to invoke them in the main thread
+    // instead of the threaded engine.
+    if (exec_type == ExecType::kSubgraphExec) {
+      RunContext rctx{ctx, nullptr};
+      run(rctx, engine::CallbackOnComplete());
+    } else if (exec_type == ExecType::kSync) {
       Engine::Get()->PushSync(
           [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
           ctx, read_vars, write_vars, FnProperty::kNormal, 0,
@@ -519,12 +526,16 @@ inline void PushOperator(const OpStatePtr& state,
         fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
         // post-fcompute fallback, cast to original storage type, if necessary
         CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-        if (is_gpu && exec_type == ExecType::kSync) {
+        if (is_gpu && exec_type == ExecType::kSync
+            && rctx.get_stream<gpu>()) {
           rctx.get_stream<gpu>()->Wait();
         }
       };
 
-    if (exec_type == ExecType::kSync) {
+    if (exec_type == ExecType::kSubgraphExec) {
+      RunContext rctx{ctx, nullptr};
+      run(rctx, engine::CallbackOnComplete());
+    } else if (exec_type == ExecType::kSync) {
       Engine::Get()->PushSync(
           [=](RunContext rctx) {
             run(rctx, engine::CallbackOnComplete());
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 94d3d90413a..583e2bfcfc9 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -39,6 +39,7 @@
 #include "../operator/tensor/matrix_op-inl.h"
 #include "../operator/tensor/init_op.h"
 #include "../operator/nn/mkldnn/mkldnn_base-inl.h"
+#include "../engine/engine_impl.h"
 
 #if MXNET_USE_OPENCV
 #include <opencv2/opencv.hpp>
@@ -2041,6 +2042,7 @@ void NDArray::SyncCheckFormat(const bool full_check) const {
   CHECK_EQ(err, kNormalErr) << "Check the validity of this sparse NDArray";
 }
 
+
 #if MXNET_PREDICT_ONLY == 0
 // register API function
 // those with underscore will be registered at NDArray
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
new file mode 100644
index 00000000000..472312d0a46
--- /dev/null
+++ b/src/operator/subgraph/common.h
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_COMMON_H_
+
+#include <string>
+#include <set>
+#include <vector>
+#include "../elemwise_op_common.h"
+#include "../../executor/exec_pass.h"
+
+namespace mxnet {
+namespace op {
+namespace sg {
+
+struct SimpleNode;
+using SimpleNodePtr = std::shared_ptr<SimpleNode>;
+
+/*!
+ * \brief Node of the undirected graph which replicates the network structures
+ * of the computational graph. It is used to ease the graph traversal for finding
+ * subgraphs.
+ */
+struct SimpleNode {
+  static SimpleNodePtr Create() {
+    return std::make_shared<SimpleNode>();
+  }
+  SimpleNode() : label(-1), node(nullptr) {}
+  /*! subgraph label */
+  int label;
+  /*! the original node in the computational graph it references*/
+  nnvm::Node* node;
+  /*!
+   * \brief output nodes of the current node
+   * key is node ptr and value is an array of indices standing for the entry indices
+   * in key->inputs whose source is the current node.
+   */
+  std::unordered_map<nnvm::Node*, std::vector<size_t>> outputs;
+};  // struct SimpleNode
+}  // namespace sg
+
+inline uint32_t DefaultSubgraphOpNumInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListInputNames(nnvm::Symbol::kAll).size();
+}
+
+inline uint32_t DefaultSubgraphOpNumOutputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListOutputNames().size();
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListInputNames(nnvm::Symbol::kAll);
+}
+
+inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  return sym.ListOutputNames();
+}
+
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+                                   std::vector<TShape> *in_shapes,
+                                   std::vector<TShape> *out_shapes) {
+  using namespace exec;
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
+  CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
+
+  // Put the input and output shapes to the shape vector.
+  nnvm::ShapeVector shapes(idx_g.num_node_entries());
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_shapes->size());
+  for (size_t i = 0; i < in_shapes->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    shapes[eid] = in_shapes->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_shapes->size());
+  for (size_t i = 0; i < out_shapes->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    shapes[eid] = out_shapes->at(i);
+  }
+
+  // Infer shape of the graph.
+  g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+  g = exec::InferShape(std::move(g));
+
+  // Copy the inferred shape back to the input shapes and the output shapes.
+  shapes = g.GetAttr<nnvm::ShapeVector>("shape");
+  // assign to in_shapes
+  for (size_t i = 0; i < in_shapes->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]);
+  }
+  // assign to out_shapes
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
+  }
+  // Check if we have inferred the shapes correctly.
+  return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+                                  std::vector<int> *in_types,
+                                  std::vector<int> *out_types) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
+  CHECK_EQ(idx_g.outputs().size(), out_types->size());
+
+  // Put the input and output data types to the dtype vector.
+  nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_types->size());
+  for (size_t i = 0; i < in_types->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    types[eid] = in_types->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_types->size());
+  for (size_t i = 0; i < out_types->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    types[eid] = out_types->at(i);
+  }
+
+  // Infer data type of the graph.
+  g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
+  g = exec::InferType(std::move(g));
+
+  types = g.GetAttr<nnvm::DTypeVector>("dtype");
+  // assign to in_types
+  for (size_t i = 0; i < in_types->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    TYPE_ASSIGN_CHECK(*in_types, i, types[eid]);
+  }
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
+  }
+  // Check if we have inferred the dtypes correctly.
+  return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
+}
+
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int>* in_stypes,
+                                         std::vector<int>* out_stypes) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  nnvm::Graph g;
+  g.outputs = subgraph_sym.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
+  CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
+  exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
+
+  // Put the input and output storages to the storage vector.
+  StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_stypes->size());
+  for (size_t i = 0; i < in_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    stypes[eid] = in_stypes->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_stypes->size());
+  for (size_t i = 0; i < out_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    stypes[eid] = out_stypes->at(i);
+  }
+
+  // Infer storage type of the graph.
+  bool dev_match = g.attrs.count("dev_mask") &&
+                   g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
+  if (!dev_match) {
+    g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
+  }
+  g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
+  g = exec::InferStorageType(std::move(g));
+
+  stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  // assign to in_types
+  for (size_t i = 0; i < in_stypes->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes[eid]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
+  }
+  // Check if we have inferred the storages correctly.
+  return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
+}
+
+inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+}
+
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
+  const std::vector<std::string> immutable_input_names =
+    subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+  const std::vector<std::string> mutable_input_names =
+    subgraph_sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+  CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size());
+  std::vector<uint32_t> ret;
+  size_t i1 = 0, i2 = 0;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) {
+      ++i1;
+    } else {
+      CHECK(i2 < mutable_input_names.size());
+      CHECK_EQ(input_names[i], mutable_input_names[i2]);
+      ++i2;
+      ret.push_back(i);
+    }
+  }
+  return ret;
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
+  const nnvm::Symbol& subgraph_sym = nnvm::get<nnvm::Symbol>(attrs.parsed);
+  static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
+  std::set<ResourceRequest::Type> resource_types;
+  DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
+    if (!node->is_variable() && fresource.count(node->op())) {
+      for (ResourceRequest& r : fresource[node->op()](node->attrs)){
+        resource_types.insert(r.type);
+      }
+    }
+  });
+  return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
+}
+
+#if 0
+// TODO(junwu): add this attribute for visible outputs
+inline uint32_t DefaultSubgraphOpNumVisibleOutputs(const nnvm::NodeAttrs& attrs) {
+}
+#endif
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_COMMON_H_
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
new file mode 100644
index 00000000000..8372ae9326d
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cc
@@ -0,0 +1,113 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#include <mxnet/ndarray.h>
+#include "./default_subgraph_op.h"
+#include "../../imperative/imperative_utils.h"
+#include "../../imperative/cached_op.h"
+
+namespace mxnet {
+namespace op {
+
+#define SUBGRAPH_DEBUG 1
+
+class DefaultSubgraphOperator {
+ public:
+  explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
+    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}}));
+  }
+
+  void Forward(const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs);
+  void Backward(const OpContext& ctx,
+                const std::vector<NDArray>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<NDArray>& outputs) {
+    LOG(FATAL) << "Not implemented";
+  }
+
+ private:
+  nnvm::Symbol subgraph_sym_;
+  CachedOpPtr subgraph_exec_;
+};
+
+void DefaultSubgraphOperator::Forward(const OpContext& ctx,
+                                      const std::vector<NDArray>& inputs,
+                                      const std::vector<OpReqType>& req,
+                                      const std::vector<NDArray>& outputs) {
+  std::vector<NDArray> tmp_inputs = inputs;
+  std::vector<NDArray*> input_ptrs;
+  input_ptrs.reserve(inputs.size());
+  for (auto& nd : tmp_inputs) {
+    input_ptrs.push_back(&nd);
+  }
+  std::vector<NDArray> tmp_outputs = outputs;
+  std::vector<NDArray*> output_ptrs;
+  for (auto& nd : tmp_outputs) {
+    output_ptrs.push_back(&nd);
+  }
+#if SUBGRAPH_DEBUG
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
+  }
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
+  }
+#endif
+  subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
+}
+
+OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
+                                        Context ctx,
+                                        const std::vector<TShape>& in_shapes,
+                                        const std::vector<int>& in_types) {
+  const Symbol& subgraph_sym = nnvm::get<Symbol>(attrs.parsed);
+  return OpStatePtr::Create<DefaultSubgraphOperator>(subgraph_sym);
+}
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+  DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
+  op.Forward(ctx, inputs, req, outputs);
+}
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
+.set_num_inputs(DefaultSubgraphOpNumInputs)
+.set_num_outputs(DefaultSubgraphOpNumOutputs)
+.set_attr<nnvm::FListInputNames>("FListInputNames", DefaultSubgraphOpListInputs)
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
+.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
+.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
+.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DefaultSubgraphOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
+.set_attr<FResourceRequest>("FResourceRequest", DefaultSubgraphOpResourceRequest)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
new file mode 100644
index 00000000000..15a76e3bbb0
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.cu
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file subgraph_op.cu
+ * \brief GPU Implementation of subgraph operations
+ */
+
+#include "./default_subgraph_op.h"
+
+namespace mxnet {
+namespace op {
+
+void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs);
+
+NNVM_REGISTER_OP(_default_subgraph_op)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", DefaultSubgraphOpForward);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.h b/src/operator/subgraph/default_subgraph_op.h
new file mode 100644
index 00000000000..7d6624ef14d
--- /dev/null
+++ b/src/operator/subgraph/default_subgraph_op.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
+#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
+
+#include <vector>
+#include <string>
+#include "./common.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * This provides criteria for selecting nodes in a subgraph.
+ * When a node is passed to this object, the selection criteria may be changed.
+ * We can also specify what links we should use when traversing the neighbor
+ * nodes.
+ */
+class SubgraphSelector {
+ public:
+  virtual ~SubgraphSelector() {
+  }
+  // Determine if the node should be selected for a subgraph.
+  virtual bool Select(const nnvm::Node &n) = 0;
+  // Determine if the input node should be selected for a subgraph.
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+  // Determine if the output node should be selected for a subgraph.
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) = 0;
+  // Post processes pre-selected subgraph nodes. Return a list of nodes that
+  // users want to keep in subgraph(s).
+  virtual std::vector<nnvm::Node*> Filter(nnvm::Graph* g,
+                                          const std::vector<nnvm::Node*>& candidates) {
+    return candidates;
+  }
+};
+
+using SubgraphSelectorPtr = std::shared_ptr<SubgraphSelector>;
+
+/*!
+ * \brief This provides a set of properties for partitioning a graph into subgraphs,
+ * reconstructing a new graph from the subgraphs and creating a subgraph
+ * operator to execute the subgraph.
+ */
+class SubgraphProperty {
+ public:
+  // the criteria of selecting the subgraph nodes.
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0;
+  // create an nnvm node for a given subgraph. Here users can customize how to
+  // execute the operators in the subgraph.
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s,
+                                           const int subgraph_id = 0) const = 0;
+};
+
+using SubgraphPropertyPtr = std::shared_ptr<SubgraphProperty>;
+
+void RegisterSubgraphProperty(SubgraphPropertyPtr property);
+
+/*
+ * This selects nodes for a subgraph that only contains operators
+ * in a given set and it visits nodes via both input and output links.
+ */
+class ContainOpSelector: public SubgraphSelector {
+  std::shared_ptr<const std::unordered_set<std::string>> op_names;
+
+ public:
+  explicit ContainOpSelector(std::shared_ptr<const std::unordered_set<std::string>> op_names) {
+    this->op_names = op_names;
+  }
+
+  virtual bool Select(const nnvm::Node &n) {
+    return !n.is_variable() && op_names->count(n.op()->name);
+  }
+
+  virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names->count(new_node.op()->name);
+  }
+
+  virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
+    return !new_node.is_variable() && op_names->count(new_node.op()->name);
+  }
+};
+
+/*
+ * This subgraph property finds a subgraph whose nodes have only operators
+ * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ */
+class DefaultSubgraphProperty: public SubgraphProperty {
+ public:
+  explicit DefaultSubgraphProperty(const std::unordered_set<std::string> &op_names) :
+    op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {}
+  virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                           const int subgraph_id = 0) const {
+    nnvm::NodePtr n = nnvm::Node::Create();
+    n->attrs.op = Op::Get("_default_subgraph_op");
+    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+    n->attrs.parsed = sym;
+    return n;
+  }
+  virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
+    return std::make_shared<ContainOpSelector>(op_names_);
+  }
+
+ private:
+  std::shared_ptr<const std::unordered_set<std::string>> op_names_;
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc
new file mode 100644
index 00000000000..11af49ac663
--- /dev/null
+++ b/src/operator/subgraph/partition_graph.cc
@@ -0,0 +1,687 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+#include <queue>
+
+#include "./default_subgraph_op.h"
+#include "./common.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+// TODO(junwu): Change this to 0
+#define SUBGRAPH_DEBUG 1
+
+namespace sg {  // sg stands for subgraph
+
+#if SUBGRAPH_DEBUG
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+  std::string op_names = "";
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    op_names += simple_nodes[i]->node->attrs.name + ' ';
+  }
+  LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+  std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+    + ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
+  LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+  for (size_t i = 0; i < entries.size(); ++i) {
+    PrintNodeEntry(*entries[i]);
+  }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+                       std::vector<SimpleNodePtr>* simple_nodes) {
+  const auto& indexed_graph = g.indexed_graph();
+  simple_nodes->reserve(indexed_graph.num_nodes());
+  DFSVisit(g.outputs, [&](const NodePtr& node) {
+    SimpleNodePtr sn = SimpleNode::Create();
+    sn->node = node.get();
+    for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+      const auto& e = sn->node->inputs[i];
+      const auto input_nid = indexed_graph.node_id(e.node.get());
+      CHECK_LT(input_nid, simple_nodes->size());
+      auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+      auto it = input_node_outputs.find(sn->node);
+      if (it == input_node_outputs.end()) {
+        input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+      } else {
+        it->second.push_back(i);
+      }
+    }
+    simple_nodes->emplace_back(std::move(sn));
+  });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+                     const std::vector<SimpleNodePtr>& simple_nodes,
+                     std::vector<nnvm::Node*>* subgraph_nodes) {
+  for (auto n : *subgraph_nodes) {
+    const auto nid = g.indexed_graph().node_id(n);
+    simple_nodes[nid]->label = -1;
+  }
+  subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+                   SubgraphSelectorPtr subgraph_selector,
+                   const int label,
+                   const size_t snid,  // simple node id, this is a seed
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<nnvm::Node*>* subgraph_nodes,
+                   std::unordered_set<const nnvm::Node*>* excluded_nodes = nullptr) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::queue<SimpleNode*> node_queue;
+  if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+    CHECK_EQ(simple_nodes[snid]->label, -1);
+    simple_nodes[snid]->label = label;
+    node_queue.push(simple_nodes[snid].get());
+  }
+  // key: nodes that serve as input/output nodes to the subgraph
+  // value: pair of vectors of nodes in the subgraph. The first vector contains the
+  // output nodes of the key in the subgraph, and the second vector contains the
+  // input ndoes of the key in the subgraph. If both vectors are non-empty,
+  // it means there is a loop between the subgraph and the key node.
+  // When breaking the loop, we want to start removing the node with the largest node id.
+  std::unordered_map<const nnvm::Node*,
+    std::pair<std::vector<const nnvm::Node*>,
+              std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+  while (!node_queue.empty()) {
+    SimpleNode* cur_node = node_queue.front();
+    node_queue.pop();
+    subgraph_nodes->push_back(cur_node->node);
+    // get qualified adjacent input nodes
+    for (auto& e : cur_node->node->inputs) {
+      const bool select_input = (!excluded_nodes || !excluded_nodes->count(e.node.get()))
+        && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+      if (select_input) {
+        // e.node is a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // e.node is an input node of the subgraph
+        non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+      }
+    }
+    // get qualified output nodes
+    for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
+      const bool select_output = (!excluded_nodes || !excluded_nodes->count(it->first))
+          && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+      if (select_output) {
+        // it->first is a subgraph node
+        const auto nid = indexed_graph.node_id(it->first);
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // it->first is an output node of the subgraph
+        non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+      }
+    }
+  }
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  // check whether there is a loop between the subgraph and its input/output nodes
+  int excluded_node_id = -1;
+  for (auto& kv : non_subgraph_node_map) {
+    auto& output_nodes = kv.second.first;
+    auto& input_nodes = kv.second.second;
+    if (!output_nodes.empty() && !input_nodes.empty()) {
+      // there is a loop between kv->first and the subgraph
+      std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+      std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+      const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+                                    indexed_graph.node_id(input_nodes.back()));
+      excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+    }
+  }
+  if (excluded_node_id != -1) {
+    CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+    CHECK_NE(excluded_node_id, static_cast<int>(snid))
+      << "A cycle is found in the computational graph between nodes "
+      << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+      << simple_nodes[snid]->node->attrs.name;
+    excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+    ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+    return false;
+  }
+  std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+  return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+                            SubgraphSelectorPtr subgraph_selector,
+                            const int label,
+                            const size_t snid,
+                            const std::vector<SimpleNodePtr>& simple_nodes,
+                            std::vector<nnvm::Node*>* subgraph_nodes) {
+  std::unordered_set<const nnvm::Node*> excluded_nodes;
+  const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+  size_t count = 0;
+  bool success = false;
+  while (!success && count < max_num_retry) {
+    success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes,
+                            subgraph_nodes, &excluded_nodes);
+    if (!success) {
+      CHECK(!excluded_nodes.empty());
+      std::string excluded_node_names;
+      for (auto node : excluded_nodes) {
+        excluded_node_names += node->attrs.name + ", ";
+      }
+      LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
+                << ". Excluding nodes " << excluded_node_names << "and retrying";
+    }
+    ++count;
+  }
+  if (!success) {
+    LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
+              << simple_nodes[snid]->node->attrs.name << " without success because a loop "
+                  "is always found between the subgraph and some other nodes. Will treat "
+                  "seed node " << simple_nodes[snid]->node->attrs.name
+              << "as a subgraph with one node";
+    CHECK(subgraph_nodes->empty());
+    simple_nodes[snid]->label = label;
+    subgraph_nodes->push_back(simple_nodes[snid]->node);
+  }
+}
+
+/*!
+ * \brief Given a vector of nodes, group them into individual subgraphs
+ * based upon their connectivity.
+ */
+void PostProcessNodeCandidates(const nnvm::Graph& g,
+                               const std::vector<nnvm::Node*>& nodes,
+                               const std::vector<SimpleNodePtr>& simple_nodes,
+                               std::vector<std::vector<SimpleNode*>>* subgraphs,
+                               size_t* subgraph_id) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::unordered_set<nnvm::Node*> node_set(nodes.begin(), nodes.end());
+  auto simple_node_cmp = [&] (const SimpleNode* node1, const SimpleNode* node2) {
+    return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
+  };
+  for (auto node : nodes) {
+    if (!node_set.count(node)) {
+      // The node has been included in a subgraph
+      continue;
+    }
+    std::queue<nnvm::Node*> q;
+    q.push(node);
+    CHECK_EQ(node_set.erase(node), 1U);
+    subgraphs->emplace_back();
+    const auto nid = indexed_graph.node_id(node);
+    simple_nodes[nid]->label = *subgraph_id;
+    subgraphs->back().push_back(simple_nodes[nid].get());
+    while (!q.empty()) {
+      nnvm::Node* cur_node = q.front();
+      q.pop();
+      for (auto& e : cur_node->inputs) {
+        auto in_it = node_set.find(e.node.get());
+        if (in_it != node_set.end()) {
+          q.push(*in_it);
+          const auto in_nid = indexed_graph.node_id(*in_it);
+          simple_nodes[in_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[in_nid].get());
+          node_set.erase(in_it);
+        }
+      }
+      const auto cur_nid = indexed_graph.node_id(cur_node);
+      const SimpleNode* cur_snode = simple_nodes[cur_nid].get();
+      for (const auto& kv : cur_snode->outputs) {
+        const auto out_it = node_set.find(kv.first);
+        if (out_it != node_set.end()) {
+          q.push(*out_it);
+          const auto out_nid = indexed_graph.node_id(*out_it);
+          simple_nodes[out_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[out_nid].get());
+          node_set.erase(out_it);
+        }
+      }
+    }
+    ++(*subgraph_id);
+    std::sort(subgraphs->back().begin(), subgraphs->back().end(), simple_node_cmp);
+  }
+  CHECK(node_set.empty());
+}
+
+/*!
+ * \brief Finds subgraphs with all nodes that meet certain criteria.
+ * All nodes in a subgraph are marked with the same label.
+ */
+void FindSubgraphs(Graph* g,
+                   const SubgraphProperty &subg_prop,
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
+  const auto& indexed_graph = g->indexed_graph();
+  CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  size_t subgraph_id = 0;
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    nnvm::Node* node = simple_nodes[i]->node;
+    auto subgraph_selector = subg_prop.CreateSubgraphSelector();
+    if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+      // pre-select nodes that can be grouped in a subgraph
+      std::vector<nnvm::Node*> preselected_nodes;
+      PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i, simple_nodes,
+                             &preselected_nodes);
+
+      // filter out unqualified pre-selected nodes
+      std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(g, preselected_nodes);
+
+      // make sure filtered_nodes is a subset of preselected_nodes
+      for (const auto n : filtered_nodes) {
+        const auto nit = std::find(preselected_nodes.begin(), preselected_nodes.end(), n);
+        CHECK(nit != preselected_nodes.end())
+          << "Node " << n->attrs.name << " is not found in the pre-selected subgraph nodes."
+             " Please make sure that no new nodes were added in your subgraph"
+             " selector's Filter function";
+      }
+
+      // make sure nodes are sorted
+      std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
+
+      // reset node labels that are not in filtered nodes
+      for (const auto n : preselected_nodes) {
+        const auto nit = std::find(filtered_nodes.begin(), filtered_nodes.end(), n);
+        if (nit == filtered_nodes.end()) {
+          simple_nodes[indexed_graph.node_id(n)]->label = -1;
+        }
+      }
+      // find out subgraphs from the filtered nodes
+      std::vector<std::vector<SimpleNode*>> subgraphs;
+      PostProcessNodeCandidates(*g, filtered_nodes, simple_nodes, &subgraphs, &subgraph_id);
+      if (!subgraphs.empty()) {
+        subgraph_nodes->insert(subgraph_nodes->end(), subgraphs.begin(), subgraphs.end());
+      }
+    }
+  }
+}
+
+/*!
+ * \brief Sorts entries according to their topological order.
+ * Note that entry ids cannot be used to sort entries.
+ * \param entry_top_order_map mapping from entry pointer to its topological position in the graph
+ * \param entries Node entries to be sorted
+ */
+void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+                 std::vector<nnvm::NodeEntry*>* entries) {
+  auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
+    const auto it1 = entry_top_order_map.find(e1);
+    CHECK(it1 != entry_top_order_map.end());
+    const auto it2 = entry_top_order_map.find(e2);
+    CHECK(it2 != entry_top_order_map.end());
+    return it1->second < it2->second;
+  };
+  std::sort(entries->begin(), entries->end(), entry_cmp);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param input_entries input entries of the subgraph
+ */
+
+void FindInputEntries(const Graph& g,
+                      const std::vector<SimpleNodePtr>& simple_nodes,
+                      const std::vector<SimpleNode*>& subgraph_nodes,
+                      const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
+                      std::vector<nnvm::NodeEntry*>* input_entries) {
+  const auto& indexed_graph = g.indexed_graph();
+  int label = -1;
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+    if (label == -1) {
+      label = subgraph_nodes[i]->label;
+    } else {
+      CHECK_EQ(subgraph_nodes[i]->label, label);
+    }
+    auto& inputs = subgraph_nodes[i]->node->inputs;
+    for (size_t j = 0; j < inputs.size(); ++j) {
+      auto& e = inputs[j];
+      if (indexed_graph.exist(e.node.get())) {
+        // e's source node is not a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        // this is a node not belonging to the subgraph
+        if (simple_nodes[nid]->label != label) {
+          input_entries->push_back(&e);
+        }
+      } else {
+        // e's source node is a subgraph node.
+        // In this case, two subgraphs are adjacent.
+        input_entries->push_back(&e);
+      }
+    }
+  }
+  SortEntries(entry_top_order_map, input_entries);
+}
+
+/*!
+ * \brief Given a subgraph, find the output entries of a subgraph.
+ * \param g pointer to the whole graph
+ * \param simple_nods vector of simple nodes in top sorted order
+ * \param subgraph_nodes vector of pointers of simples of a subgraph.
+ * \param entry_top_order_map mapping entry pointer to its top sorted position
+ * \param output_entries output entries of the subgraph
+ */
+void FindOutputEntries(Graph* g,
+                       const std::vector<SimpleNodePtr>& simple_nodes,
+                       const std::vector<SimpleNode*>& subgraph_nodes,
+                       const std::unordered_map<const nnvm::NodeEntry*, size_t>&
+                         entry_top_order_map,
+                       std::vector<nnvm::NodeEntry*>* output_entries) {
+  if (subgraph_nodes.empty()) return;
+  const auto& indexed_graph = g->indexed_graph();
+  int label = -1;
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+    if (label == -1) {
+      label = subgraph_nodes[i]->label;
+    } else {
+      CHECK_EQ(subgraph_nodes[i]->label, label);
+    }
+    for (auto it = subgraph_nodes[i]->outputs.begin();
+         it != subgraph_nodes[i]->outputs.end(); ++it) {
+      if (indexed_graph.exist(it->first)) {
+        // if the output node is a normal graph node (not a subgraph node)
+        const auto nid = indexed_graph.node_id(it->first);
+        // this is a node not belonging to the current subgraph
+        if (simple_nodes[nid]->label != label) {
+          // TODO(zhengda) I need to test this.
+          for (auto idx : it->second) {
+            auto& e = simple_nodes[nid]->node->inputs[idx];
+            output_entries->push_back(&e);
+          }
+        }
+      } else {
+        // if the output node is a subgraph node
+        // two graphs are adjacent
+        for (auto idx : it->second) {
+          output_entries->push_back(&(it->first->inputs[idx]));
+        }
+      }
+    }
+  }
+  // Check if current subgraph contains a node which is the last node
+  // of the whole graph. If so, save its corresponding entry as well.
+  for (size_t i = 0; i < g->outputs.size(); ++i) {
+    auto& entry = g->outputs[i];
+    // The entry might has been updated as an output of
+    // a subgraph node. In this case, no need
+    // to check its source for the current subgraph. Otherwise,
+    // do the following.
+    if (indexed_graph.exist(entry.node.get())) {
+      const auto nid = indexed_graph.node_id(entry.node.get());
+      if (simple_nodes[nid]->label == label) {
+        output_entries->push_back(&entry);
+      }
+    }
+  }
+  SortEntries(entry_top_order_map, output_entries);
+}
+
+/*!
+ * \brief Given a computation graph and a set of input node entries, this function cuts
+ * the node entries and creates new variable nodes as the input nodes of the
+ * subgraph. It returns the nodes that connect to the subgraph directly and
+ * the names of the new variable nodes.
+ */
+void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
+                    std::vector<nnvm::NodeEntry> *orig_entries,
+                    const bool skip_var = false) {
+  orig_entries->resize(input_entries.size());
+  for (size_t i = 0; i < input_entries.size(); ++i) {
+    nnvm::NodeEntry *e = input_entries[i];
+    // If the node is a variable itself, we may want to skip the node.
+    if (e->node->is_variable() && skip_var) {
+      continue;
+    }
+
+    orig_entries->at(i) = *e;
+    nnvm::Symbol sym;
+    sym.outputs.push_back(*e);
+    const auto output_names = sym.ListOutputNames();
+    CHECK_EQ(output_names.size(), 1U);
+    nnvm::NodePtr n = nnvm::CreateVariableNode(output_names[0]);
+    *e = nnvm::NodeEntry{n, 0, 0};
+  }
+}
+
+/*!
+ * \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node
+ * and keep the subgraph in the subgraph node. The input entries and output entries
+ * of the subgraph node are kept in the same order as the subgraph's.
+ */
+void CreateSubgraphNode(Graph* g,
+                        const std::vector<SimpleNodePtr>& simple_nodes,
+                        const std::vector<SimpleNode*>& subgraph_nodes,
+                        const size_t subgraph_id,
+                        std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+#if SUBGRAPH_DEBUG
+  LOG(INFO) << "Searching for input entries...";
+#endif
+  std::vector<nnvm::NodeEntry*> input_entries;
+  FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
+  std::vector<nnvm::NodeEntry> orig_input_entries;
+  // TODO(junwu): Confirm what value to pass to skip_var
+  CutGraphInputs(input_entries, &orig_input_entries, false);
+#if SUBGRAPH_DEBUG
+  PrintNodeEntries(input_entries);
+  LOG(INFO) << "Searching for output entries...";
+#endif
+  std::vector<nnvm::NodeEntry*> output_entries;
+  FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries);
+
+  // Create a subgraph for the subgraph node
+  nnvm::Symbol sym;
+  sym.outputs.resize(output_entries.size());
+  for (size_t i = 0; i < output_entries.size(); ++i) {
+    sym.outputs[i] = *output_entries[i];
+  }
+  const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
+  nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym, subgraph_id);
+
+  // Connect the external nodes to the subgraph node.
+  for (size_t i = 0; i < output_entries.size(); ++i) {
+    *output_entries[i] = nnvm::NodeEntry{n, static_cast<uint32_t>(i), 0};
+  }
+  n->inputs = orig_input_entries;
+  const auto& indexed_graph = g->indexed_graph();
+  for (size_t i = 0; i < n->inputs.size(); ++i) {
+    auto& e = n->inputs[i];
+    // update entry_top_order_map with newly created orig_input_entries
+    auto it = entry_top_order_map->find(input_entries[i]);
+    CHECK(it != entry_top_order_map->end());
+    CHECK_EQ(entry_top_order_map->count(&e), 0U);
+    entry_top_order_map->emplace(&e, it->second);
+    // update input entries' source simple nodes' outputs map
+    nnvm::Node* node = e.node.get();
+    if (indexed_graph.exist(node)) {
+      const auto nid = indexed_graph.node_id(node);
+      SimpleNode* sn = simple_nodes[nid].get();
+      for (SimpleNode* dest_node : subgraph_nodes) {
+        sn->outputs.erase(dest_node->node);
+      }
+      sn->outputs[n.get()].push_back(i);
+    }
+  }
+#if SUBGRAPH_DEBUG
+  PrintNodeEntries(output_entries);
+#endif
+}
+
+}  // namespace sg
+
+/*!
+ * \brief Sort entries of all the nodes' inputs vectors in the topological order.
+ * This is going to be used to sort input/output entries of subgraphs to keep
+ * the topological order unchanged.
+ */
+void TopSortEntries(const Graph& g,
+                    std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
+  CHECK(entry_top_order_map != nullptr);
+  std::unordered_set<const nnvm::Node*> visited;
+  // tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
+  std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
+  auto in_degree = [] (const nnvm::Node* node)->size_t {
+    if (!node) {
+      return 0;
+    }
+    CHECK_EQ(node->control_deps.size(), 0U);
+    return node->inputs.size();
+  };
+  for (auto& e : g.outputs) {
+    nnvm::Node* node = e.node.get();
+    if (visited.count(node) == 0U) {
+      s.emplace(node, 0U, &e);
+      visited.insert(node);
+    }
+    while (!s.empty()) {
+      auto& top = s.top();
+      if (std::get<1>(top) == in_degree(std::get<0>(top))) {
+        // The node's inputs has been exhausted.
+        entry_top_order_map->emplace(std::get<2>(top), entry_top_order_map->size());
+        s.pop();
+      } else {
+        // The node still has input entries not visited.
+        CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
+        auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
+        nnvm::Node* input_node = entry.node.get();
+        if (visited.count(input_node) == 0U) {
+          // The entry's source node has not been visited.
+          // Push the entry to the stack for marking order later.
+          s.emplace(input_node, 0U, &entry);
+          visited.insert(input_node);
+        } else {
+          // The entry's source node has been visited before.
+          // Marking order for it.
+          entry_top_order_map->emplace(&entry, entry_top_order_map->size());
+        }
+      }
+    }
+  }
+}
+
+Graph PartitionGraph(Graph&& g) {
+  if (!g.HasAttr("subgraph_property")) {  // treat the whole graph as a subgraph
+    LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
+                 "The original graph is returned.";
+    return g;
+  }
+  using namespace sg;
+  const SubgraphPropertyPtr& subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
+  // top sort NodeEntry of all the nodes' inputs
+  std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
+  TopSortEntries(g, &entry_top_order_map);
+
+  // Create undirected graph for ease of finding subgraphs
+  std::vector<SimpleNodePtr> simple_nodes;
+  CreateSimpleGraph(g, &simple_nodes);
+  std::vector<std::vector<SimpleNode*>> subgraph_nodes;
+  FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes);
+  for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
+#if SUBGRAPH_DEBUG
+    std::set<SimpleNode*> simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end());
+    CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
+    PrintSubgraph(subgraph_nodes[i]);
+#endif
+    CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i, &entry_top_order_map);
+  }
+  return g;
+}
+
+NNVM_REGISTER_PASS(PartitionGraph)
+.describe("Partition a graph according to the user defined rules "
+          "in a derived class of SubgraphProperty")
+.set_body(PartitionGraph)
+.set_change_graph(true);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index ed4aaa43782..9f20627dae6 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -44,6 +44,7 @@
 from test_sparse_ndarray import *
 from test_sparse_operator import *
 from test_ndarray import *
+from test_subgraph_op import *
 
 set_default_context(mx.gpu(0))
 del test_support_vector_machine_l1_svm
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
index 1c59ceaa093..f8833f88e6f 100644
--- a/tests/python/unittest/test_gluon_trainer.py
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -175,6 +175,7 @@ def test_trainer_save_load():
     # check if parameter dict is correctly associated with optimizer after load_state
     assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
 
+@unittest.skip("temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11353")
 @with_seed()
 def test_trainer_reset_kv():
     def check_trainer_reset_kv(kv):
diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py
new file mode 100644
index 00000000000..f08c42c3063
--- /dev/null
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ctypes
+import mxnet as mx
+from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
+from mxnet.symbol import Symbol
+import numpy as np
+
+
+def test_subgraph():
+    def get_graph():
+        data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32)
+        data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
+        data3 = mx.sym.sin(data2)
+        conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1)
+        rets = []
+        rets.append((conv, []))
+        rets.append((conv, [mx.sym.sin.__name__]))
+        rets.append((conv, [mx.sym.Convolution.__name__]))
+        rets.append((conv, [mx.sym.sin.__name__, mx.sym.Convolution.__name__]))
+        return rets
+
+    for regular_sym, op_names in get_graph():
+        input_names = regular_sym.list_inputs()
+        shapes = regular_sym.infer_shape()
+        types = regular_sym.infer_type()
+        out = SymbolHandle()
+
+        check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)),
+            c_str_array(op_names), ctypes.byref(out)))
+        subgraph_sym = Symbol(out)
+        assert input_names == subgraph_sym.list_inputs()
+
+        print(subgraph_sym.list_outputs())
+        assert shapes == subgraph_sym.infer_shape()
+        assert types == subgraph_sym.infer_type()
+
+        regular_exec = regular_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+        subgraph_exec = subgraph_sym.simple_bind(ctx=mx.cpu(), grad_req='null')
+
+        for name in input_names:
+            regular_exec.arg_dict[name][:] = mx.nd.random.normal(
+                    shape=regular_exec.arg_dict[name].shape)
+            subgraph_exec.arg_dict[name][:] = regular_exec.arg_dict[name]
+
+        subgraph_exec.forward()
+        regular_exec.forward()
+        mx.nd.waitall()
+        assert (subgraph_exec.outputs[0] - regular_exec.outputs[0]).abs().sum().asscalar() == 0.0
+
+
+def test_input_name_order():
+    def check_input_order(sym, op_names):
+        out = SymbolHandle()
+        check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)),
+                                         c_str_array(op_names), ctypes.byref(out)))
+
+        new_sym = Symbol(out)
+        #print(sym.list_inputs())
+        #print(new_sym.list_inputs())
+        assert new_sym.list_inputs() == sym.list_inputs()
+        assert new_sym.list_arguments() == sym.list_arguments()
+        assert new_sym.list_auxiliary_states() == sym.list_auxiliary_states()
+        #print(new_sym.list_arguments())
+        #print(new_sym.list_auxiliary_states())
+        #print('original outputs: %s' % sym.list_outputs())
+        #print('new sym outputs: %s' % new_sym.list_outputs())
+
+    def test_network_structure_1():
+        data1 = mx.sym.var('data1')
+        data2 = mx.sym.var('data2')
+        conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
+        conv2 = mx.sym.Convolution(data=data2, weight=data1, no_bias=True, kernel=(2, 2), num_filter=1)
+        out = mx.sym.Group([conv1, conv2])
+        check_input_order(out, ['Convolution'])
+
+    def test_network_structure_2():
+        data1 = mx.sym.var('data1')
+        data2 = mx.sym.var('data2')
+        conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
+        conv2 = mx.sym.Convolution(data=data2, weight=data1, no_bias=True, kernel=(2, 2), num_filter=1)
+        out = conv1 + conv2
+        check_input_order(out, ['Convolution'])
+        check_input_order(out, ['Convolution', '_Plus', 'elemwise_add', '_plus'])
+
+    def test_network_structure_3():
+        # this tests whether the partitioning algorithm can deal with cycles
+        data = mx.sym.var('data')
+        ret = mx.sym.exp(data)
+        ret1 = mx.sym.cos(ret)
+        ret2 = mx.sym.sin(ret)
+        ret = ret1 + ret2
+        check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+        check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+
+    def test_network_structure_4():
+        # this tests whether the partitioned sym can distinguish in_args and aux_states
+        data = mx.sym.var('data')
+        ret = mx.sym.exp(data)
+        ret1 = mx.sym.cos(ret)
+        ret2 = mx.sym.sin(ret)
+        ret = ret1 + ret2
+        ret = mx.sym.BatchNorm(ret)
+        ret = mx.sym.BatchNorm(ret)
+        check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus'])
+        check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus'])
+        check_input_order(ret, ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+        check_input_order(ret, ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm'])
+        check_input_order(ret, ['exp', 'BatchNorm'])
+        check_input_order(ret, ['BatchNorm'])
+
+    test_network_structure_1()
+    test_network_structure_2()
+    test_network_structure_3()
+    test_network_structure_4()
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services