You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/09/23 21:45:56 UTC

[incubator-mxnet] branch master updated: [MXNET-876] make CachedOp a normal operator (#11641)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3caf2ca  [MXNET-876] make CachedOp a normal operator (#11641)
3caf2ca is described below

commit 3caf2cacfaae17595713d060061822655d7c5fcd
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Sun Sep 23 14:45:39 2018 -0700

    [MXNET-876] make CachedOp a normal operator (#11641)
    
    * extend _CachedOp a regular operator.
    
    * use default subgraph infer.
    
    * fix test.
    
    * fix compilation error.
    
    * use default subgraph stuff.
    
    * add comments.
    
    * fix.
    
    * use a more general InferStorage.
    
    * use cachedOp as default subgraph operator.
    
    * remove default subgraph op.
    
    * fix.
    
    * fix.
    
    * rename.
    
    * add comment.
    
    * retrigger
    
    * add comments.
---
 src/imperative/cached_op.cc                        | 259 ++++++++++++++++++---
 src/imperative/cached_op.h                         |  28 ++-
 src/operator/operator_common.h                     |  12 +-
 src/operator/subgraph/common.h                     |  62 +++--
 src/operator/subgraph/default_subgraph_op.cc       | 112 ---------
 src/operator/subgraph/default_subgraph_op.cu       |  44 ----
 src/operator/subgraph/default_subgraph_property.cc |  11 +-
 tests/python/unittest/test_subgraph.py             | 149 ++++++++++++
 8 files changed, 453 insertions(+), 224 deletions(-)

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 0c4c1e6..1f115cd 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -23,6 +23,7 @@
 #include "../executor/exec_pass.h"
 #include "../profiler/profiler.h"
 #include "../operator/operator_common.h"
+#include "../operator/subgraph/common.h"
 
 
 namespace mxnet {
@@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward(
   return op_state;
 }
 
-
 void CachedOp::DynamicBackward(
     const bool retain_graph,
     const OpStatePtr& op_state,
@@ -1067,34 +1067,153 @@ void CachedOp::Backward(
   Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
-bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
-                                  const int dev_mask,
-                                  DispatchMode* dispatch_mode,
-                                  std::vector<int> *in_attrs,
-                                  std::vector<int> *out_attrs) {
-  using namespace imperative;
-  nnvm::Graph g(fwd_graph_);
-  const auto& idx = g.indexed_graph();
-  const auto &outputs = idx.outputs();
+/*
+ * This is the operator state of CachedOp when CachedOp is used in the symbol
+ * executor. This is different from the OpState returned by CachedOp::Forward.
+ * The main reason why we need this OpState is that CachedOp and the symbol executor
+ * maintain OpState differently. The symbol executor generates OpState in advance
+ * while CachedOp generates OpState after Forward is called. We need this data
+ * structure to keep the OpState generated by CachedOp::Forward and pass it to
+ * Backward.
+ */
+struct CachedOpActualState {
+  std::shared_ptr<CachedOp> op;
+  OpStatePtr forward_state;
 
-  // Prepare stypes and contexts based on inputs
-  StorageTypeVector storage_type_inputs;
-  storage_type_inputs.reserve(in_attrs->size());
-  for (size_t i = 0; i < in_attrs->size(); ++i) {
-    storage_type_inputs.emplace_back(in_attrs->at(i));
+  explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
+    this->op = op;
   }
-  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+};
 
-  // Forward graph storage type inference
-  CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
-  // Retrieve result and set outputs
-  const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
-  for (size_t i = 0; i < out_attrs->size(); i++) {
-    const auto eid = idx.entry_id(outputs[i]);
-    STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+/*
+ * This is the forward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpForward(const OpStatePtr& state_ptr,
+                     const OpContext& ctx,
+                     const std::vector<NDArray>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<NDArray>& outputs) {
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs(in_bufs.size());
+  std::vector<NDArray *> out_ptrs(out_bufs.size());
+  for (size_t i = 0; i < in_ptrs.size(); i++)
+    in_ptrs[i] = &in_bufs[i];
+  for (size_t i = 0; i < out_ptrs.size(); i++)
+    out_ptrs[i] = &out_bufs[i];
+
+  // Set is_recording correct for the imperative executor.
+  bool orig_is_record;
+  if (ctx.need_grad)
+    orig_is_record = Imperative::Get()->set_is_recording(true);
+  else
+    orig_is_record = Imperative::Get()->is_recording();
+  // Set is_training correct for the imperative executor.
+  bool orig_is_train;
+  if (ctx.is_train)
+    orig_is_train = Imperative::Get()->set_is_training(true);
+  else
+    orig_is_train = Imperative::Get()->is_training();
+  s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
+  Imperative::Get()->set_is_training(orig_is_train);
+  Imperative::Get()->set_is_recording(orig_is_record);
+  // The arrays in out_ptrs may be changed by CachedOp.
+  // If it is, we need to copy data back.
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    if (!out_bufs[i].IsSame(outputs[i]))
+      CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+/*
+ * This is the backward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpBackward(const OpStatePtr& state_ptr,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs;
+  std::vector<NDArray *> out_ptrs;
+  CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
+  in_ptrs.reserve(s.op->num_backward_inputs());
+  out_ptrs.reserve(s.op->num_inputs());
+
+  const std::vector<bool> &save_inputs = s.op->save_inputs();
+  const std::vector<bool> &save_outputs = s.op->save_outputs();
+  size_t bwd_in_dep = s.op->num_inputs();
+  size_t bwd_out_dep = s.op->num_outputs();
+  CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
+  size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep;
+
+  // Find inputs, outputs and ograds
+  auto ograds_begin = in_bufs.begin();
+  auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
+  auto in_begin = ograds_end;
+  auto in_end = in_begin + bwd_in_dep;
+  auto out_begin = in_end;
+  auto out_end = in_bufs.end();
+
+  for (auto it = ograds_begin; it != ograds_end; it++)
+    in_ptrs.push_back(&(*it));
+
+  CHECK_EQ(save_inputs.size(), in_end - in_begin);
+  CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
+  for (auto it = in_begin; it != in_end; it++) {
+    auto i = it - in_begin;
+    if (save_inputs[i])
+      in_ptrs.push_back(&(*it));
   }
-  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
-  return true;
+  for (auto it = out_begin; it != out_end; it++) {
+    auto i = it - out_begin;
+    if (save_outputs[i])
+      in_ptrs.push_back(&(*it));
+  }
+  CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    out_ptrs.push_back(&out_bufs[i]);
+  CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
+  // Set is_training correct for the imperative executor.
+  bool orig_is_train;
+  if (ctx.is_train)
+    orig_is_train = Imperative::Get()->set_is_training(true);
+  else
+    orig_is_train = Imperative::Get()->is_training();
+  // TODO(zhengda) CachedOp supports recording computation when running
+  // the backward path. This is necessary if we want to support the second-order
+  // differentiation. However, MXNet operator doesn't have an interface to
+  // pass a flag to determine whether to record computation inside an operator.
+  // Let's use false here for now and design a solution when the second-order
+  // differentiation is supported.
+  s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
+  Imperative::Get()->set_is_training(orig_is_train);
+
+  // Clean up what we recorded.
+  s.forward_state.reset();
+
+  // The arrays in out_ptrs may be changed by CachedOp.
+  // If it is, we need to copy data back.
+  // For example, when the inputs and outputs share the same NDArrays,
+  // the outputs will be replaced by inputs.
+  // https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    if (!out_bufs[i].IsSame(outputs[i]))
+      CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
+                               Context ctx,
+                               const std::vector<TShape>& in_shapes,
+                               const std::vector<int>& in_types) {
+  const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+  return OpStatePtr::Create<CachedOpActualState>(op);
 }
 
 bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
@@ -1143,6 +1262,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
+  CachedOpConfig param;
+  try {
+    param.Init(attrs->dict);
+  } catch (const dmlc::ParamError& e) {
+    std::ostringstream os;
+    os << e.what();
+    os << ", in operator " << attrs->op->name << "("
+       << "name=\"" << attrs->name << "\"";
+    for (const auto& k : attrs->dict) {
+      os << ", " << k.first << "=\"" << k.second << "\"";
+    }
+    os << ")";
+    throw dmlc::ParamError(os.str());
+  }
+  if (!param.subgraph.empty()) {
+    nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph);
+    CHECK(!g.outputs.empty());
+    nnvm::Symbol sym;
+    sym.outputs = g.outputs;
+    std::vector<std::pair<std::string, std::string> > flags;
+    for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++)
+      flags.emplace_back(it->first, it->second);
+    attrs->parsed = CachedOpPtr(new CachedOp(sym, flags));
+  }
+}
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
@@ -1153,19 +1298,62 @@ NNVM_REGISTER_OP(_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->num_outputs();
   })
-.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
-                                                     const int dev_mask,
-                                                     DispatchMode* dispatch_mode,
-                                                     std::vector<int> *in_attrs,
-                                                     std::vector<int> *out_attrs) {
-    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
-    return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
-  })
+.set_attr_parser(CachedOpParamParser)
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
     return op->Gradient(n, ograds);
-  });
+  })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->ListForwardInputNames();
+  })
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->ListForwardOutputNames();
+  })
+.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState)
+.set_attr<nnvm::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_shapes,
+     std::vector<TShape> *out_shapes) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes);
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int> *in_types,
+     std::vector<int> *out_types) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types);
+  })
+.set_attr<FInferStorageType>("FInferStorageType",
+  [](const nnvm::NodeAttrs& attrs,
+     const int dev_mask,
+     DispatchMode* dispatch_mode,
+     std::vector<int>* in_stypes,
+     std::vector<int>* out_stypes) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(),
+                                                  dev_mask, dispatch_mode,
+                                                  in_stypes, out_stypes);
+  })
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym());
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym());
+  })
+.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
 
 NNVM_REGISTER_OP(_backward_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs){
@@ -1184,6 +1372,9 @@ NNVM_REGISTER_OP(_backward_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
   })
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward)
+.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true);
 
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 4f4dfdc..59a793e 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
   bool static_shape;
   nnvm::Tuple<uint32_t> data_indices;
   nnvm::Tuple<uint32_t> param_indices;
+  std::string subgraph;
   DMLC_DECLARE_PARAMETER(CachedOpConfig) {
     DMLC_DECLARE_FIELD(static_alloc)
     .set_default(false)
@@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
     DMLC_DECLARE_FIELD(param_indices)
     .set_default(nnvm::Tuple<uint32_t>())
     .describe("Position of parameters.");
+    DMLC_DECLARE_FIELD(subgraph)
+    .set_default(std::string(""))
+    .describe("JSON string of a subgraph.");
   }
 };
 
@@ -80,6 +84,10 @@ class CachedOp {
   uint32_t num_backward_inputs() const {
     return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
   }
+  uint32_t num_backward_outputs() const {
+    auto &idx = fwd_graph_.indexed_graph();
+    return idx.input_nodes().size() - idx.mutable_input_nodes().size();
+  }
   std::vector<bool>& save_inputs() {
     return save_inputs_;
   }
@@ -102,13 +110,6 @@ class CachedOp {
       const std::vector<NDArray*>& inputs,
       const std::vector<OpReqType>& reqs,
       const std::vector<NDArray*>& outputs);
-  // forward storage type inference
-  bool ForwardStorageType(
-      const nnvm::NodeAttrs& attrs,
-      const int dev_mask,
-      DispatchMode* dispatch_mode,
-      std::vector<int> *in_attrs,
-      std::vector<int> *out_attrs);
   // backward storage type inference
   bool BackwardStorageType(
       const nnvm::NodeAttrs& attrs,
@@ -116,6 +117,19 @@ class CachedOp {
       DispatchMode* dispatch_mode,
       std::vector<int> *in_attrs,
       std::vector<int> *out_attrs);
+  std::vector<std::string> ListForwardInputNames() const {
+    nnvm::Symbol sym = GetForwardSym();
+    return sym.ListInputNames(nnvm::Symbol::kAll);
+  }
+  std::vector<std::string> ListForwardOutputNames() const {
+    nnvm::Symbol sym = GetForwardSym();
+    return sym.ListOutputNames();
+  }
+  nnvm::Symbol GetForwardSym() const {
+    nnvm::Symbol sym;
+    sym.outputs = fwd_graph_.outputs;
+    return sym;
+  }
 
  private:
   struct GraphInfo;
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 2911293..6a4c3d0 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -221,7 +221,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
  */
 #define SHAPE_ASSIGN_CHECK(shape_array, index, shape)                       \
   {                                                                         \
-    if (!shape_assign(&(shape_array)[index], TShape(shape))) {              \
+    if (!::mxnet::op::shape_assign(&(shape_array)[index], TShape(shape))) { \
       std::ostringstream os;                                                \
       os << "Shape inconsistent, Provided = " << (shape_array)[index] << ','\
          << " inferred shape=" << shape;                                    \
@@ -238,11 +238,11 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
  */
 #define TYPE_ASSIGN_CHECK(type_array, index, type)                          \
   {                                                                         \
-    if (!type_assign(&(type_array)[index], type)) {                         \
+    if (!::mxnet::op::type_assign(&(type_array)[index], type)) {            \
       std::ostringstream os;                                                \
       os << "Type inconsistent, Provided = "                                \
-         << type_string((type_array)[index]) << ','                         \
-         << " inferred type = " << type_string(type);                       \
+         << ::mxnet::op::type_string((type_array)[index]) << ','            \
+         << " inferred type = " << ::mxnet::op::type_string(type);          \
       throw ::mxnet::op::InferTypeError(os.str(), index);                   \
     }                                                                       \
   }
@@ -291,8 +291,8 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
 #define UNIFORM_TYPE_CHECK(type, expected, arg)                         \
   {                                                                     \
     CHECK_EQ(type, expected) << "This layer requires uniform type. "    \
-                             << "Expected '" << type_string(expected)   \
-                             << "' v.s. given '" << type_string(type)   \
+                             << "Expected '" << ::mxnet::op::type_string(expected)   \
+                             << "' v.s. given '" << ::mxnet::op::type_string(type)   \
                              << "' at '" << arg << "'";                 \
   }
 
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
index 22058d5..4e1cd66 100644
--- a/src/operator/subgraph/common.h
+++ b/src/operator/subgraph/common.h
@@ -49,11 +49,10 @@ inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAtt
   return sym.ListOutputNames();
 }
 
-inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
-                                   std::vector<TShape> *in_shapes,
-                                   std::vector<TShape> *out_shapes) {
+inline bool DefaultSubgraphOpShapeHelper(const nnvm::Symbol& subgraph_sym,
+                                         std::vector<TShape> *in_shapes,
+                                         std::vector<TShape> *out_shapes) {
   using namespace exec;
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -94,10 +93,15 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
   return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
 }
 
-inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
-                                  std::vector<int> *in_types,
-                                  std::vector<int> *out_types) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+                                   std::vector<TShape> *in_shapes,
+                                   std::vector<TShape> *out_shapes) {
+  return DefaultSubgraphOpShapeHelper(*attrs.subgraphs[0], in_shapes, out_shapes);
+}
+
+inline bool DefaultSubgraphOpTypeHelper(const nnvm::Symbol& subgraph_sym,
+                                        std::vector<int> *in_types,
+                                        std::vector<int> *out_types) {
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -137,12 +141,17 @@ inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
   return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
 }
 
-inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
-                                         const int dev_mask,
-                                         DispatchMode* dispatch_mode,
-                                         std::vector<int>* in_stypes,
-                                         std::vector<int>* out_stypes) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+                                  std::vector<int> *in_types,
+                                  std::vector<int> *out_types) {
+  return DefaultSubgraphOpTypeHelper(*attrs.subgraphs[0], in_types, out_types);
+}
+
+inline bool DefaultSubgraphOpStorageTypeHelper(const nnvm::Symbol& subgraph_sym,
+                                               const int dev_mask,
+                                               DispatchMode* dispatch_mode,
+                                               std::vector<int>* in_stypes,
+                                               std::vector<int>* out_stypes) {
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -190,12 +199,21 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
   return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
 }
 
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int>* in_stypes,
+                                         std::vector<int>* out_stypes) {
+  return DefaultSubgraphOpStorageTypeHelper(*attrs.subgraphs[0], dev_mask, dispatch_mode,
+                                            in_stypes, out_stypes);
+}
+
 inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
   return ExecType::kSubgraphExec;
 }
 
-inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputsHelper(
+  const nnvm::Symbol& subgraph_sym) {
   const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
   const std::vector<std::string> immutable_input_names =
     subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
@@ -217,8 +235,12 @@ inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr
   return ret;
 }
 
-inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
+  return DefaultSubgraphOpMutableInputsHelper(*attrs.subgraphs[0]);
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequestHelper(
+    const nnvm::Symbol& subgraph_sym) {
   static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
   std::set<ResourceRequest::Type> resource_types;
   DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
@@ -231,6 +253,10 @@ inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm:
   return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
 }
 
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
+  return DefaultSubgraphOpResourceRequestHelper(*attrs.subgraphs[0]);
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
deleted file mode 100644
index d5fb7ee..0000000
--- a/src/operator/subgraph/default_subgraph_op.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*/
-
-#include <mxnet/ndarray.h>
-#include "./common.h"
-#include "../../imperative/imperative_utils.h"
-#include "../../imperative/cached_op.h"
-
-namespace mxnet {
-namespace op {
-
-#define DEBUG_SUBGRAPH 0
-
-class DefaultSubgraphOperator {
- public:
-  explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
-    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
-                                            {"static_shape", "true"}}));
-  }
-
-  void Forward(const OpContext& ctx,
-               const std::vector<NDArray>& inputs,
-               const std::vector<OpReqType>& req,
-               const std::vector<NDArray>& outputs);
-  void Backward(const OpContext& ctx,
-                const std::vector<NDArray>& inputs,
-                const std::vector<OpReqType>& req,
-                const std::vector<NDArray>& outputs) {
-    LOG(FATAL) << "Not implemented";
-  }
-
- private:
-  nnvm::Symbol subgraph_sym_;
-  CachedOpPtr subgraph_exec_;
-};
-
-void DefaultSubgraphOperator::Forward(const OpContext& ctx,
-                                      const std::vector<NDArray>& inputs,
-                                      const std::vector<OpReqType>& req,
-                                      const std::vector<NDArray>& outputs) {
-  std::vector<NDArray> tmp_inputs = inputs;
-  std::vector<NDArray*> input_ptrs;
-  input_ptrs.reserve(inputs.size());
-  for (auto& nd : tmp_inputs) {
-    input_ptrs.push_back(&nd);
-  }
-  std::vector<NDArray> tmp_outputs = outputs;
-  std::vector<NDArray*> output_ptrs;
-  for (auto& nd : tmp_outputs) {
-    output_ptrs.push_back(&nd);
-  }
-#if DEBUG_SUBGRAPH
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
-  }
-  for (size_t i = 0; i < outputs.size(); ++i) {
-    LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
-  }
-#endif
-  subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
-}
-
-OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
-                                        Context ctx,
-                                        const std::vector<TShape>& in_shapes,
-                                        const std::vector<int>& in_types) {
-  return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
-}
-
-void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs) {
-  DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
-  op.Forward(ctx, inputs, req, outputs);
-}
-
-NNVM_REGISTER_OP(_default_subgraph_op)
-.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
-.set_num_inputs(DefaultSubgraphOpNumInputs)
-.set_num_outputs(DefaultSubgraphOpNumOutputs)
-.set_attr<nnvm::FListInputNames>("FListInputNames", DefaultSubgraphOpListInputs)
-.set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
-.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
-.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
-.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
-.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DefaultSubgraphOpForward)
-.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
-.set_attr<std::string>("key_var_num_args", "num_args")
-.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
-.add_argument("data", "NDArray-or-Symbol[]", "input data list");
-
-}  // namespace op
-}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
deleted file mode 100644
index 008826b..0000000
--- a/src/operator/subgraph/default_subgraph_op.cu
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file default_subgraph_op.cu
- * \brief GPU Implementation of subgraph operations
- */
-
-#include <mxnet/ndarray.h>
-#include "./common.h"
-#include "../../imperative/imperative_utils.h"
-#include "../../imperative/cached_op.h"
-
-namespace mxnet {
-namespace op {
-
-void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs);
-
-NNVM_REGISTER_OP(_default_subgraph_op)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", DefaultSubgraphOpForward);
-
-}  // namespace op
-}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc
index c8d3e9f..0152344 100644
--- a/src/operator/subgraph/default_subgraph_property.cc
+++ b/src/operator/subgraph/default_subgraph_property.cc
@@ -21,6 +21,7 @@
 #include <string>
 #include "./common.h"
 #include "./subgraph_property.h"
