You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/23 21:45:41 UTC

[GitHub] eric-haibin-lin closed pull request #11641: [MXNET-876] make CachedOp a normal operator

eric-haibin-lin closed pull request #11641: [MXNET-876] make CachedOp a normal operator
URL: https://github.com/apache/incubator-mxnet/pull/11641
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 0c4c1e60208..1f115cd64ad 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -23,6 +23,7 @@
 #include "../executor/exec_pass.h"
 #include "../profiler/profiler.h"
 #include "../operator/operator_common.h"
+#include "../operator/subgraph/common.h"
 
 
 namespace mxnet {
@@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward(
   return op_state;
 }
 
-
 void CachedOp::DynamicBackward(
     const bool retain_graph,
     const OpStatePtr& op_state,
@@ -1067,34 +1067,153 @@ void CachedOp::Backward(
   Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
-bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
-                                  const int dev_mask,
-                                  DispatchMode* dispatch_mode,
-                                  std::vector<int> *in_attrs,
-                                  std::vector<int> *out_attrs) {
-  using namespace imperative;
-  nnvm::Graph g(fwd_graph_);
-  const auto& idx = g.indexed_graph();
-  const auto &outputs = idx.outputs();
+/*
+ * This is the operator state of CachedOp when CachedOp is used in the symbol
+ * executor. This is different from the OpState returned by CachedOp::Forward.
+ * The main reason why we need this OpState is that CachedOp and the symbol executor
+ * maintain OpState differently. The symbol executor generates OpState in advance
+ * while CachedOp generates OpState after Forward is called. We need this data
+ * structure to keep the OpState generated by CachedOp::Forward and pass it to
+ * Backward.
+ */
+struct CachedOpActualState {
+  std::shared_ptr<CachedOp> op;
+  OpStatePtr forward_state;
 
-  // Prepare stypes and contexts based on inputs
-  StorageTypeVector storage_type_inputs;
-  storage_type_inputs.reserve(in_attrs->size());
-  for (size_t i = 0; i < in_attrs->size(); ++i) {
-    storage_type_inputs.emplace_back(in_attrs->at(i));
+  explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
+    this->op = op;
   }
-  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+};
 
-  // Forward graph storage type inference
-  CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
-  // Retrieve result and set outputs
-  const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
-  for (size_t i = 0; i < out_attrs->size(); i++) {
-    const auto eid = idx.entry_id(outputs[i]);
-    STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+/*
+ * This is the forward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpForward(const OpStatePtr& state_ptr,
+                     const OpContext& ctx,
+                     const std::vector<NDArray>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<NDArray>& outputs) {
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs(in_bufs.size());
+  std::vector<NDArray *> out_ptrs(out_bufs.size());
+  for (size_t i = 0; i < in_ptrs.size(); i++)
+    in_ptrs[i] = &in_bufs[i];
+  for (size_t i = 0; i < out_ptrs.size(); i++)
+    out_ptrs[i] = &out_bufs[i];
+
+  // Set is_recording correct for the imperative executor.
+  bool orig_is_record;
+  if (ctx.need_grad)
+    orig_is_record = Imperative::Get()->set_is_recording(true);
+  else
+    orig_is_record = Imperative::Get()->is_recording();
+  // Set is_training correct for the imperative executor.
+  bool orig_is_train;
+  if (ctx.is_train)
+    orig_is_train = Imperative::Get()->set_is_training(true);
+  else
+    orig_is_train = Imperative::Get()->is_training();
+  s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
+  Imperative::Get()->set_is_training(orig_is_train);
+  Imperative::Get()->set_is_recording(orig_is_record);
+  // The arrays in out_ptrs may be changed by CachedOp.
+  // If it is, we need to copy data back.
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    if (!out_bufs[i].IsSame(outputs[i]))
+      CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+/*
+ * This is the backward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpBackward(const OpStatePtr& state_ptr,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs;
+  std::vector<NDArray *> out_ptrs;
+  CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
+  in_ptrs.reserve(s.op->num_backward_inputs());
+  out_ptrs.reserve(s.op->num_inputs());
+
+  const std::vector<bool> &save_inputs = s.op->save_inputs();
+  const std::vector<bool> &save_outputs = s.op->save_outputs();
+  size_t bwd_in_dep = s.op->num_inputs();
+  size_t bwd_out_dep = s.op->num_outputs();
+  CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
+  size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep;
+
+  // Find inputs, outputs and ograds
+  auto ograds_begin = in_bufs.begin();
+  auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
+  auto in_begin = ograds_end;
+  auto in_end = in_begin + bwd_in_dep;
+  auto out_begin = in_end;
+  auto out_end = in_bufs.end();
+
+  for (auto it = ograds_begin; it != ograds_end; it++)
+    in_ptrs.push_back(&(*it));
+
+  CHECK_EQ(save_inputs.size(), in_end - in_begin);
+  CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
+  for (auto it = in_begin; it != in_end; it++) {
+    auto i = it - in_begin;
+    if (save_inputs[i])
+      in_ptrs.push_back(&(*it));
   }
-  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
-  return true;
+  for (auto it = out_begin; it != out_end; it++) {
+    auto i = it - out_begin;
+    if (save_outputs[i])
+      in_ptrs.push_back(&(*it));
+  }
+  CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    out_ptrs.push_back(&out_bufs[i]);
+  CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
+  // Set is_training correct for the imperative executor.
+  bool orig_is_train;
+  if (ctx.is_train)
+    orig_is_train = Imperative::Get()->set_is_training(true);
+  else
+    orig_is_train = Imperative::Get()->is_training();
+  // TODO(zhengda) CachedOp supports recording computation when running
+  // the backward path. This is necessary if we want to support the second-order
+  // differentiation. However, MXNet operator doesn't have an interface to
+  // pass a flag to determine whether to record computation inside an operator.
+  // Let's use false here for now and design a solution when the second-order
+  // differentiation is supported.
+  s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
+  Imperative::Get()->set_is_training(orig_is_train);
+
+  // Clean up what we recorded.
+  s.forward_state.reset();
+
+  // The arrays in out_ptrs may be changed by CachedOp.
+  // If it is, we need to copy data back.
+  // For example, when the inputs and outputs share the same NDArrays,
+  // the outputs will be replaced by inputs.
+  // https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    if (!out_bufs[i].IsSame(outputs[i]))
+      CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
+                               Context ctx,
+                               const std::vector<TShape>& in_shapes,
+                               const std::vector<int>& in_types) {
+  const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+  return OpStatePtr::Create<CachedOpActualState>(op);
 }
 
 bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
@@ -1143,6 +1262,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
+  CachedOpConfig param;
+  try {
+    param.Init(attrs->dict);
+  } catch (const dmlc::ParamError& e) {
+    std::ostringstream os;
+    os << e.what();
+    os << ", in operator " << attrs->op->name << "("
+       << "name=\"" << attrs->name << "\"";
+    for (const auto& k : attrs->dict) {
+      os << ", " << k.first << "=\"" << k.second << "\"";
+    }
+    os << ")";
+    throw dmlc::ParamError(os.str());
+  }
+  if (!param.subgraph.empty()) {
+    nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph);
+    CHECK(!g.outputs.empty());
+    nnvm::Symbol sym;
+    sym.outputs = g.outputs;
+    std::vector<std::pair<std::string, std::string> > flags;
+    for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++)
+      flags.emplace_back(it->first, it->second);
+    attrs->parsed = CachedOpPtr(new CachedOp(sym, flags));
+  }
+}
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
@@ -1153,19 +1298,62 @@ NNVM_REGISTER_OP(_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->num_outputs();
   })
-.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
-                                                     const int dev_mask,
-                                                     DispatchMode* dispatch_mode,
-                                                     std::vector<int> *in_attrs,
-                                                     std::vector<int> *out_attrs) {
-    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
-    return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
-  })
+.set_attr_parser(CachedOpParamParser)
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
     return op->Gradient(n, ograds);
-  });
+  })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->ListForwardInputNames();
+  })
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op->ListForwardOutputNames();
+  })
+.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState)
+.set_attr<nnvm::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_shapes,
+     std::vector<TShape> *out_shapes) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes);
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int> *in_types,
+     std::vector<int> *out_types) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types);
+  })
+.set_attr<FInferStorageType>("FInferStorageType",
+  [](const nnvm::NodeAttrs& attrs,
+     const int dev_mask,
+     DispatchMode* dispatch_mode,
+     std::vector<int>* in_stypes,
+     std::vector<int>* out_stypes) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(),
+                                                  dev_mask, dispatch_mode,
+                                                  in_stypes, out_stypes);
+  })
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward)
+.set_attr<nnvm::FMutateInputs>("FMutateInputs",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym());
+  })
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const nnvm::NodeAttrs& attrs) {
+    const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+    return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym());
+  })
+.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
+.add_argument("data", "NDArray-or-Symbol[]", "input data list");
 
 NNVM_REGISTER_OP(_backward_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs){
@@ -1184,6 +1372,9 @@ NNVM_REGISTER_OP(_backward_CachedOp)
     const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
     return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
   })
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward)
+.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true);
 
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 4f4dfdcc14d..59a793ee1b6 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
   bool static_shape;
   nnvm::Tuple<uint32_t> data_indices;
   nnvm::Tuple<uint32_t> param_indices;
+  std::string subgraph;
   DMLC_DECLARE_PARAMETER(CachedOpConfig) {
     DMLC_DECLARE_FIELD(static_alloc)
     .set_default(false)
@@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
     DMLC_DECLARE_FIELD(param_indices)
     .set_default(nnvm::Tuple<uint32_t>())
     .describe("Position of parameters.");
+    DMLC_DECLARE_FIELD(subgraph)
+    .set_default(std::string(""))
+    .describe("JSON string of a subgraph.");
   }
 };
 
@@ -80,6 +84,10 @@ class CachedOp {
   uint32_t num_backward_inputs() const {
     return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
   }
+  uint32_t num_backward_outputs() const {
+    auto &idx = fwd_graph_.indexed_graph();
+    return idx.input_nodes().size() - idx.mutable_input_nodes().size();
+  }
   std::vector<bool>& save_inputs() {
     return save_inputs_;
   }
@@ -102,13 +110,6 @@ class CachedOp {
       const std::vector<NDArray*>& inputs,
       const std::vector<OpReqType>& reqs,
       const std::vector<NDArray*>& outputs);
-  // forward storage type inference
-  bool ForwardStorageType(
-      const nnvm::NodeAttrs& attrs,
-      const int dev_mask,
-      DispatchMode* dispatch_mode,
-      std::vector<int> *in_attrs,
-      std::vector<int> *out_attrs);
   // backward storage type inference
   bool BackwardStorageType(
       const nnvm::NodeAttrs& attrs,
@@ -116,6 +117,19 @@ class CachedOp {
       DispatchMode* dispatch_mode,
       std::vector<int> *in_attrs,
       std::vector<int> *out_attrs);
+  std::vector<std::string> ListForwardInputNames() const {
+    nnvm::Symbol sym = GetForwardSym();
+    return sym.ListInputNames(nnvm::Symbol::kAll);
+  }
+  std::vector<std::string> ListForwardOutputNames() const {
+    nnvm::Symbol sym = GetForwardSym();
+    return sym.ListOutputNames();
+  }
+  nnvm::Symbol GetForwardSym() const {
+    nnvm::Symbol sym;
+    sym.outputs = fwd_graph_.outputs;
+    return sym;
+  }
 
  private:
   struct GraphInfo;
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 29112939a22..6a4c3d02707 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -221,7 +221,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
  */
 #define SHAPE_ASSIGN_CHECK(shape_array, index, shape)                       \
   {                                                                         \
-    if (!shape_assign(&(shape_array)[index], TShape(shape))) {              \
+    if (!::mxnet::op::shape_assign(&(shape_array)[index], TShape(shape))) { \
       std::ostringstream os;                                                \
       os << "Shape inconsistent, Provided = " << (shape_array)[index] << ','\
          << " inferred shape=" << shape;                                    \
@@ -238,11 +238,11 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
  */
 #define TYPE_ASSIGN_CHECK(type_array, index, type)                          \
   {                                                                         \
-    if (!type_assign(&(type_array)[index], type)) {                         \
+    if (!::mxnet::op::type_assign(&(type_array)[index], type)) {            \
       std::ostringstream os;                                                \
       os << "Type inconsistent, Provided = "                                \
-         << type_string((type_array)[index]) << ','                         \
-         << " inferred type = " << type_string(type);                       \
+         << ::mxnet::op::type_string((type_array)[index]) << ','            \
+         << " inferred type = " << ::mxnet::op::type_string(type);          \
       throw ::mxnet::op::InferTypeError(os.str(), index);                   \
     }                                                                       \
   }
@@ -291,8 +291,8 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
 #define UNIFORM_TYPE_CHECK(type, expected, arg)                         \
   {                                                                     \
     CHECK_EQ(type, expected) << "This layer requires uniform type. "    \
-                             << "Expected '" << type_string(expected)   \
-                             << "' v.s. given '" << type_string(type)   \
+                             << "Expected '" << ::mxnet::op::type_string(expected)   \
+                             << "' v.s. given '" << ::mxnet::op::type_string(type)   \
                              << "' at '" << arg << "'";                 \
   }
 
diff --git a/src/operator/subgraph/common.h b/src/operator/subgraph/common.h
index 22058d556e0..4e1cd66b8b6 100644
--- a/src/operator/subgraph/common.h
+++ b/src/operator/subgraph/common.h
@@ -49,11 +49,10 @@ inline std::vector<std::string> DefaultSubgraphOpListOutputs(const nnvm::NodeAtt
   return sym.ListOutputNames();
 }
 
-inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
-                                   std::vector<TShape> *in_shapes,
-                                   std::vector<TShape> *out_shapes) {
+inline bool DefaultSubgraphOpShapeHelper(const nnvm::Symbol& subgraph_sym,
+                                         std::vector<TShape> *in_shapes,
+                                         std::vector<TShape> *out_shapes) {
   using namespace exec;
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -94,10 +93,15 @@ inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
   return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
 }
 
-inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
-                                  std::vector<int> *in_types,
-                                  std::vector<int> *out_types) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline bool DefaultSubgraphOpShape(const nnvm::NodeAttrs& attrs,
+                                   std::vector<TShape> *in_shapes,
+                                   std::vector<TShape> *out_shapes) {
+  return DefaultSubgraphOpShapeHelper(*attrs.subgraphs[0], in_shapes, out_shapes);
+}
+
+inline bool DefaultSubgraphOpTypeHelper(const nnvm::Symbol& subgraph_sym,
+                                        std::vector<int> *in_types,
+                                        std::vector<int> *out_types) {
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -137,12 +141,17 @@ inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
   return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
 }
 
-inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
-                                         const int dev_mask,
-                                         DispatchMode* dispatch_mode,
-                                         std::vector<int>* in_stypes,
-                                         std::vector<int>* out_stypes) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline bool DefaultSubgraphOpType(const nnvm::NodeAttrs& attrs,
+                                  std::vector<int> *in_types,
+                                  std::vector<int> *out_types) {
+  return DefaultSubgraphOpTypeHelper(*attrs.subgraphs[0], in_types, out_types);
+}
+
+inline bool DefaultSubgraphOpStorageTypeHelper(const nnvm::Symbol& subgraph_sym,
+                                               const int dev_mask,
+                                               DispatchMode* dispatch_mode,
+                                               std::vector<int>* in_stypes,
+                                               std::vector<int>* out_stypes) {
   nnvm::Graph g;
   g.outputs = subgraph_sym.outputs;
   const auto& idx_g = g.indexed_graph();
@@ -190,12 +199,21 @@ inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
   return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
 }
 
+inline bool DefaultSubgraphOpStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int>* in_stypes,
+                                         std::vector<int>* out_stypes) {
+  return DefaultSubgraphOpStorageTypeHelper(*attrs.subgraphs[0], dev_mask, dispatch_mode,
+                                            in_stypes, out_stypes);
+}
+
 inline ExecType DefaultSubgraphOpExecType(const nnvm::NodeAttrs& attrs) {
   return ExecType::kSubgraphExec;
 }
 
-inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputsHelper(
+  const nnvm::Symbol& subgraph_sym) {
   const std::vector<std::string> input_names = subgraph_sym.ListInputNames(nnvm::Symbol::kAll);
   const std::vector<std::string> immutable_input_names =
     subgraph_sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
@@ -217,8 +235,12 @@ inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttr
   return ret;
 }
 
-inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
-  const nnvm::Symbol& subgraph_sym = *attrs.subgraphs[0];
+inline std::vector<uint32_t> DefaultSubgraphOpMutableInputs(const nnvm::NodeAttrs& attrs) {
+  return DefaultSubgraphOpMutableInputsHelper(*attrs.subgraphs[0]);
+}
+
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequestHelper(
+    const nnvm::Symbol& subgraph_sym) {
   static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
   std::set<ResourceRequest::Type> resource_types;
   DFSVisit(subgraph_sym.outputs, [&](const nnvm::NodePtr& node) {
@@ -231,6 +253,10 @@ inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm:
   return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
 }
 
+inline std::vector<ResourceRequest> DefaultSubgraphOpResourceRequest(const nnvm::NodeAttrs& attrs) {
+  return DefaultSubgraphOpResourceRequestHelper(*attrs.subgraphs[0]);
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/subgraph/default_subgraph_op.cc b/src/operator/subgraph/default_subgraph_op.cc
deleted file mode 100644
index d5fb7ee2db6..00000000000
--- a/src/operator/subgraph/default_subgraph_op.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*/
-
-#include <mxnet/ndarray.h>
-#include "./common.h"
-#include "../../imperative/imperative_utils.h"
-#include "../../imperative/cached_op.h"
-
-namespace mxnet {
-namespace op {
-
-#define DEBUG_SUBGRAPH 0
-
-class DefaultSubgraphOperator {
- public:
-  explicit DefaultSubgraphOperator(const Symbol& sym) : subgraph_sym_(sym) {
-    subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"},
-                                            {"static_shape", "true"}}));
-  }
-
-  void Forward(const OpContext& ctx,
-               const std::vector<NDArray>& inputs,
-               const std::vector<OpReqType>& req,
-               const std::vector<NDArray>& outputs);
-  void Backward(const OpContext& ctx,
-                const std::vector<NDArray>& inputs,
-                const std::vector<OpReqType>& req,
-                const std::vector<NDArray>& outputs) {
-    LOG(FATAL) << "Not implemented";
-  }
-
- private:
-  nnvm::Symbol subgraph_sym_;
-  CachedOpPtr subgraph_exec_;
-};
-
-void DefaultSubgraphOperator::Forward(const OpContext& ctx,
-                                      const std::vector<NDArray>& inputs,
-                                      const std::vector<OpReqType>& req,
-                                      const std::vector<NDArray>& outputs) {
-  std::vector<NDArray> tmp_inputs = inputs;
-  std::vector<NDArray*> input_ptrs;
-  input_ptrs.reserve(inputs.size());
-  for (auto& nd : tmp_inputs) {
-    input_ptrs.push_back(&nd);
-  }
-  std::vector<NDArray> tmp_outputs = outputs;
-  std::vector<NDArray*> output_ptrs;
-  for (auto& nd : tmp_outputs) {
-    output_ptrs.push_back(&nd);
-  }
-#if DEBUG_SUBGRAPH
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    LOG(INFO) << "inputs[" << i << "].version = " << inputs[i].version();
-  }
-  for (size_t i = 0; i < outputs.size(); ++i) {
-    LOG(INFO) << "outputs[" << i << "].version = " << outputs[i].version();
-  }
-#endif
-  subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
-}
-
-OpStatePtr CreateDefaultSubgraphOpState(const NodeAttrs& attrs,
-                                        Context ctx,
-                                        const std::vector<TShape>& in_shapes,
-                                        const std::vector<int>& in_types) {
-  return OpStatePtr::Create<DefaultSubgraphOperator>(*attrs.subgraphs[0]);
-}
-
-void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs) {
-  DefaultSubgraphOperator& op = state_ptr.get_state<DefaultSubgraphOperator>();
-  op.Forward(ctx, inputs, req, outputs);
-}
-
-NNVM_REGISTER_OP(_default_subgraph_op)
-.describe(R"code(_default_subgraph_op)code" ADD_FILELINE)
-.set_num_inputs(DefaultSubgraphOpNumInputs)
-.set_num_outputs(DefaultSubgraphOpNumOutputs)
-.set_attr<nnvm::FListInputNames>("FListInputNames", DefaultSubgraphOpListInputs)
-.set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
-.set_attr<FCreateOpState>("FCreateOpState", CreateDefaultSubgraphOpState)
-.set_attr<nnvm::FInferShape>("FInferShape", DefaultSubgraphOpShape)
-.set_attr<nnvm::FInferType>("FInferType", DefaultSubgraphOpType)
-.set_attr<FInferStorageType>("FInferStorageType", DefaultSubgraphOpStorageType)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DefaultSubgraphOpForward)
-.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
-.set_attr<std::string>("key_var_num_args", "num_args")
-.set_attr<FExecType>("FExecType", DefaultSubgraphOpExecType)
-.add_argument("data", "NDArray-or-Symbol[]", "input data list");
-
-}  // namespace op
-}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_op.cu b/src/operator/subgraph/default_subgraph_op.cu
deleted file mode 100644
index 008826b21d7..00000000000
--- a/src/operator/subgraph/default_subgraph_op.cu
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file default_subgraph_op.cu
- * \brief GPU Implementation of subgraph operations
- */
-
-#include <mxnet/ndarray.h>
-#include "./common.h"
-#include "../../imperative/imperative_utils.h"
-#include "../../imperative/cached_op.h"
-
-namespace mxnet {
-namespace op {
-
-void DefaultSubgraphOpForward(const OpStatePtr& state_ptr,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs);
-
-NNVM_REGISTER_OP(_default_subgraph_op)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", DefaultSubgraphOpForward);
-
-}  // namespace op
-}  // namespace mxnet
diff --git a/src/operator/subgraph/default_subgraph_property.cc b/src/operator/subgraph/default_subgraph_property.cc
index c8d3e9ffd43..0152344f4d4 100644
--- a/src/operator/subgraph/default_subgraph_property.cc
+++ b/src/operator/subgraph/default_subgraph_property.cc
@@ -21,6 +21,7 @@
 #include <string>
 #include "./common.h"
 #include "./subgraph_property.h"
+#include "../../imperative/cached_op.h"
 
 namespace mxnet {
 namespace op {
@@ -51,7 +52,7 @@ class ContainOpSelector: public SubgraphSelector {
 
 /*
  * This subgraph property finds a subgraph whose nodes have only operators
- * within a set. The operators in the subgraph will be executed by _default_subgraph_op.
+ * within a set. The operators in the subgraph will be executed by _CachedOp.
  */
 class DefaultSubgraphProperty: public SubgraphProperty {
  public:
@@ -59,9 +60,13 @@ class DefaultSubgraphProperty: public SubgraphProperty {
   virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
                                            const int subgraph_id = 0) const {
     nnvm::NodePtr n = nnvm::Node::Create();
-    n->attrs.op = Op::Get("_default_subgraph_op");
-    n->attrs.name = "_default_subgraph_op" + std::to_string(subgraph_id);
+    n->attrs.op = Op::Get("_CachedOp");
+    n->attrs.name = "_CachedOp" + std::to_string(subgraph_id);
     n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
+
+    std::vector<std::pair<std::string, std::string> > flags{{"static_alloc", "true"}};
+    n->attrs.parsed = CachedOpPtr(new CachedOp(sym, flags));
+
     return n;
   }
   virtual SubgraphSelectorPtr CreateSubgraphSelector() const {
diff --git a/tests/python/unittest/test_subgraph.py b/tests/python/unittest/test_subgraph.py
new file mode 100644
index 00000000000..b5577d4d0ff
--- /dev/null
+++ b/tests/python/unittest/test_subgraph.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from __future__ import print_function
+import numpy as np
+import mxnet as mx
+import copy
+import math
+import ctypes
+import random
+import itertools
+from numpy.testing import assert_allclose, assert_array_equal
+from mxnet.test_utils import *
+from mxnet.base import py_str, MXNetError, _as_list, SymbolHandle, check_call, _LIB, c_handle_array, mx_uint
+from common import setup_module, with_seed, teardown
+import unittest
+from mxnet.gluon.model_zoo.vision import get_model
+
+def make_subgraph(subg, *args):
+    js = subg.tojson()
+    return mx.sym._internal._CachedOp(*args, subgraph=js)
+
+@with_seed()
+def test_make_subgraph():
+    def make_subgraph1(stype):
+        a = mx.symbol.Variable(name='a', stype=stype)
+        b = mx.symbol.Variable(name='b', stype=stype)
+        c = a * b
+        d = c * 2
+
+        a1 = mx.symbol.Variable(name='a', stype=stype)
+        b1 = mx.symbol.Variable(name='b', stype=stype)
+        y = make_subgraph(c, a1, b1)
+        y = y * 2
+
+        s = (10, 10)
+        a_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
+                ctx=default_context()).tostype(stype)
+        b_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
+                ctx=default_context()).tostype(stype)
+        return (d, y, {'a': a_arr, 'b': b_arr}, {})
+
+    def create_weights(shapes, names):
+        nd_dict = {}
+        sym_dict = {}
+        assert len(shapes) == len(names)
+        for i in range(len(shapes)):
+            sym_dict[names[i]] = mx.symbol.Variable(names[i])
+            nd_dict[names[i]] = mx.nd.array(np.ones(shapes[i]), ctx=default_context())
+        return (nd_dict, sym_dict)
+
+    def make_subgraph_weight(orig, shape, stype):
+        arg_shapes, out_shapes, aux_shapes = orig.infer_shape(data=shape)
+        weight_shapes = arg_shapes[1:]
+        weight_names = orig.list_arguments()[1:]
+        weight_dict, weight_sym_dict = create_weights(weight_shapes, weight_names)
+        aux_dict, aux_sym_dict = create_weights(aux_shapes, orig.list_auxiliary_states())
+
+        input_dict = copy.deepcopy(weight_sym_dict)
+        input_dict.update(aux_sym_dict)
+        input_dict['data'] = mx.symbol.Variable('data', stype=stype)
+        input_list = []
+        for name in orig.list_inputs():
+            assert name in input_dict.keys()
+            input_list.append(input_dict[name])
+        subg = make_subgraph(orig, *input_list)
+
+        arr = mx.nd.random.uniform(-1, 1, shape=shape, ctx=default_context()).tostype(stype)
+        arg_dict = weight_dict
+        arg_dict['data'] = arr
+        return (orig, subg, arg_dict, aux_dict)
+
+    def make_subgraph2(stype, out_mean_var):
+        data = mx.symbol.Variable('data', stype=stype)
+        orig = mx.symbol.BatchNorm(data, fix_gamma=False,
+                output_mean_var=out_mean_var, name="batchnorm")
+        s = (10, 10)
+        return make_subgraph_weight(orig, s, stype)
+
+    def make_subgraph3(stype):
+        data = mx.symbol.Variable('data', stype=stype)
+        conv1 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
+        bn1 = mx.symbol.BatchNorm(conv1, fix_gamma=False, output_mean_var=False)
+        conv2 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
+        bn2 = mx.symbol.BatchNorm(conv2, fix_gamma=False, output_mean_var=False)
+        orig = bn1 + bn2
+        s = (1, 3, 32, 32)
+        return make_subgraph_weight(orig, s, stype)
+
+    def make_subgraph4(stype):
+        model = get_model('resnet18_v1')
+        model.hybridize()
+        model.initialize()
+        s = (1, 3, 32, 32)
+        data = mx.nd.random.normal(shape=s)
+        out = model(data)
+        model.export('resnet18')
+        orig = mx.sym.load('resnet18-symbol.json')
+        return make_subgraph_weight(orig, s, stype)
+
+    make_subgraphs = [make_subgraph1,
+            lambda stype: make_subgraph2(stype, False),
+            lambda stype: make_subgraph2(stype, True),
+            make_subgraph3, make_subgraph4]
+    stypes = ['default', 'row_sparse']
+    for make_subg in make_subgraphs:
+        for stype in stypes:
+            orig, subg, inputs, aux_states = make_subg(stype)
+            all_inputs = copy.deepcopy(inputs)
+            all_inputs.update(aux_states)
+            args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
+            e1 = orig.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad,
+                    aux_states=all_inputs)
+            args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
+            e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad,
+                    aux_states=all_inputs)
+            e1.forward(is_train=True)
+            e2.forward(is_train=True)
+            for i in range(len(e1.outputs)):
+                assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(),
+                        rtol=0.001, atol=0.0001)
+
+            out_grads = [mx.nd.random.uniform(-1, 1, shape=out.shape, ctx=default_context())
+                    for out in e1.outputs]
+            e1.backward(out_grads)
+            e2.backward(out_grads)
+            for i in range(len(e1.grad_arrays)):
+                assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy(),
+                        rtol=0.001, atol=0.0001)
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services