+#include "../../imperative/cached_op.h"
 
 namespace mxnet {
 namespace op {
@@ -51,7 +52,7 @@ class ContainOpSelector: public SubgraphSelector {
 
 /*
  * This subgraph property finds a subgraph whose nodes have only operators
- * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ * within a set. The operators in the subgraph will be executed by _CachedOp.
  */
 class DefaultSubgraphProperty: public SubgraphProperty {
  public:
@@ -59,9 +60,13 @@ class DefaultSubgraphProperty: public SubgraphProperty {
   virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
                                            const int subgraph_id = 0) const {
     nnvm::NodePtr n = nnvm::Node::Create();
-    n->attrs.op = Op::Get("_default_subgraph_op");
-    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+    n->attrs.op = Op::Get("_CachedOp");
+    n->attrs.name = "_CachedOp" + std::to_string(subgraph_id);
     n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+
+    std::vector<std::pair<std::string, std::string> > flags{{"static_alloc", "true"}};
+    n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags));
+
     return n;
   }
   virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py
new file mode 100644
index 0000000..b5577d4
--- /dev/null
+++ b/tests/python/unittest/test_subgraph.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from __future__ import print_function
+import numpy as np
+import mxnet as mx
+import copy
+import math
+import ctypes
+import random
+import itertools
+from numpy.testing import assert_allclose, assert_array_equal
+from mxnet.test_utils import *
+from mxnet.base import py_str, MXNetError, _as_list, SymbolHandle, check_call, _LIB, c_handle_array, mx_uint
+from common import setup_module, with_seed, teardown
+import unittest
+from mxnet.gluon.model_zoo.vision import get_model
+
+def make_subgraph(subg, *args):
+    js = subg.tojson()
+    return mx.sym._internal._CachedOp(*args, subgraph=js)
+
+@with_seed()
+def test_make_subgraph():
+    def make_subgraph1(stype):
+        a = mx.symbol.Variable(name='a', stype=stype)
+        b = mx.symbol.Variable(name='b', stype=stype)
+        c = a * b
+        d = c * 2
+
+        a1 = mx.symbol.Variable(name='a', stype=stype)
+        b1 = mx.symbol.Variable(name='b', stype=stype)
+        y = make_subgraph(c, a1, b1)
+        y = y * 2
+
+        s = (10, 10)
+        a_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
+                ctx=default_context()).tostype(stype)
+        b_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
+                ctx=default_context()).tostype(stype)
+        return (d, y, {'a': a_arr, 'b': b_arr}, {})
+
+    def create_weights(shapes, names):
+        nd_dict = {}
+        sym_dict = {}
+        assert len(shapes) == len(names)
+        for i in range(len(shapes)):
+            sym_dict[names[i]] = mx.symbol.Variable(names[i])
+            nd_dict[names[i]] = mx.nd.array(np.ones(shapes[i]), ctx=default_context())
+        return (nd_dict, sym_dict)
+
+    def make_subgraph_weight(orig, shape, stype):
+        arg_shapes, out_shapes, aux_shapes = orig.infer_shape(data=shape)
+        weight_shapes = arg_shapes[1:]
+        weight_names = orig.list_arguments()[1:]
+        weight_dict, weight_sym_dict = create_weights(weight_shapes, weight_names)
+        aux_dict, aux_sym_dict = create_weights(aux_shapes, orig.list_auxiliary_states())
+
+        input_dict = copy.deepcopy(weight_sym_dict)
+        input_dict.update(aux_sym_dict)
+        input_dict['data'] = mx.symbol.Variable('data', stype=stype)
+        input_list = []
+        for name in orig.list_inputs():
+            assert name in input_dict.keys()
+            input_list.append(input_dict[name])
+        subg = make_subgraph(orig, *input_list)
+
+        arr = mx.nd.random.uniform(-1, 1, shape=shape, ctx=default_context()).tostype(stype)
+        arg_dict = weight_dict
+        arg_dict['data'] = arr
+        return (orig, subg, arg_dict, aux_dict)
+
+    def make_subgraph2(stype, out_mean_var):
+        data = mx.symbol.Variable('data', stype=stype)
+        orig = mx.symbol.BatchNorm(data, fix_gamma=False,
+                output_mean_var=out_mean_var, name="batchnorm")
+        s = (10, 10)
+        return make_subgraph_weight(orig, s, stype)
+
+    def make_subgraph3(stype):
+        data = mx.symbol.Variable('data', stype=stype)
+        conv1 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
+        bn1 = mx.symbol.BatchNorm(conv1, fix_gamma=False, output_mean_var=False)
+        conv2 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
+        bn2 = mx.symbol.BatchNorm(conv2, fix_gamma=False, output_mean_var=False)
+        orig = bn1 + bn2
+        s = (1, 3, 32, 32)
+        return make_subgraph_weight(orig, s, stype)
+
+    def make_subgraph4(stype):
+        model = get_model('resnet18_v1')
+        model.hybridize()
+        model.initialize()
+        s = (1, 3, 32, 32)
+        data = mx.nd.random.normal(shape=s)
+        out = model(data)
+        model.export('resnet18')
+        orig = mx.sym.load('resnet18-symbol.json')
+        return make_subgraph_weight(orig, s, stype)
+
+    make_subgraphs = [make_subgraph1,
+            lambda stype: make_subgraph2(stype, False),
+            lambda stype: make_subgraph2(stype, True),
+            make_subgraph3, make_subgraph4]
+    stypes = ['default', 'row_sparse']
+    for make_subg in make_subgraphs:
+        for stype in stypes:
+            orig, subg, inputs, aux_states = make_subg(stype)
+            all_inputs = copy.deepcopy(inputs)
+            all_inputs.update(aux_states)
+            args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
+            e1 = orig.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad,
+                    aux_states=all_inputs)
+            args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
+            e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad,
+                    aux_states=all_inputs)
+            e1.forward(is_train=True)
+            e2.forward(is_train=True)
+            for i in range(len(e1.outputs)):
+                assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(),
+                        rtol=0.001, atol=0.0001)
+
+            out_grads = [mx.nd.random.uniform(-1, 1, shape=out.shape, ctx=default_context())
+                    for out in e1.outputs]
+            e1.backward(out_grads)
+            e2.backward(out_grads)
+            for i in range(len(e1.grad_arrays)):
+                assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy(),
+                        rtol=0.001, atol=0.0001)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()