You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/06/16 01:27:39 UTC

[incubator-mxnet] branch revert-11313-static created (now 24f14f7)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a change to branch revert-11313-static
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git.


      at 24f14f7  Revert "Static alloc for hybridblock (#11313)"

This branch includes the following new commits:

     new 24f14f7  Revert "Static alloc for hybridblock (#11313)"

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


-- 
To stop receiving notification emails like this one, please contact
marcoabreu@apache.org.

[incubator-mxnet] 01/01: Revert "Static alloc for hybridblock (#11313)"

Posted by ma...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch revert-11313-static
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 24f14f7fdbef427aab25738608267c57de645987
Author: Marco de Abreu <ma...@users.noreply.github.com>
AuthorDate: Fri Jun 15 18:27:32 2018 -0700

    Revert "Static alloc for hybridblock (#11313)"
    
    This reverts commit 5431e12f11fc5446f6ec2a25098b4e3b67ee7eb3.
---
 include/mxnet/c_api.h                   |   5 +
 include/mxnet/imperative.h              |  89 ++++
 include/mxnet/ndarray.h                 |   8 -
 include/mxnet/op_attr_types.h           |  33 +-
 python/mxnet/_ctypes/ndarray.py         |  16 +-
 python/mxnet/gluon/block.py             |  74 ++--
 src/c_api/c_api_ndarray.cc              |  26 +-
 src/engine/threaded_engine.cc           |   3 +-
 src/executor/attach_op_execs_pass.cc    | 165 ++++---
 src/executor/attach_op_resource_pass.cc |  16 +-
 src/executor/exec_pass.h                |  28 +-
 src/executor/graph_executor.cc          |   2 +-
 src/imperative/cached_op.cc             | 750 ++++++--------------------------
 src/imperative/cached_op.h              | 174 --------
 src/imperative/imperative.cc            |  90 +++-
 src/imperative/imperative_utils.cc      | 120 -----
 src/imperative/imperative_utils.h       | 256 ++---------
 tests/python/unittest/test_gluon.py     |  67 +--
 18 files changed, 523 insertions(+), 1399 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 4dd858a..55c26bc 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -987,6 +987,11 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
                                  int num_flags,
                                  const char** keys,
                                  const char** vals,
+                                 int num_inputs,
+                                 const char** input_names,
+                                 int num_params,
+                                 const char** param_names,
+                                 NDArrayHandle* params,
                                  CachedOpHandle *out);
 /*!
  * \brief free cached operator
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 7ea60df..758ce85 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,6 +35,23 @@
 #include "./ndarray.h"
 
 namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
+  uint32_t inline_limit;
+  uint32_t forward_bulk_size;
+  uint32_t backward_bulk_size;
+  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
+    DMLC_DECLARE_FIELD(inline_limit)
+    .set_default(2)
+    .describe("Maximum number of operators that can be inlined.");
+    DMLC_DECLARE_FIELD(forward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during forward pass.");
+    DMLC_DECLARE_FIELD(backward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during backward pass.");
+  }
+};
 /*! \brief runtime functions for NDArray */
 class Imperative {
  public:
@@ -77,6 +94,67 @@ class Imperative {
              && info.out_grads.size() == 1;
     }
   };
+  class CachedOp {
+   public:
+    CachedOp(
+        const nnvm::Symbol& sym,
+        const std::vector<std::pair<std::string, std::string> >& flags,
+        const std::vector<std::string> arg_names,
+        const std::unordered_map<std::string, std::vector<NDArray> >& params);
+    uint32_t num_inputs() {
+      return fwd_graph_.indexed_graph().input_nodes().size();
+    }
+    uint32_t num_outputs() {
+      return fwd_graph_.outputs.size();
+    }
+    uint32_t num_backward_inputs() {
+      return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
+    }
+    std::vector<bool>& save_inputs() {
+      return save_inputs_;
+    }
+    std::vector<bool>& save_outputs() {
+      return save_outputs_;
+    }
+    const std::unordered_set<uint32_t>& mutable_input_nodes() {
+      return fwd_graph_.indexed_graph().mutable_input_nodes();
+    }
+    nnvm::Graph GetForwardGraph(const bool recording,
+                                const std::vector<NDArray*>& inputs);
+    nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
+                                 const std::vector<OpReqType>& reqs,
+                                 const std::vector<NDArray*>& inputs);
+    std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
+                                          const std::vector<nnvm::NodeEntry>& ograds);
+    void Forward(const std::shared_ptr<CachedOp>& op_ptr,
+                 const std::vector<NDArray*>& args,
+                 const std::vector<NDArray*>& outputs);
+    void Backward(const bool retain_graph,
+                  const OpStatePtr& state,
+                  const std::vector<NDArray*>& inputs,
+                  const std::vector<OpReqType>& reqs,
+                  const std::vector<NDArray*>& outputs);
+
+   private:
+    struct CachedOpState {
+      std::vector<NDArray> buff;
+      std::vector<OpStatePtr> states;
+    };
+    std::mutex mutex_;
+    CachedOpConfig config_;
+    nnvm::Graph fwd_graph_;
+    nnvm::Graph grad_graph_;
+    nnvm::Graph full_graph_;
+    std::unordered_map<Context, std::vector<NDArray> > params_;
+    bool inlining_;
+    std::vector<nnvm::NodeEntry> ograd_entries_;
+    std::vector<bool> curr_grad_req_;
+    std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
+    std::vector<uint32_t> fwd_args_idx_;
+    std::vector<uint32_t> fwd_params_idx_;
+    std::vector<uint32_t> bwd_input_eid_;
+    std::vector<bool> save_inputs_, save_outputs_;
+  };
   /*! \brief whether operator recording is on. */
   bool is_training() const {
     return is_train_;
@@ -144,6 +222,15 @@ class Imperative {
       uint32_t num_inputs, uint32_t num_outputs,
       std::vector<bool> *p_save_inputs,
       std::vector<bool> *p_save_outputs);
+  void RunGraph(
+      const bool retain_graph,
+      const nnvm::IndexedGraph& idx,
+      const std::vector<NDArray*> arrays,
+      size_t node_start, size_t node_end,
+      std::vector<OpReqType>&& array_reqs,
+      std::vector<uint32_t>&& ref_count,
+      std::vector<OpStatePtr> *p_states,
+      const DispatchModeVector& dispatch_modes);
   /*! \brief indicate whether is training. */
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local bool is_train_;
@@ -160,5 +247,7 @@ class Imperative {
   int backward_bulk_size_{0};
 };
 
+using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
+
 }  // namespace mxnet
 #endif  // MXNET_IMPERATIVE_H_
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index ae96fd8..e243eb7 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -155,14 +155,6 @@ class NDArray {
     return byte_offset_ > 0 || shape() != ptr_->storage_shape;
   }
 
-  /* \brief Check whether the two arrays are the same array */
-  inline bool IsSame(const NDArray& other) {
-    return ptr_ == other.ptr_ &&
-        shape_ == other.shape_ &&
-        byte_offset_ == other.byte_offset_ &&
-        dtype_ == other.dtype_;
-  }
-
   /*!
    * \return the shape of current NDArray.
    */
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index f4694ef..3969d84 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -126,36 +126,25 @@ class OpStatePtr {
   template<typename T, typename... Args>
   static OpStatePtr Create(Args&&... args) {
     OpStatePtr ret;
-    auto state = new T(std::forward<Args>(args)...);
-    auto var = Engine::Get()->NewVariable();
-    ret.ptr_.reset(
-      new OpState(var, state),
-      [](OpState* p) {
-        Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
-        delete reinterpret_cast<T*>(p->state);
-        delete p;
-      });
+    ret.ptr_ = std::make_shared<OpState>();
+    ret.ptr_->var_ = Engine::Get()->NewVariable();
+    ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
 
     return ret;
   }
   /* \brief Get engine variable associated with this state */
   engine::VarHandle get_var() const {
-    return ptr_->var;
+    return ptr_->var_;
   }
   /* \brief Get state of type T */
   template<typename T>
   T& get_state() const {
-    return *reinterpret_cast<T*>(ptr_->state);
+    return dmlc::get<T>(ptr_->state_);
   }
   /* \brief clear state */
   void reset() {
     ptr_.reset();
   }
-  /* \brief checks whether the managed object is managed only by the current
-            OpStatePtr instance */
-  bool unique() const {
-    return ptr_.unique();
-  }
   /* \brief Whether state is empty */
   explicit operator bool() const {
     return ptr_ ? true : false;
@@ -164,12 +153,16 @@ class OpStatePtr {
  private:
   /* \brief state structure */
   struct OpState {
-    engine::VarHandle var;
-    void* state;
-
-    OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
+    OpState() {}
     OpState(const OpState& other) = delete;
     OpState& operator=(const OpState& other) = delete;
+
+    ~OpState() {
+      Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
+    }
+
+    engine::VarHandle var_;
+    dmlc::any state_;
   };
   /* \brief shared pointer to state */
   std::shared_ptr<OpState> ptr_;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index f324545..d2cae0c 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,14 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 class CachedOp(object):
     """Cached operator handle."""
     __slots__ = ["handle"]
-    def __init__(self, sym, flags=()):
+    def __init__(self, sym, flags=(), inputs=None, params=None):
         self.handle = CachedOpHandle()
+        param_names = []
+        param_arrays = []
+        if inputs is None:
+            assert params is None, "When inputs is None params must also be None."
+            inputs = sym.list_inputs()
+        elif params is not None:
+            for name, arrs in params.items():
+                param_arrays.extend(arrs)
+                param_names.extend([name] * len(arrs))
 
         check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
             len(flags),
             c_str_array([key for key, _ in flags]),
             c_str_array([str(val) for _, val in flags]),
+            len(inputs),
+            c_str_array(inputs),
+            len(param_names),
+            c_str_array(param_names),
+            c_handle_array(param_arrays),
             ctypes.byref(self.handle)))
 
     def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 293fafa..3b97c05 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -502,16 +502,8 @@ class Block(object):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
-        static_alloc : bool, default False
-            Statically allocate memory to improve speed. Memory usage may increase.
-        static_shape : bool, default False
-            Optimize for invariant input shapes between iterations. Must also
-            set static_alloc to True. Change of input shapes is still allowed
-            but slower.
-        forward_bulk_size : int, default 15
-            Segment size of bulk execution during forward pass.
-        backward_bulk_size : int, default 15
-            Segment size of bulk execution during backward pass.
+        **kwargs : string
+            Additional flags for hybridized operator.
         """
         for cld in self._children.values():
             cld.hybridize(active, **kwargs)
@@ -704,7 +696,7 @@ class HybridBlock(Block):
         self._out_format = None
         self._in_format = None
         self._active = False
-        self._flags = []
+        self._flags = {}
 
     def __setattr__(self, name, value):
         """Registers parameters."""
@@ -731,43 +723,39 @@ class HybridBlock(Block):
         return self._cached_graph
 
     def _build_cache(self, *args):
-        data, out = self._get_graph(*args)
-        data_names = {data.name : i for i, data in enumerate(data)}
-        params = self.collect_params()
-        input_names = out.list_inputs()
+        inputs, out = self._get_graph(*args)
+        input_names = [i.name for i in inputs]
 
+        params = self.collect_params()
         param_names = set(params.keys())
-        expected_names = set(input_names)
+        expected_names = set(out.list_inputs())
         for name in expected_names:
-            assert name in param_names or name in data_names, \
+            assert name in param_names or name in input_names, \
                 "Unknown input to HybridBlock: %s"%name
 
-        used_data_names = [i for i in data_names if i in expected_names]
-        if len(used_data_names) != len(data_names):
-            unused = ', '.join(['%d-th'%i for name, i in data_names.items()
+        used_input_names = [i for i in input_names if i in expected_names]
+        if len(used_input_names) != len(input_names):
+            unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
                                 if name not in expected_names])
             warnings.warn("The %s input to HybridBlock is not used by any "
                           "computation. Is this intended?"%unused, stacklevel=4)
 
-        used_param_names = [i for i in param_names if i in expected_names]
+        used_param_names = set(i for i in param_names if i in expected_names)
         if len(used_param_names) != len(param_names):
-            unused = ', '.join(list(param_names - set(used_param_names)))
+            unused = ', '.join(list(param_names - used_param_names))
             warnings.warn("Parameter %s is not used by any computation. "
                           "Is this intended?"%unused, stacklevel=4)
 
-        data_indices = []
-        param_indices = []
-        self._cached_op_args = []
-        for i, name in enumerate(input_names):
-            if name in data_names:
-                data_indices.append(i)
-                self._cached_op_args.append((True, data_names[name]))
-            else:
-                param_indices.append(i)
-                self._cached_op_args.append((False, params[name]))
-        flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
-                self._flags
-        self._cached_op = ndarray.CachedOp(out, flags)
+        used_params = {k: params[k] for k in used_param_names}
+        try:
+            param_dict = {k: v.list_data() for k, v in used_params.items()}
+        except DeferredInitializationError:
+            self._deferred_infer_shape(*args)
+            for i in used_params.values():
+                i._finish_deferred_init()
+            param_dict = {k: v.list_data() for k, v in used_params.items()}
+
+        self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)
 
     def _deferred_infer_shape(self, *args):
         try:
@@ -783,19 +771,7 @@ class HybridBlock(Block):
 
         args, fmt = _flatten(args, "input")
         assert fmt == self._in_format, "Invalid input format"
-        try:
-            cargs = [args[i] if is_arg else i.data()
-                     for is_arg, i in self._cached_op_args]
-        except DeferredInitializationError:
-            self._deferred_infer_shape(*args)
-            cargs = []
-            for is_arg, i in self._cached_op_args:
-                if is_arg:
-                    cargs.append(args[i])
-                else:
-                    i._finish_deferred_init()
-                    cargs.append(i.data())
-        out = self._cached_op(*cargs)
+        out = self._cached_op(*args)
         if isinstance(out, NDArray):
             out = [out]
         return _regroup(out, self._out_format)[0]
@@ -816,7 +792,7 @@ class HybridBlock(Block):
 
     def hybridize(self, active=True, **kwargs):
         self._active = active
-        self._flags = list(kwargs.items())
+        self._flags = kwargs.items()
         self._clear_cached_op()
         if active and self._forward_hooks or self._forward_pre_hooks:
             warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 34bd4b2..9aabe04 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -36,7 +36,6 @@
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
 #include "../imperative/imperative_utils.h"
-#include "../imperative/cached_op.h"
 
 using namespace mxnet;
 
@@ -161,8 +160,12 @@ int MXCreateCachedOp(SymbolHandle handle,
   std::vector<std::string> input_names;
   input_names.reserve(inputs.size());
   for (const auto& i : inputs) input_names.push_back(i->attrs.name);
-  *out = new CachedOpPtr(new CachedOp(
-      *sym, std::vector<std::pair<std::string, std::string> >()));
+  *out = new std::shared_ptr<Imperative::CachedOp>(
+      new Imperative::CachedOp(
+        *sym,
+        std::vector<std::pair<std::string, std::string> >(),
+        input_names,
+        std::unordered_map<std::string, std::vector<NDArray> >()));
   API_END();
 }
 
@@ -170,6 +173,11 @@ int MXCreateCachedOpEx(SymbolHandle handle,
                        int num_flags,
                        const char** keys,
                        const char** vals,
+                       int num_args,
+                       const char** arg_names,
+                       int num_params,
+                       const char** param_names,
+                       NDArrayHandle* params,
                        CachedOpHandle *out) {
   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
 
@@ -178,7 +186,17 @@ int MXCreateCachedOpEx(SymbolHandle handle,
   for (int i = 0; i < num_flags; ++i) {
     flags.push_back({keys[i], vals[i]});
   }
-  *out = new CachedOpPtr(new CachedOp(*sym, flags));
+  std::vector<std::string> args;
+  for (int i = 0; i < num_args; ++i) {
+    args.push_back(arg_names[i]);
+  }
+  std::unordered_map<std::string, std::vector<NDArray> > param_dict;
+  for (int i = 0; i < num_params; ++i) {
+    param_dict[param_names[i]].emplace_back(
+        *reinterpret_cast<NDArray*>(params[i]));
+  }
+  *out = new std::shared_ptr<Imperative::CachedOp>(
+      new Imperative::CachedOp(*sym, flags, args, param_dict));
   API_END();
 }
 
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc19..dc0436e 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -278,8 +278,6 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
 }
 
 void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
-  BulkFlush();
-
   ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
   OprBlock* opr_block = OprBlock::New();
   opr_block->opr = threaded_opr;
@@ -325,6 +323,7 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
         << device_count_;
   }
 #endif
+  BulkFlush();
   ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
   opr->temporary = true;
   const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 72919d9..697e486 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -134,10 +134,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
     return state_.get_var();
   }
 
-  OpStatePtr state() const override {
-    return state_;
-  }
-
   explicit StatefulComputeExecutor(const OpStatePtr& state,
                                    const FStatefulCompute& fcompute,
                                    ExecType exec_type,
@@ -146,6 +142,7 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
         state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
+  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulCompute fcompute_;
   ExecType exec_type_;
@@ -173,16 +170,13 @@ class StatefulComputeExExecutor : public OpExecutor {
     return state_.get_var();
   }
 
-  OpStatePtr state() const override {
-    return state_;
-  }
-
   explicit StatefulComputeExExecutor(const OpStatePtr& state,
                                      const FStatefulComputeEx& fcompute,
                                      ExecType exec_type)
       : state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
+  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulComputeEx fcompute_;
   ExecType exec_type_;
@@ -247,15 +241,16 @@ class FComputeExExecutor : public OpExecutor {
   ExecType exec_type_;
 };
 
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
+// pass to attach operator executors
+Graph AttachOpExecs(Graph g) {
   using nnvm::DTypeVector;
   using nnvm::ShapeVector;
   using nnvm::FMutateInputs;
 
-  static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
-  static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
-  static auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
+  auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
+  auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+  auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
 
   const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
   const auto& vshape = g.GetAttr<ShapeVector>("shape");
@@ -264,88 +259,82 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
 
   // get the graph
   const auto& idx = g.indexed_graph();
-  OpExecVector& ret = *p_ret;
+  std::vector<std::shared_ptr<OpExecutor> > ret(idx.num_nodes());
 
   // initialize the nodes
-  const auto& inode = idx[i];
-  if (inode.source->is_variable()) return;
-  const nnvm::Op *op = inode.source->op();
-  ExecType exec_type = ExecType::kSync;
-  std::vector<uint32_t> mutate_index;
-  if (fmutate_inputs.count(op)) {
-    mutate_index = fmutate_inputs[op](inode.source->attrs);
-  }
-  if (fexec_type.count(op)) {
-    exec_type = fexec_type[op](inode.source->attrs);
-  }
-  CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
-  if (fcreate_op_state.count(op)) {
-    std::vector<TShape> ishape;
-    std::vector<int> itype;
-    for (const auto& e : inode.inputs) {
-      ishape.emplace_back(vshape[idx.entry_id(e)]);
-      itype.emplace_back(vdtype[idx.entry_id(e)]);
-    }
-
-    OpStatePtr state = fcreate_op_state[op](
-        inode.source->attrs, vctx[i], ishape, itype);
-    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-        op, "FStatefulComputeEx", vctx[i]);
-    // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
-    if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-      ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
-    } else {
-      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-          op, "FStatefulCompute", vctx[i]);
-      CHECK(fcompute != nullptr)
-          << "One of FStatefulCompute and FStatefulComputeEx must be registered "
-          << "for stateful operator " << op->name;
-      ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
-                                                         exec_type, mutate_index);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    const auto& inode = idx[i];
+    if (inode.source->is_variable()) continue;
+    const nnvm::Op *op = inode.source->op();
+    ExecType exec_type = ExecType::kSync;
+    std::vector<uint32_t> mutate_index;
+    if (fmutate_inputs.count(op)) {
+      mutate_index = fmutate_inputs[op](inode.source->attrs);
     }
-  } else if (is_layer_backward.get(op, false)) {
-    CHECK_GE(inode.control_deps.size(), 1);
-    uint32_t fwd_id = inode.control_deps[0];
-    CHECK(vctx[fwd_id] == vctx[i]);
-    CHECK(ret[fwd_id] != nullptr);
-    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-        op, "FStatefulComputeEx", vctx[i]);
-    // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
-    if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-      ret[i] = std::make_shared<StatefulComputeExExecutor>(
-          ret[fwd_id].get()->state(), fcompute_ex, exec_type);
-    } else {
-      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-          op, "FStatefulCompute", vctx[i]);
-      CHECK(fcompute != nullptr)
-          << "One of FStatefulCompute and FStatefulComputeEx must be registered "
-          << "for stateful operator " << op->name;
-      ret[i] = std::make_shared<StatefulComputeExecutor>(
-          ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
+    if (fexec_type.count(op)) {
+      exec_type = fexec_type[op](inode.source->attrs);
     }
-  } else {
-    FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
-    FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
-    if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-      ret[i] = std::make_shared<FComputeExExecutor>(
-          inode.source->attrs, fcomp_ex, exec_type);
-    } else if (fcompute != nullptr) {
-      ret[i] = std::make_shared<FComputeExecutor>(
-          inode.source->attrs, fcompute, exec_type, mutate_index);
+    CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
+    if (fcreate_op_state.count(op)) {
+      std::vector<TShape> ishape;
+      std::vector<int> itype;
+      for (const auto& e : inode.inputs) {
+        ishape.emplace_back(vshape[idx.entry_id(e)]);
+        itype.emplace_back(vdtype[idx.entry_id(e)]);
+      }
+
+      OpStatePtr state = fcreate_op_state[op](
+          inode.source->attrs, vctx[i], ishape, itype);
+      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+          op, "FStatefulComputeEx", vctx[i]);
+      // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+      if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+        ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
+      } else {
+        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+            op, "FStatefulCompute", vctx[i]);
+        CHECK(fcompute != nullptr)
+            << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+            << "for stateful operator " << op->name;
+        ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
+                                                           exec_type, mutate_index);
+      }
+    } else if (is_layer_backward.get(op, false)) {
+      CHECK_GE(inode.control_deps.size(), 1);
+      uint32_t fwd_id = inode.control_deps[0];
+      CHECK(vctx[fwd_id] == vctx[i]);
+      CHECK(ret[fwd_id] != nullptr);
+      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+          op, "FStatefulComputeEx", vctx[i]);
+      // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+      if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+        ret[i] = std::make_shared<StatefulComputeExExecutor>(
+            dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
+            fcompute_ex, exec_type);
+      } else {
+        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+            op, "FStatefulCompute", vctx[i]);
+        CHECK(fcompute != nullptr)
+            << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+            << "for stateful operator " << op->name;
+        ret[i] = std::make_shared<StatefulComputeExecutor>(
+            dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
+            fcompute, exec_type, mutate_index);
+      }
     } else {
-      LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+      FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
+      FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
+      if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+        ret[i] = std::make_shared<FComputeExExecutor>(
+            inode.source->attrs, fcomp_ex, exec_type);
+      } else if (fcompute != nullptr) {
+        ret[i] = std::make_shared<FComputeExecutor>(
+            inode.source->attrs, fcompute, exec_type, mutate_index);
+      } else {
+        LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+      }
     }
   }
-}
-
-
-// pass to attach operator executors
-Graph AttachOpExecs(Graph g) {
-  const auto& idx = g.indexed_graph();
-  OpExecVector ret(idx.num_nodes());
-  for (size_t i = 0; i < idx.num_nodes(); ++i) {
-    CreateOpExecs(g, &ret, i);
-  }
   g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
   return g;
 }
diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc
index 56122cd..6818662 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -30,15 +30,12 @@
 namespace mxnet {
 namespace exec {
 
-void AttachOpResources(
-    const Graph& g,
-    const OpExecVector& op_execs,
-    size_t start_nid,
-    size_t end_nid) {
+Graph AttachOpResources(Graph g) {
   static auto& fresource =
       nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
   static auto& fresource_ex =
       nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
+  auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
   const auto& vctx = g.GetAttr<ContextVector>("context");
   const auto& vdispatch = g.GetAttr<DispatchModeVector>("dispatch_mode");
   const auto& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
@@ -46,7 +43,7 @@ void AttachOpResources(
   // Use global resource pool for each executor for now.
   std::map<Context, Resource> cached_temp;
   // Resource allocation
-  for (uint32_t nid = start_nid; nid < end_nid; ++nid) {
+  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
     const auto& inode = idx[nid];
     if (inode.source->is_variable()) continue;
     const Context &ctx = vctx[nid];
@@ -87,12 +84,7 @@ void AttachOpResources(
       requested.push_back(ResourceManager::Get()->Request(ctx, ResourceRequest::kTempSpace));
     }
   }
+  return g;
 }
-
-void AttachOpResources(const Graph& g) {
-  const auto& op_execs = g.GetAttr<OpExecVector>("op_execs");
-  AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes());
-}
-
 }  // namespace exec
 }  // namespace mxnet
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 26a2491..99b1b16 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -82,10 +82,6 @@ class OpExecutor {
   virtual engine::VarHandle var() const {
     return nullptr;
   }
-  /*! \return return operator state */
-  virtual OpStatePtr state() const {
-    return OpStatePtr();
-  }
 };
 
 /*!
@@ -107,14 +103,6 @@ using ContextVector = std::vector<Context>;
 using DevMaskVector = std::vector<int>;
 
 /*!
- * \brief create OpExecutor for a node in graph
- *
- * \param g input graph
- * \param p_ret OpExecVector for input and output
- * \param i the id of the node
- */
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
-/*!
  * \brief Attach OpExecutor to the graph attributes.
  *
  * \param g input graph
@@ -127,20 +115,12 @@ Graph AttachOpExecs(Graph g);
  * \brief Attach Resource to the OpExecVector of the graph.
  *
  * \param g input graph need to contain op_exec attribute.
- */
-void AttachOpResources(const Graph& g);
-/*!
- * \brief Attach Resource to the OpExecVector
  *
- * \param g input graph
- * \param op_execs OpExecutor vector
- * \param start_nid starting node id
- * \param end_nid end node id
+ * \return graph with new attribute "op_exec" of type OpExecVector
+ *  The fields on the OpExecVector are not yet been setup.
  */
-void AttachOpResources(const Graph& g,
-                       const OpExecVector& op_execs,
-                       size_t start_nid,
-                       size_t end_nid);
+Graph AttachOpResources(Graph g);
+
 /*!
  * \brief Discover chance of inplace addto operators.
  *  i.e. z = plus(z, source_op), and encourage it to become z += source_op.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f9..e28867d 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
   }
 
   g = AttachOpExecs(g);
-  AttachOpResources(g);
+  g = AttachOpResources(g);
   graph_ = std::move(g);
 
   if (shared_exec != nullptr) {
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index b17fae4..140b5a5 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -19,78 +19,16 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
-#include "./cached_op.h"
-#include "../executor/exec_pass.h"
-#include "../profiler/profiler.h"
-
 
 namespace mxnet {
 
 DMLC_REGISTER_PARAMETER(CachedOpConfig);
 
-struct CachedOp::GraphInfo {
-  nnvm::Graph fwd_graph;
-  nnvm::Graph full_graph;
-  std::vector<OpReqType> bwd_output_reqs;
-  std::vector<uint32_t> bwd_input_eid;
-};
-
-struct CachedOp::DynamicRuntime {
-  GraphInfo info;
-  std::vector<NDArray> buff;
-  std::vector<OpStatePtr> op_states;
-};
-
-struct CachedOp::CachedOpState {
-  CachedOpState(const Context& context_,
-                const nnvm::Graph& fwd_graph_,
-                const nnvm::Graph& full_graph_) {
-    context = context_;
-    info.fwd_graph = fwd_graph_;
-    info.full_graph = full_graph_;
-
-    size_t max_nodes = info.full_graph.indexed_graph().num_nodes();
-    size_t max_entries = info.full_graph.indexed_graph().num_node_entries();
-    info.fwd_graph.attrs["context"] = std::make_shared<dmlc::any>(
-        std::vector<Context>(info.fwd_graph.indexed_graph().num_nodes(), context));
-    info.full_graph.attrs["context"] = std::make_shared<dmlc::any>(
-        std::vector<Context>(max_nodes, context));
-
-    buff.resize(max_entries);
-    arrays.resize(max_entries);
-    array_reqs.resize(max_entries);
-    dynamic_entries.resize(max_entries, false);
-    op_states.resize(max_nodes);
-    execs.resize(max_nodes);
-    opr_segs.resize(max_nodes);
-  }
-
-  std::mutex mutex;
-  Context context;
-  GraphInfo info;
-
-  bool recording = false;
-  bool fwd_alloc = false;
-  bool bwd_alloc = false;
-  bool fwd_exec_init = false;
-  bool bwd_exec_init = false;
-
-  std::vector<NDArray> buff;
-  std::vector<NDArray*> arrays;
-  std::vector<OpReqType> array_reqs;
-
-  std::vector<OpStatePtr> op_states;
-  std::vector<std::shared_ptr<exec::OpExecutor> > execs;
-  std::vector<imperative::EngineOprSeg> opr_segs;
-
-  std::vector<bool> dynamic_entries;
-  std::multimap<size_t, NDArray> fwd_reuse_pool;
-  std::multimap<size_t, NDArray> bwd_reuse_pool;
-};
-
-CachedOp::CachedOp(
+Imperative::CachedOp::CachedOp(
     const nnvm::Symbol& sym,
-    const std::vector<std::pair<std::string, std::string> >& flags) {
+    const std::vector<std::pair<std::string, std::string> >& flags,
+    const std::vector<std::string> arg_names,
+    const std::unordered_map<std::string, std::vector<NDArray> >& params) {
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
@@ -130,22 +68,34 @@ CachedOp::CachedOp(
     fwd_graph_.attrs["forward_ref_count"] =
         std::make_shared<dmlc::any>(std::move(ref_count));
 
-    inlining_ = !config_.static_alloc &&
-        (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
+    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
   }
 
   // Set params
   {
     const auto& idx = fwd_graph_.indexed_graph();
-    if (config_.data_indices.ndim() || config_.param_indices.ndim()) {
-      CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(),
-               idx.input_nodes().size());
-    } else {
-      std::vector<uint32_t> tmp;
-      for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
-        tmp.push_back(i);
+    std::unordered_map<std::string, size_t> arg_name_to_id;
+    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+      const auto& name = idx[idx.input_nodes()[i]].source->attrs.name;
+      auto iter = params.find(name);
+      if (iter == params.end()) {
+        arg_name_to_id[name] = i;
+        continue;
+      }
+      fwd_params_idx_.push_back(i);
+      for (const auto& param : iter->second) {
+        params_[param.ctx()].emplace_back(param);
       }
-      config_.data_indices.assign(tmp.begin(), tmp.end());
+    }
+
+    CHECK_EQ(arg_name_to_id.size(), arg_names.size())
+        << "CachedOp expects " << arg_name_to_id.size()
+        << " inputs, given " << arg_names.size();
+
+    for (const auto& name : arg_names) {
+      auto iter = arg_name_to_id.find(name);
+      CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name;
+      fwd_args_idx_.push_back(iter->second);
     }
   }
 
@@ -157,14 +107,9 @@ CachedOp::CachedOp(
     }
 
     std::vector<NodeEntry> xs;
-    const auto& idx = fwd_graph_.indexed_graph();
-    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
-      auto nid = idx.input_nodes()[i];
-      if (idx.mutable_input_nodes().count(nid)) continue;
-      fwd_input_to_grad_output_[i] = xs.size();
-      xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0});
-    }
-
+    std::vector<NodePtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
+    xs.reserve(args.size());
+    for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0});
     CHECK_GT(xs.size(), 0)
         << "There are no inputs in computation graph that require gradients.";
 
@@ -180,7 +125,7 @@ CachedOp::CachedOp(
     size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
 
     full_graph_.outputs = fwd_graph_.outputs;
-    bwd_output_reqs_ = std::vector<OpReqType>(grad_graph_.outputs.size(), kWriteTo);
+    curr_grad_req_ = std::vector<bool>(grad_graph_.outputs.size(), true);
     for (const auto& i : grad_graph_.outputs) full_graph_.outputs.emplace_back(i);
     const auto& idx = full_graph_.indexed_graph();
 
@@ -224,10 +169,7 @@ CachedOp::CachedOp(
   }
 }
 
-CachedOp::~CachedOp() {
-}
-
-std::vector<nnvm::NodeEntry> CachedOp::Gradient(
+std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
     const nnvm::NodePtr& node,
     const std::vector<nnvm::NodeEntry>& ograds) {
   using namespace nnvm;
@@ -264,15 +206,13 @@ std::vector<nnvm::NodeEntry> CachedOp::Gradient(
   return ret;
 }
 
-
-bool CachedOp::SetForwardGraph(
-    GraphInfo* info,
-    const bool recording,
-    const std::vector<NDArray*>& inputs) {
+nnvm::Graph Imperative::CachedOp::GetForwardGraph(
+    const bool recording, const std::vector<NDArray*>& inputs) {
   using namespace nnvm;
   using namespace imperative;
+  std::lock_guard<std::mutex> lock(mutex_);
   CHECK_EQ(inputs.size(), num_inputs());
-  nnvm::Graph& g = info->fwd_graph;
+  nnvm::Graph& g = fwd_graph_;
 
   ShapeVector shape_inputs;
   DTypeVector dtype_inputs;
@@ -297,22 +237,18 @@ bool CachedOp::SetForwardGraph(
     g.attrs.erase("forward_mem_plan");
     g.attrs.erase("full_mem_plan");
   } else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
-    return true;
+    return g;
   }
 
   const auto& idx = g.indexed_graph();
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
+  for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
   CHECK_EQ(stypes.size(), storage.size());
   for (size_t i = 0; i < stypes.size(); i++) {
-    if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
-  }
-  for (const auto i : idx.input_nodes()) {
-    storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
-  }
-  for (size_t i = 0; i < idx.outputs().size(); ++i) {
-    storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID;
+    if (stypes[i] != kDefaultStorage)
+      storage[i] = exec::kDynamicStorageID;
   }
 
   auto mem_plan = PlanMemory(
@@ -321,50 +257,51 @@ bool CachedOp::SetForwardGraph(
   g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
       std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return false;
+  return g;
 }
 
-bool CachedOp::SetBackwardGraph(
-    GraphInfo* info,
+nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
+    const OpStatePtr& op_state,
     const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& inputs,
-    bool detect_inplace_addto) {
+    const std::vector<NDArray*>& inputs) {
   using namespace nnvm;
   using namespace imperative;
   std::lock_guard<std::mutex> lock(mutex_);
-  Context default_ctx = inputs[0]->ctx();
-  nnvm::Graph& g = info->full_graph;
-
-  if (info->bwd_output_reqs != reqs) {
-    info->bwd_output_reqs = reqs;
-    info->bwd_input_eid.clear();
+  nnvm::Graph& g = full_graph_;
+  auto& state = op_state.get_state<CachedOpState>();
+  bool req_match = true;
+  for (size_t i = 0; i < reqs.size(); ++i) {
+    if (curr_grad_req_[i] != (reqs[i] != kNullOp)) {
+      curr_grad_req_[i] = reqs[i] != kNullOp;
+      req_match = false;
+    }
+  }
+  if (!req_match) {
     g = nnvm::Graph();
     g.outputs = fwd_graph_.outputs;
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
-      if (info->bwd_output_reqs[i] == kNullOp) continue;
-      g.outputs.emplace_back(grad_graph_.outputs[i]);
+      if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
     }
-    g.attrs["context"] = std::make_shared<dmlc::any>(
-        std::vector<Context>(g.indexed_graph().num_nodes(), default_ctx));
+    bwd_input_eid_.clear();
   }
 
   const auto& idx = g.indexed_graph();
 
-  if (info->bwd_input_eid.size() != inputs.size()) {
-    info->bwd_input_eid.clear();
+  if (bwd_input_eid_.size() != inputs.size()) {
+    bwd_input_eid_.clear();
     for (const auto& i : bwd_ograd_dep_) {
       auto eid = idx.entry_id(ograd_entries_[i]);
-      info->bwd_input_eid.push_back(eid);
+      bwd_input_eid_.push_back(eid);
     }
     for (const auto& i : bwd_in_dep_) {
       auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      info->bwd_input_eid.push_back(eid);
+      bwd_input_eid_.push_back(eid);
     }
     for (const auto& i : bwd_out_dep_) {
       auto eid = idx.entry_id(idx.outputs()[i]);
-      info->bwd_input_eid.push_back(eid);
+      bwd_input_eid_.push_back(eid);
     }
-    CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
+    CHECK_EQ(inputs.size(), bwd_input_eid_.size());
   }
 
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -375,22 +312,25 @@ bool CachedOp::SetBackwardGraph(
     for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
       for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
     }
-    for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[info->bwd_input_eid[i]];
+    for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]];
     for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
     g.attrs["backward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count));
   }
 
-  auto shapes = info->fwd_graph.GetAttr<ShapeVector>("shape");
-  shapes.resize(idx.num_node_entries(), TShape());
-  auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
-  dtypes.resize(idx.num_node_entries(), -1);
-  auto stypes = info->fwd_graph.GetAttr<StorageTypeVector>("storage_type");
-  stypes.resize(idx.num_node_entries(), -1);
+  ShapeVector shapes(idx.num_node_entries(), TShape());
+  DTypeVector dtypes(idx.num_node_entries(), -1);
+  StorageTypeVector stypes(idx.num_node_entries(), -1);
+
+  for (size_t i = 0; i < num_forward_entries; ++i) {
+    shapes[i] = state.buff[i].shape();
+    dtypes[i] = state.buff[i].dtype();
+    stypes[i] = state.buff[i].storage_type();
+  }
 
   for (size_t i = 0; i < inputs.size(); ++i) {
-    shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
-    dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
-    stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
+    shapes[bwd_input_eid_[i]] = inputs[i]->shape();
+    dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
+    stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
   }
 
   std::pair<uint32_t, uint32_t> node_range, entry_range;
@@ -402,353 +342,79 @@ bool CachedOp::SetBackwardGraph(
                               node_range, entry_range);
   match &= CheckAndInferType(&g, std::move(dtypes), false,
                              node_range, entry_range);
-  exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask());
+  exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask());
   match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes),
                                     false, node_range, entry_range);
 
   if (!match) {
     g.attrs.erase("backward_mem_plan");
   } else if (g.attrs.count("backward_mem_plan")) {
-    return true;
+    return g;
   }
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
-  const auto& bwd_stypes = g.GetAttr<StorageTypeVector>("storage_type");
-  for (size_t i = 0; i < bwd_stypes.size(); i++) {
-    if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
-  }
   for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID;
   for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
   for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID;
+  for (size_t i = 0; i < stypes.size(); i++) {
+    if (stypes[i] != kDefaultStorage)
+      storage[i] = exec::kDynamicStorageID;
+  }
 
   auto mem_plan = PlanMemory(
       &g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
-      {num_forward_nodes, idx.num_nodes()},
-      {num_forward_entries, idx.num_node_entries()},
-      detect_inplace_addto);
+      {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, idx.num_node_entries()});
   g.attrs["backward_mem_plan"] = std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return false;
-}
-
-OpStatePtr CachedOp::GetCachedOpState(
-    const Context& ctx) {
-  std::lock_guard<std::mutex> lock(mutex_);
-  for (const auto& i : cached_op_states_[ctx]) {
-    // only create one state per device when not using static memory
-    if (!config_.static_alloc || i.unique()) {
-      return i;
-    }
-  }
-  auto state_ptr = OpStatePtr::Create<CachedOpState>(ctx, fwd_graph_, full_graph_);
-
-  cached_op_states_[ctx].push_back(state_ptr);
-  return state_ptr;
-}
-
-void CachedOp::StaticAllocMemory(
-    const OpStatePtr& state_ptr,
-    bool recording,
-    bool keep_fwd) {
-  using namespace nnvm;
-  using namespace imperative;
-
-  auto& state = state_ptr.get_state<CachedOpState>();
-  const auto& default_ctx = state.context;
-  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
-  const auto& idx = g.indexed_graph();
-  const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
-  const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
-      keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
-  std::vector<int> addto_entry;
-  if (g.attrs.count("addto_entry")) {
-    addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
-  }
-  size_t start_eid =
-      keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0;
-  size_t end_eid = idx.num_node_entries();
-
-  if (!keep_fwd) state.fwd_alloc = false;
-  state.bwd_alloc = false;
-  for (size_t i = start_eid; i < state.buff.size(); ++i) {
-    state.buff[i] = NDArray();
-    state.arrays[i] = &state.buff[i];
-    state.array_reqs[i] = kNullOp;
-    state.dynamic_entries[i] = false;
-  }
-
-  for (auto i : idx.input_nodes()) {
-    auto eid = idx.entry_id(i, 0);
-    if (eid >= start_eid) state.dynamic_entries[eid] = true;
-  }
-  for (auto i : idx.outputs()) {
-    auto eid = idx.entry_id(i);
-    if (eid >= start_eid) state.dynamic_entries[eid] = true;
-  }
-
-  for (size_t i = start_eid; i < end_eid; ++i) {
-    if (addto_entry.size() && addto_entry[i]) {
-      state.array_reqs[i] = kAddTo;
-    } else if (vstorage_inplace[i] >= 0) {
-      state.array_reqs[i] = kWriteInplace;
-    } else if (vstorage_inplace[i] == -2) {
-      // -2 indicate that the entry is never referenced.
-      state.array_reqs[i] = kNullOp;
-    } else {
-      state.array_reqs[i] = kWriteTo;
-    }
-  }
-
-  auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool;
-  reuse_pool = imperative::AllocateMemory(
-      g, idx, default_ctx, start_eid, end_eid, mem_plan,
-      state.arrays, &state.array_reqs, std::move(reuse_pool));
-
-  state.recording = recording;
-  if (keep_fwd) {
-    state.bwd_alloc = true;
-  } else {
-    state.fwd_alloc = true;
-  }
+  return g;
 }
 
-void CachedOp::StaticInitExec(
-    const OpStatePtr& state_ptr,
-    bool recording,
-    bool keep_fwd) {
+void Imperative::CachedOp::Forward(
+    const std::shared_ptr<CachedOp>& op_ptr,
+    const std::vector<NDArray*>& args,
+    const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
+  static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
-  auto& state = state_ptr.get_state<CachedOpState>();
-  const auto& default_ctx = state.context;
-  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
-  const auto& idx = g.indexed_graph();
-  std::vector<int> skip_plus_node;
-  if (g.attrs.count("skip_plus_node")) {
-    skip_plus_node = g.GetAttr<std::vector<int> >("skip_plus_node");
-  }
-  size_t start_nid =
-      keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0;
-  size_t end_nid = idx.num_nodes();
-
-  if (!keep_fwd) state.fwd_exec_init = false;
-  state.bwd_exec_init = false;
-
-  for (size_t i = start_nid; i < state.execs.size(); ++i) {
-    state.execs[i].reset();
-    state.opr_segs[i] = EngineOprSeg();
-  }
-
-  if (!config_.static_shape) {
-    for (size_t i = start_nid; i < end_nid; ++i) {
-      state.opr_segs[i].next_nid = i + 1;
-      state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i];
-    }
-  } else {
-    for (size_t i = start_nid; i < end_nid; ++i) {
-      exec::CreateOpExecs(g, &state.execs, i);
-    }
-    exec::AttachOpResources(g, state.execs, start_nid, end_nid);
-
-    for (size_t i = start_nid; i < end_nid; ++i) {
-      bool skip = idx[i].source->is_variable();
-      for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) {
-        skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])];
-      }
-      for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) {
-        skip = state.dynamic_entries[idx.entry_id(i, j)];
-      }
-      if (skip) continue;
-      SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
-    }
+  CHECK_EQ(args.size(), fwd_args_idx_.size())
+      << "CachedOp requires " << fwd_args_idx_.size()
+      << " inputs but got " << args.size();
 
-    size_t bulk_size = idx.num_nodes();
-    std::unordered_set<uint32_t> excludes;
-    if (recording || keep_fwd) {
-      bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
-      for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
-      for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0));
-    }
+  Context default_ctx = args[0]->ctx();
 
-    CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes,
-                      state.execs, skip_plus_node, &state.opr_segs);
-  }
 
-  if (keep_fwd) {
-    state.bwd_exec_init = true;
-  } else {
-    state.fwd_exec_init = true;
+  std::vector<NDArray*> inputs(num_inputs());
+  for (index_t i = 0; i < fwd_args_idx_.size(); ++i) {
+    inputs[fwd_args_idx_[i]] = args[i];
   }
-}
-
-void CachedOp::StaticRunOps(
-    const Context& default_ctx,
-    const nnvm::Graph& g,
-    const OpStatePtr& state_ptr,
-    size_t start_nid,
-    size_t end_nid) {
-  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-
-  bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
-  bool is_training = Imperative::Get()->is_training();
-  auto& state = state_ptr.get_state<CachedOpState>();
-  const auto& idx = g.indexed_graph();
-  const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
-  const auto& op_execs = state.execs;
-
-  std::vector<NDArray*> ndinputs, ndoutputs;
-  nnvm::ShapeVector arg_shapes;
-  nnvm::DTypeVector arg_dtypes;
-  std::vector<OpReqType> req;
+  if (fwd_params_idx_.size()) {
+    CHECK(params_.find(default_ctx) != params_.end())
+        << "CachedOp is not initialized on context " << default_ctx;
 
-  for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) {
-    if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training;
-  }
-
-  for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) {
-    const auto& opr_seg = state.opr_segs[i];
-    if (opr_seg.skip) continue;
-    if (opr_seg.opr != nullptr) {
-      Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling);
-    } else {
-      const nnvm::IndexedGraph::Node& node = idx[i];
-      if (node.source->is_variable()) continue;
-      auto num_outputs = node.source->num_outputs();
-      ndinputs.clear();
-      ndinputs.reserve(node.inputs.size());
-      for (const auto& j : node.inputs) {
-        ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
-        CHECK(!ndinputs.back()->is_none());
-      }
-      ndoutputs.clear();
-      ndoutputs.reserve(num_outputs);
-      req.clear();
-      req.reserve(num_outputs);
-      for (size_t j = 0; j < num_outputs; ++j) {
-        size_t eid = idx.entry_id(i, j);
-        ndoutputs.emplace_back(state.arrays[eid]);
-        req.push_back(state.array_reqs[eid]);
-        CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
-      }
-      const DispatchMode dispatch_mode = dispatch_modes[i];
-      if (createop.count(node.source->op())) {
-        arg_shapes.clear();
-        arg_dtypes.clear();
-        arg_shapes.reserve(ndinputs.size());
-        arg_dtypes.reserve(ndinputs.size());
-        for (size_t i = 0; i < ndinputs.size(); ++i) {
-          arg_shapes.emplace_back(ndinputs[i]->shape());
-          arg_dtypes.emplace_back(ndinputs[i]->dtype());
-        }
-        state.op_states[i] = createop[node.source->op()](
-            node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
-        Imperative::Get()->InvokeOp(
-            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
-            dispatch_mode, state.op_states[i]);
-      } else if (is_layer_backward.get(node.source->op(), false)) {
-        nnvm::Node* fwd_node = node.source->control_deps[0].get();
-        auto fwd_node_id = idx.node_id(fwd_node);
-        Imperative::Get()->InvokeOp(
-            default_ctx, node.source->attrs, ndinputs, ndoutputs,
-            req, dispatch_mode, state.op_states[fwd_node_id]);
-      } else {
-        Imperative::Get()->InvokeOp(
-            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
-            dispatch_mode);
-      }
+    for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
+      inputs[fwd_params_idx_[i]] = &params_[default_ctx][i];
     }
   }
-}
-
-OpStatePtr CachedOp::StaticForward(
-    const Context& default_ctx,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace nnvm;
-  using namespace imperative;
 
+  // Initialize
   bool recording = Imperative::Get()->is_recording();
-  auto state_ptr = GetCachedOpState(default_ctx);
-  auto& state = state_ptr.get_state<CachedOpState>();
-  std::lock_guard<std::mutex> lock(state.mutex);
-
-  bool match = SetForwardGraph(&state.info, recording, inputs);
-  match = match && state.recording == recording;
-
-  nnvm::Graph& g = state.info.fwd_graph;
+  nnvm::Graph g = GetForwardGraph(recording, inputs);
   const auto& idx = g.indexed_graph();
-  if (!state.fwd_alloc || !match)  {
-    StaticAllocMemory(state_ptr, recording, false);
-  }
-
-  if (config_.static_shape) {
-    for (auto i : config_.param_indices) {
-      auto nid = idx.input_nodes()[i];
-      if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
-        match = false;
-        auto ptr = &state.buff[idx.entry_id(nid, 0)];
-        CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
-        *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
-        state.dynamic_entries[idx.entry_id(nid, 0)] = false;
-      }
-    }
-    for (auto i : config_.data_indices) {
-      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      state.arrays[eid] = inputs[i];
-    }
-  } else {
-    for (size_t i = 0; i < num_inputs(); ++i) {
-      auto nid = idx.input_nodes()[i];
-      state.arrays[idx.entry_id(nid, 0)] = inputs[i];
-    }
-  }
-
-  if (!state.fwd_exec_init || !match) {
-    StaticInitExec(state_ptr, recording, false);
-  }
-
-  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
-  const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  size_t num_inputs = idx.input_nodes().size();
 
-  for (size_t i = 0; i < outputs.size(); ++i) {
-    auto eid = idx.entry_id(idx.outputs()[i]);
-    state.arrays[eid] = outputs[i];
-    if (!outputs[i]->is_none()) continue;
-    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
-                          shapes[eid], default_ctx, true, dtypes[eid]);
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    CHECK_EQ(inputs[i]->ctx(), default_ctx)
+        << "CachedOp requires all inputs to live on the same context. But "
+        << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx
+        << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is on "
+        << inputs[i]->ctx();
   }
 
-  StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
-
-  return recording ? state_ptr : OpStatePtr();
-}
-
-
-OpStatePtr CachedOp::DynamicForward(
-    const Context& default_ctx,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace nnvm;
-  using namespace imperative;
-
-  // Initialize
-  bool recording = Imperative::Get()->is_recording();
-  auto op_state = OpStatePtr::Create<DynamicRuntime>();
-  auto& runtime = op_state.get_state<DynamicRuntime>();
-  {
-    auto state_ptr = GetCachedOpState(default_ctx);
-    auto& state = state_ptr.get_state<CachedOpState>();
-    std::lock_guard<std::mutex> lock(state.mutex);
-    SetForwardGraph(&state.info, recording, inputs);
-    runtime.info.fwd_graph = state.info.fwd_graph;
-  }
-  nnvm::Graph& g = runtime.info.fwd_graph;
-  const auto& idx = g.indexed_graph();
-  size_t num_inputs = idx.input_nodes().size();
-  auto& buff = runtime.buff;
-  auto& states = runtime.op_states;
+  auto op_state_ptr = OpStatePtr::Create<CachedOpState>();
+  auto& cached_op_state = op_state_ptr.get_state<CachedOpState>();
+  auto& buff = cached_op_state.buff;
+  auto& states = cached_op_state.states;
 
   // Allocate entries
   states.resize(idx.num_nodes());
@@ -780,98 +446,57 @@ OpStatePtr CachedOp::DynamicForward(
   AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
-  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
-  const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
-
-  for (size_t i = 0; i < outputs.size(); ++i) {
-    auto eid = idx.entry_id(idx.outputs()[i]);
-    arrays[eid] = outputs[i];
-    if (!outputs[i]->is_none()) continue;
-    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
-                          shapes[eid], default_ctx, true, dtypes[eid]);
-  }
-
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
   if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
 
-  RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
-           std::move(ref_count), &states, dispatch_modes);
+  Imperative::Get()->RunGraph(
+      false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
+      std::move(ref_count), &states, dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
   Imperative::Get()->set_is_recording(recording);
 
-  return op_state;
-}
-
-void CachedOp::Forward(
-    const std::shared_ptr<CachedOp>& op_ptr,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<NDArray*>& outputs) {
-  static const auto cached_op = nnvm::Op::Get("_CachedOp");
-
-  CHECK_EQ(inputs.size(), num_inputs());
-
-  Context default_ctx = inputs[0]->ctx();
-
-  const auto& idx = fwd_graph_.indexed_graph();
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    CHECK_EQ(inputs[i]->ctx(), default_ctx)
-        << "CachedOp requires all inputs to live on the same context. But "
-        << idx[idx.input_nodes()[0]].source->attrs.name
-        << " is on " << default_ctx << " while "
-        << idx[idx.input_nodes()[i]].source->attrs.name
-        << " is on " << inputs[i]->ctx();
-  }
-
-  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
-
-  OpStatePtr op_state;
-  if (config_.static_alloc) {
-    op_state = StaticForward(default_ctx, inputs, outputs);
-  } else {
-    op_state = DynamicForward(default_ctx, inputs, outputs);
+  for (size_t i = 0; i < idx.num_node_entries(); ++i) {
+    if (arrays[i] == &buff[i]) continue;
+    buff[i].shape_ = arrays[i]->shape_;
+    buff[i].dtype_ = arrays[i]->dtype_;
+    buff[i].storage_type_ = arrays[i]->storage_type_;
   }
 
-  Engine::Get()->set_bulk_size(prev_bulk_size);
-
-  if (Imperative::Get()->is_recording() && !inlining_) {
+  if (recording && !inlining_) {
     nnvm::NodeAttrs attrs;
     attrs.op = cached_op;
     attrs.name = "_cachedop";
     attrs.parsed = op_ptr;
     Imperative::Get()->RecordOp(
-        std::move(attrs), inputs, outputs, op_state,
+        std::move(attrs), inputs, outputs, op_state_ptr,
         &save_inputs(), &save_outputs());
   }
 }
 
 
-void CachedOp::DynamicBackward(
+void Imperative::CachedOp::Backward(
     const bool retain_graph,
-    const OpStatePtr& op_state,
+    const OpStatePtr& state,
     const std::vector<NDArray*>& inputs,
     const std::vector<OpReqType>& reqs,
     const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
+  CHECK(!Imperative::Get()->is_recording())
+      << "CachedOp does not support higher order gradients. "
+      << "If you want to do backward with create_graph=True please "
+      << "do not use hybridize.";
 
   // Initialize
-  Context default_ctx = outputs[0]->ctx();
-  auto& runtime = op_state.get_state<DynamicRuntime>();
-  {
-    auto state_ptr = GetCachedOpState(default_ctx);
-    auto& state = state_ptr.get_state<CachedOpState>();
-    std::lock_guard<std::mutex> lock(state.mutex);
-    state.info.fwd_graph = runtime.info.fwd_graph;
-    SetBackwardGraph(&state.info, reqs, inputs);
-    runtime.info.full_graph = state.info.full_graph;
-    runtime.info.bwd_input_eid = state.info.bwd_input_eid;
-  }
-  nnvm::Graph& g = runtime.info.full_graph;
+  nnvm::Graph g = GetBackwardGraph(state, reqs, inputs);
   const auto& idx = g.indexed_graph();
-  auto& buff = runtime.buff;
-  auto& states = runtime.op_states;
+
+  auto& cached_op_state = state.get_state<CachedOpState>();
+  auto& buff = cached_op_state.buff;
+  auto& states = cached_op_state.states;
 
   size_t num_forward_outputs = fwd_graph_.outputs.size();
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -881,7 +506,7 @@ void CachedOp::DynamicBackward(
   arrays.reserve(buff.size());
   for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]);
   for (size_t i = 0; i < inputs.size(); ++i) {
-    arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
+    arrays[bwd_input_eid_[i]] = inputs[i];
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
     if (reqs[i] == kNullOp) continue;
@@ -905,14 +530,20 @@ void CachedOp::DynamicBackward(
     if (ref_count[i] == 0) array_reqs[i] = kNullOp;
   }
 
+  Context default_ctx = outputs[0]->ctx();
   const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
   AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
-  RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
-           std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
+
+  Imperative::Get()->RunGraph(
+      retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
+      std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
+
+  Engine::Get()->set_bulk_size(prev_bulk_size);
 
   if (retain_graph) {
     buff.resize(num_forward_entries);
@@ -922,99 +553,6 @@ void CachedOp::DynamicBackward(
   }
 }
 
-void CachedOp::StaticBackward(
-    const bool retain_graph,
-    const OpStatePtr& state_ptr,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace nnvm;
-  using namespace imperative;
-
-  Context default_ctx = outputs[0]->ctx();
-
-  auto& state = state_ptr.get_state<CachedOpState>();
-  std::lock_guard<std::mutex> lock(state.mutex);
-
-  bool match = SetBackwardGraph(&state.info, reqs, inputs, true);
-
-  nnvm::Graph& g = state.info.full_graph;
-  const auto& idx = g.indexed_graph();
-  auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes();
-
-  if (!state.bwd_alloc || !match) {
-    StaticAllocMemory(state_ptr, true, true);
-  }
-
-  if (config_.static_shape) {
-    for (auto i : config_.param_indices) {
-      const auto iter = fwd_input_to_grad_output_.find(i);
-      if (iter == fwd_input_to_grad_output_.end()) continue;
-      auto entry = grad_graph_.outputs[iter->second];
-      if (!idx.exist(entry.node.get())) continue;
-      auto eid = idx.entry_id(entry);
-      if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
-          !(state.array_reqs[eid] == reqs[iter->second])) {
-        match = false;
-        state.array_reqs[eid] = reqs[iter->second];
-        *state.arrays[eid] = *outputs[iter->second];
-        state.dynamic_entries[eid] = false;
-      }
-    }
-    for (auto i : config_.data_indices) {
-      const auto iter = fwd_input_to_grad_output_.find(i);
-      if (iter == fwd_input_to_grad_output_.end()) continue;
-      auto entry = grad_graph_.outputs[iter->second];
-      if (!idx.exist(entry.node.get())) continue;
-      auto eid = idx.entry_id(entry);
-      state.array_reqs[eid] = reqs[iter->second];
-      state.arrays[eid] = outputs[iter->second];
-    }
-  } else {
-    for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
-      auto entry = grad_graph_.outputs[i];
-      if (!idx.exist(entry.node.get())) continue;
-      auto eid = idx.entry_id(entry);
-      state.array_reqs[eid] = reqs[i];
-      state.arrays[eid] = outputs[i];
-    }
-  }
-
-  if (!state.bwd_exec_init || !match) {
-    StaticInitExec(state_ptr, true, true);
-  }
-
-  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
-    auto eid = state.info.bwd_input_eid[i];
-    if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
-  }
-
-  StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
-}
-
-void CachedOp::Backward(
-    const bool retain_graph,
-    const OpStatePtr& state,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace imperative;
-  CHECK(!Imperative::Get()->is_recording())
-      << "CachedOp does not support higher order gradients. "
-      << "If you want to do backward with create_graph=True please "
-      << "do not use hybridize.";
-
-  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
-
-  if (config_.static_alloc) {
-    StaticBackward(retain_graph, state, inputs, reqs, outputs);
-  } else {
-    DynamicBackward(retain_graph, state, inputs, reqs, outputs);
-  }
-
-  Engine::Get()->set_bulk_size(prev_bulk_size);
-}
-
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
deleted file mode 100644
index 60a40c5..0000000
--- a/src/imperative/cached_op.h
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_IMPERATIVE_CACHED_OP_H_
-#define MXNET_IMPERATIVE_CACHED_OP_H_
-
-#include <mxnet/imperative.h>
-#include <vector>
-#include <atomic>
-#include <utility>
-#include <string>
-#include <unordered_map>
-
-namespace mxnet {
-/*! \brief CachedOp Parameters */
-struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
-  uint32_t inline_limit;
-  uint32_t forward_bulk_size;
-  uint32_t backward_bulk_size;
-  bool static_alloc;
-  bool static_shape;
-  nnvm::Tuple<uint32_t> data_indices;
-  nnvm::Tuple<uint32_t> param_indices;
-  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
-    DMLC_DECLARE_FIELD(static_alloc)
-    .set_default(false)
-    .describe("Statically allocate memory to improve speed. "
-              "Memory usage may increase.");
-    DMLC_DECLARE_FIELD(static_shape)
-    .set_default(false)
-    .describe("Optimize for invariant input shapes between iterations. "
-              "Must also set static_alloc to True. "
-              "Change of input shapes is still allowed but slower.");
-    DMLC_DECLARE_FIELD(inline_limit)
-    .set_default(2)
-    .describe("Maximum number of operators that can be inlined.");
-    DMLC_DECLARE_FIELD(forward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during forward pass.");
-    DMLC_DECLARE_FIELD(backward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during backward pass.");
-    DMLC_DECLARE_FIELD(data_indices)
-    .set_default(nnvm::Tuple<uint32_t>())
-    .describe("Position of argument variables.");
-    DMLC_DECLARE_FIELD(param_indices)
-    .set_default(nnvm::Tuple<uint32_t>())
-    .describe("Position of parameters.");
-  }
-};
-
-class CachedOp {
- public:
-  CachedOp(
-      const nnvm::Symbol& sym,
-      const std::vector<std::pair<std::string, std::string> >& flags);
-  ~CachedOp();
-  uint32_t num_inputs() {
-    return fwd_graph_.indexed_graph().input_nodes().size();
-  }
-  uint32_t num_outputs() {
-    return fwd_graph_.outputs.size();
-  }
-  uint32_t num_backward_inputs() {
-    return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
-  }
-  std::vector<bool>& save_inputs() {
-    return save_inputs_;
-  }
-  std::vector<bool>& save_outputs() {
-    return save_outputs_;
-  }
-  const std::unordered_set<uint32_t>& mutable_input_nodes() {
-    return fwd_graph_.indexed_graph().mutable_input_nodes();
-  }
-  std::vector<nnvm::NodeEntry> Gradient(
-      const nnvm::NodePtr& node,
-      const std::vector<nnvm::NodeEntry>& ograds);
-  void Forward(
-      const std::shared_ptr<CachedOp>& op_ptr,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<NDArray*>& outputs);
-  void Backward(
-      const bool retain_graph,
-      const OpStatePtr& state,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& outputs);
-
- private:
-  struct GraphInfo;
-  struct DynamicRuntime;
-  struct CachedOpState;
-
-  OpStatePtr GetCachedOpState(const Context& ctx);
-  bool SetForwardGraph(
-      GraphInfo* info,
-      const bool recording,
-      const std::vector<NDArray*>& inputs);
-  bool SetBackwardGraph(
-      GraphInfo* info,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& inputs,
-      bool detect_inplace_addto = false);
-  OpStatePtr DynamicForward(
-      const Context& default_ctx,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<NDArray*>& outputs);
-  void DynamicBackward(
-      const bool retain_graph,
-      const OpStatePtr& op_state,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& outputs);
-  void StaticAllocMemory(
-      const OpStatePtr& state_ptr,
-      bool recording,
-      bool keep_fwd);
-  void StaticInitExec(
-      const OpStatePtr& state_ptr,
-      bool recording,
-      bool keep_fwd);
-  void StaticRunOps(
-      const Context& default_ctx,
-      const nnvm::Graph& g,
-      const OpStatePtr& state_ptr,
-      size_t start_nid,
-      size_t end_nid);
-  OpStatePtr StaticForward(
-      const Context& default_ctx,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<NDArray*>& outputs);
-  void StaticBackward(
-      const bool retain_graph,
-      const OpStatePtr& state_ptr,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& outputs);
-
-  CachedOpConfig config_;
-  nnvm::Graph fwd_graph_;
-  nnvm::Graph grad_graph_;
-  nnvm::Graph full_graph_;
-  bool inlining_;
-  std::vector<nnvm::NodeEntry> ograd_entries_;
-  std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
-  std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;
-  std::vector<bool> save_inputs_, save_outputs_;
-  std::vector<OpReqType> bwd_output_reqs_;
-
-  std::mutex mutex_;
-  std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
-};
-
-using CachedOpPtr = std::shared_ptr<CachedOp>;
-
-}  // namespace mxnet
-#endif  // MXNET_IMPERATIVE_CACHED_OP_H_
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index e165425..7caf305 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -19,7 +19,6 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
-#include "./cached_op.h"
 
 namespace mxnet {
 #if DMLC_CXX11_THREAD_LOCAL
@@ -267,6 +266,95 @@ void Imperative::RecordOp(
   }
 }
 
+void Imperative::RunGraph(
+    const bool retain_graph,
+    const nnvm::IndexedGraph& idx,
+    const std::vector<NDArray*> arrays,
+    size_t node_start, size_t node_end,
+    std::vector<OpReqType>&& array_reqs,
+    std::vector<uint32_t>&& ref_count,
+    std::vector<OpStatePtr> *p_states,
+    const DispatchModeVector &dispatch_modes) {
+  using namespace nnvm;
+  using namespace imperative;
+  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
+
+  std::vector<OpStatePtr>& states = *p_states;
+  bool recording = is_recording();
+
+  std::vector<NDArray*> ndinputs, ndoutputs;
+  ShapeVector arg_shapes;
+  DTypeVector arg_dtypes;
+  std::vector<OpReqType> req;
+
+  for (size_t i = node_start; i < node_end; ++i) {
+    const nnvm::IndexedGraph::Node& node = idx[i];
+    if (node.source->op() == nullptr) continue;
+    auto num_outputs = node.source->num_outputs();
+    ndinputs.clear();
+    ndinputs.reserve(node.inputs.size());
+    for (const auto& j : node.inputs) {
+      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
+      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
+    }
+    ndoutputs.clear();
+    ndoutputs.reserve(num_outputs);
+    req.clear();
+    req.reserve(num_outputs);
+    for (size_t j = 0; j < num_outputs; ++j) {
+      size_t eid = idx.entry_id(i, j);
+      ndoutputs.emplace_back(arrays[eid]);
+      req.push_back(array_reqs[eid]);
+      CHECK(!ndoutputs.back()->is_none());
+    }
+    const Context& ctx = ndoutputs[0]->ctx();
+    const DispatchMode dispatch_mode = dispatch_modes[i];
+    if (node.source->op() == bwd_cached_op) {
+      const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
+    } else if (createop.count(node.source->op())) {
+      arg_shapes.clear();
+      arg_dtypes.clear();
+      arg_shapes.reserve(ndinputs.size());
+      arg_dtypes.reserve(ndinputs.size());
+      for (size_t i = 0; i < ndinputs.size(); ++i) {
+        arg_shapes.emplace_back(ndinputs[i]->shape());
+        arg_dtypes.emplace_back(ndinputs[i]->dtype());
+      }
+      states[i] = createop[node.source->op()](
+          node.source->attrs, ctx, arg_shapes, arg_dtypes);
+      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
+      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
+    } else if (is_layer_backward.get(node.source->op(), false)) {
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
+               req, dispatch_mode, states[fwd_node_id]);
+      if (recording) {
+        RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
+      }
+    } else {
+      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
+      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
+    }
+
+    for (const auto& j : node.inputs) {
+      size_t eid = idx.entry_id(j);
+      --ref_count[eid];
+      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
+    }
+    for (size_t j = 0; j < ndoutputs.size(); ++j) {
+      size_t eid = idx.entry_id(i, j);
+      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
+    }
+  }
+}
+
+
 std::vector<NDArray*> Imperative::Backward(
     const std::vector<NDArray*>& outputs,
     const std::vector<NDArray*>& ograds,
diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc
deleted file mode 100644
index 464aefc..0000000
--- a/src/imperative/imperative_utils.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include "./imperative_utils.h"
-#include "./cached_op.h"
-
-namespace mxnet {
-namespace imperative {
-void RunGraph(
-    const bool retain_graph,
-    const nnvm::IndexedGraph& idx,
-    const std::vector<NDArray*> arrays,
-    size_t node_start, size_t node_end,
-    std::vector<OpReqType>&& array_reqs,
-    std::vector<uint32_t>&& ref_count,
-    std::vector<OpStatePtr> *p_states,
-    const DispatchModeVector &dispatch_modes) {
-  using namespace nnvm;
-  using namespace imperative;
-  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
-
-  const auto imp = Imperative::Get();
-
-  std::vector<OpStatePtr>& states = *p_states;
-  bool recording = imp->is_recording();
-
-  std::vector<NDArray*> ndinputs, ndoutputs;
-  ShapeVector arg_shapes;
-  DTypeVector arg_dtypes;
-  std::vector<OpReqType> req;
-
-  for (size_t i = node_start; i < node_end; ++i) {
-    const nnvm::IndexedGraph::Node& node = idx[i];
-    if (node.source->op() == nullptr) continue;
-    auto num_outputs = node.source->num_outputs();
-    ndinputs.clear();
-    ndinputs.reserve(node.inputs.size());
-    for (const auto& j : node.inputs) {
-      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
-      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
-    }
-    ndoutputs.clear();
-    ndoutputs.reserve(num_outputs);
-    req.clear();
-    req.reserve(num_outputs);
-    for (size_t j = 0; j < num_outputs; ++j) {
-      size_t eid = idx.entry_id(i, j);
-      ndoutputs.emplace_back(arrays[eid]);
-      req.push_back(array_reqs[eid]);
-      CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
-    }
-    const Context& ctx = ndoutputs[0]->ctx();
-    const DispatchMode dispatch_mode = dispatch_modes[i];
-    if (node.source->op() == bwd_cached_op) {
-      const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
-    } else if (createop.count(node.source->op())) {
-      arg_shapes.clear();
-      arg_dtypes.clear();
-      arg_shapes.reserve(ndinputs.size());
-      arg_dtypes.reserve(ndinputs.size());
-      for (size_t i = 0; i < ndinputs.size(); ++i) {
-        arg_shapes.emplace_back(ndinputs[i]->shape());
-        arg_dtypes.emplace_back(ndinputs[i]->dtype());
-      }
-      states[i] = createop[node.source->op()](
-          node.source->attrs, ctx, arg_shapes, arg_dtypes);
-      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
-      if (recording) {
-        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
-      }
-    } else if (is_layer_backward.get(node.source->op(), false)) {
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
-               req, dispatch_mode, states[fwd_node_id]);
-      if (recording) {
-        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
-      }
-    } else {
-      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
-      if (recording) {
-        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
-      }
-    }
-
-    for (const auto& j : node.inputs) {
-      size_t eid = idx.entry_id(j);
-      --ref_count[eid];
-      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
-    }
-    for (size_t j = 0; j < ndoutputs.size(); ++j) {
-      size_t eid = idx.entry_id(i, j);
-      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
-    }
-  }
-}
-
-}  // namespace imperative
-}  // namespace mxnet
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 726531d..06b7e05 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -23,7 +23,6 @@
 #include <utility>
 #include <algorithm>
 #include <vector>
-#include <map>
 #include <string>
 #include "../executor/graph_executor.h"
 #include "../executor/exec_pass.h"
@@ -39,24 +38,11 @@ namespace mxnet {
 namespace imperative {
 
 struct MemoryPlanInfo {
-  int storage_id;
-  uint32_t root;
+  uint32_t sid;
   size_t size;
   bool inplace;
 };
 
-struct EngineOprDeleter {
-  void operator()(engine::Opr* handle) {
-    Engine::Get()->DeleteOperator(handle);
-  }
-};
-
-struct EngineOprSeg {
-  bool skip;
-  size_t next_nid;
-  std::unique_ptr<engine::Opr, EngineOprDeleter> opr;
-};
-
 using MemoryPlanVector = std::vector<MemoryPlanInfo>;
 
 inline Context GetContext(const nnvm::NodeAttrs& attrs,
@@ -729,12 +715,10 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {
 
 
 inline MemoryPlanVector PlanMemory(
-    nnvm::Graph* p_g,
-    nnvm::StorageVector&& storage,
+    nnvm::Graph* p_g, nnvm::StorageVector&& storage,
     const std::vector<uint32_t>& ref_count,
     const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
-    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
-    bool detect_inplace_addto = false) {
+    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0}) {
   using namespace nnvm;
   nnvm::Graph& g = *p_g;
   const auto& idx = g.indexed_graph();
@@ -744,31 +728,31 @@ inline MemoryPlanVector PlanMemory(
   g.attrs["ref_count"] = std::make_shared<dmlc::any>(ref_count);
   g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(storage));
   g = nnvm::ApplyPass(g, "PlanMemory");
-  if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g);
 
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
-  const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
+  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  auto storage_ids = g.MoveCopyAttr<StorageVector>("storage_id");
+  auto storage_inplace = g.MoveCopyAttr<std::vector<int> >("storage_inplace_index");
   uint32_t entry_start = entry_range.first;
   uint32_t entry_end =
       entry_range.second > entry_start ? entry_range.second : idx.num_node_entries();
   MemoryPlanVector mem_plan(idx.num_node_entries());
-  std::unordered_map<int, uint32_t> sid_to_root;
+  std::unordered_map<int, uint32_t> sid_to_loc;
 
   for (uint32_t i = entry_start; i < entry_end; ++i) {
+    if (stypes[i] != kDefaultStorage) continue;
     if (storage_ids[i] < 0) {
-      mem_plan[i] = {storage_ids[i], i, 0, false};
-    } else if (!sid_to_root.count(storage_ids[i])) {
+      mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false};
+    } else if (!sid_to_loc.count(storage_ids[i])) {
       CHECK_LT(storage_inplace[i], 0);
-      sid_to_root[storage_ids[i]] = i;
-      mem_plan[i] = {storage_ids[i], i,
-                     mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
-                     false};
+      sid_to_loc[storage_ids[i]] = i;
+      mem_plan[i].sid = i;
+      mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size();
     } else {
-      uint32_t root = sid_to_root[storage_ids[i]];
-      mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0};
-      mem_plan[root].size = std::max(mem_plan[root].size,
+      uint32_t loc = sid_to_loc[storage_ids[i]];
+      mem_plan[i] = {loc, 0, storage_inplace[i] >= 0};
+      mem_plan[loc].size = std::max(mem_plan[loc].size,
           mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size());
     }
   }
@@ -777,213 +761,39 @@ inline MemoryPlanVector PlanMemory(
 }
 
 
-inline std::multimap<size_t, NDArray> AllocateMemory(
-    const nnvm::Graph& g,
-    const nnvm::IndexedGraph& idx,
-    const Context& default_ctx,
-    const uint32_t entry_start, const uint32_t entry_end,
-    const MemoryPlanVector& mem_plan,
-    const std::vector<NDArray*>& arrays,
-    std::vector<OpReqType> *array_reqs,
-    std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
+inline void AllocateMemory(const nnvm::Graph& g,
+                    const nnvm::IndexedGraph& idx,
+                    const Context& default_ctx,
+                    const uint32_t entry_start, const uint32_t entry_end,
+                    const MemoryPlanVector& mem_plan,
+                    const std::vector<NDArray*>& arrays,
+                    std::vector<OpReqType> *array_reqs) {
   using namespace nnvm;
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
 
-  std::multimap<size_t, NDArray> new_pool;
-
   for (uint32_t i = entry_start; i < entry_end; ++i) {
-    if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
-    CHECK(arrays[i]->is_none());
-    if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
-      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
-                           shapes[i], default_ctx, true, dtypes[i]);
-      continue;
-    }
-    CHECK_EQ(stypes[i], kDefaultStorage);
-    if (mem_plan[i].root == i) {
-      CHECK_GT(mem_plan[i].size, 0);
-      auto iter = pool.lower_bound(mem_plan[i].size);
-      if (iter != pool.end()) {
-        *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
-        new_pool.insert(*iter);
-        pool.erase(iter);
-      } else {
+    if (!arrays[i]->is_none()) continue;
+    if (stypes[i] == kDefaultStorage) {
+      if (mem_plan[i].sid == i) {
+        CHECK_GT(mem_plan[i].size, 0);
         NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
                      default_ctx, true, mshadow::kUint8);
         *arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
-        new_pool.insert({mem_plan[i].size, buff});
-      }
-    } else {
-      CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
-      *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
-      if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
-        array_reqs->at(i) = kWriteInplace;
-      }
-    }
-  }
-
-  return new_pool;
-}
-
-inline void SetupOpExec(
-    const nnvm::Graph& g,
-    size_t nid,
-    const std::shared_ptr<exec::OpExecutor>& exec,
-    const std::vector<NDArray*> arrays,
-    const std::vector<OpReqType> array_reqs) {
-  const auto& idx = g.indexed_graph();
-  const auto& inode = idx[nid];
-  CHECK_EQ(exec->in_array.size(), 0U);
-  CHECK_EQ(exec->out_array.size(), 0U);
-  for (const auto& e : inode.inputs) {
-    CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name;
-    exec->in_array.push_back(*arrays[idx.entry_id(e)]);
-  }
-  for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
-    uint32_t eid = idx.entry_id(nid, index);
-    CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name;
-    exec->out_array.push_back(*arrays[eid]);
-    exec->req.push_back(array_reqs[eid]);
-  }
-
-  exec->Setup();
-}
-
-inline Engine::OprHandle CreateEngineOp(
-    const Context& default_ctx,
-    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs) {
-  CHECK_GT(execs.size(), 0);
-  std::vector<Engine::VarHandle> use_vars, mutate_vars;
-
-  for (const auto& exec : execs) {
-    CHECK_GT(exec->out_array.size(), 0);
-    CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);
-
-    // the variables
-    for (const auto& nd : exec->in_array) {
-      use_vars.push_back(nd.var());
-    }
-    for (auto& r : exec->op_ctx.requested) {
-      mutate_vars.push_back(r.var);
-    }
-    for (auto& nd : exec->out_array) {
-      mutate_vars.push_back(nd.var());
-    }
-    if (exec->var() != nullptr) {
-      mutate_vars.push_back(exec->var());
-    }
-  }
-
-  // dedup vars
-  Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
-  bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
-  bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;
-
-  auto exec_fun = [execs, is_async, is_gpu] (
-      RunContext ctx, Engine::CallbackOnComplete on_complete) {
-    if (is_async) {
-      execs[0]->op_ctx.async_on_complete = on_complete;
-    }
-    for (const auto& exec : execs) exec->Run(ctx, is_gpu);
-    // call on complete only if it is async op
-    if (!is_async) {
-      if (is_gpu) {
-      #if MXNET_USE_CUDA
-        // Wait GPU kernel to finish.
-        ctx.get_stream<gpu>()->Wait();
-      #else
-        LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
-      #endif
-      }
-      on_complete();
-    }
-  };
-
-  return Engine::Get()->NewOperator(
-      exec_fun, use_vars, mutate_vars, FnProperty::kNormal);
-}
-
-inline void CreateEngineOpSeg(
-    const nnvm::IndexedGraph& idx,
-    const Context default_ctx,
-    const size_t start_nid,
-    const size_t end_nid,
-    const size_t bulk_size,
-    const std::unordered_set<uint32_t>& excludes,
-    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
-    const std::vector<int> skip_plus_node,
-    std::vector<EngineOprSeg> *opr_segs) {
-  size_t seg_start = start_nid;
-  std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
-  for (size_t nid = start_nid; nid < end_nid; ++nid) {
-    const auto& node = idx[nid];
-    if (node.source->is_variable()) continue;
-    if (skip_plus_node.size() && skip_plus_node[nid]) continue;
-    auto& exec = execs[nid];
-    bool is_async = exec->exec_type() != ExecType::kSync;
-    bool valid = exec->out_array.size() > 0;
-
-    // Stop at async nodes and invalid node (due to input/output is not allocated)
-    bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
-    for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
-      if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
-    }
-    auto num_outputs = node.source->num_outputs();
-    for (size_t i = 0; i < num_outputs && !stop; ++i) {
-      if (excludes.count(idx.entry_id(nid, i))) stop = true;
-    }
-
-    // Create opr segment for previous nodes.
-    if (stop && nid > seg_start) {
-      auto& seg = (*opr_segs)[seg_start];
-      if (seg_execs.size()) {
-        seg = EngineOprSeg{false, nid};
-        seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
       } else {
-        seg = EngineOprSeg{true, nid, nullptr};
+        *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]);
+        if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
+          array_reqs->at(i) = kWriteInplace;
+        }
       }
-      seg_start = nid;
-      seg_execs.clear();
-    }
-
-    seg_execs.push_back(exec);
-
-    auto& seg = (*opr_segs)[nid];
-    if (is_async) {
-      seg = EngineOprSeg{false, nid + 1};
-      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
-      seg_execs.clear();
-      seg_start = nid + 1;
-    } else if (!valid) {
-      seg = EngineOprSeg{false, nid + 1, nullptr};
-      seg_execs.clear();
-      seg_start = nid + 1;
-    }
-  }
-  // The last segment
-  if (end_nid > seg_start) {
-    auto& seg = (*opr_segs)[seg_start];
-    if (seg_execs.size()) {
-      seg = EngineOprSeg{false, end_nid};
-      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
     } else {
-      seg = EngineOprSeg{true, end_nid, nullptr};
+      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
+                           shapes[i], default_ctx, true, dtypes[i]);
     }
   }
 }
 
-
-void RunGraph(const bool retain_graph,
-              const nnvm::IndexedGraph& idx,
-              const std::vector<NDArray*> arrays,
-              size_t node_start, size_t node_end,
-              std::vector<OpReqType>&& array_reqs,
-              std::vector<uint32_t>&& ref_count,
-              std::vector<OpStatePtr> *p_states,
-              const DispatchModeVector &dispatch_modes);
-
 }  // namespace imperative
 }  // namespace mxnet
 
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 5701a5d..451fde2 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -22,7 +22,6 @@ from mxnet.test_utils import assert_almost_equal
 from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from common import setup_module, with_seed, assertRaises, teardown
 import numpy as np
-from numpy.testing import assert_array_equal
 from nose.tools import raises, assert_raises
 from copy import deepcopy
 import warnings
@@ -1125,6 +1124,7 @@ def test_hybrid_multi_context():
     net.hybridize()
     net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()
 
+
 @with_seed()
 def test_zero_grad():
     data = mx.nd.random.uniform(shape=(3,3))
@@ -1137,60 +1137,6 @@ def test_zero_grad():
     grad = net.collect_params()['test_zero_grad_weight'].grad()
     assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
 
-def check_hybrid_static_memory(**kwargs):
-    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
-    x.attach_grad()
-
-    net1 = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
-    net2 = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
-    net2.hybridize(**kwargs)
-    net1(x)
-    net2(x)
-
-    def test(net, x):
-        with mx.autograd.record():
-            y = net(x) + net(x)
-            y.backward()
-
-        grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}
-
-        return y, grads
-
-    y1, grads1 = test(net1, x)
-    y2, grads2 = test(net2, x)
-
-    assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
-    for key in grads1:
-        assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
-
-def test_hybrid_static_memory():
-    check_hybrid_static_memory()
-    check_hybrid_static_memory(static_alloc=True)
-    check_hybrid_static_memory(static_alloc=True, static_shape=True)
-
-def check_hybrid_static_memory_switching(**kwargs):
-    net = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, ctx=mx.context.current_context())
-    net.hybridize(**kwargs)
-
-    x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
-    net(x)
-    with mx.autograd.record():
-        y = net(x)
-        y.backward()
-    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
-    net(x)
-    with mx.autograd.record():
-        y = net(x)
-        y.backward()
-    mx.nd.waitall()
-
-def test_hybrid_static_memory_switching():
-    check_hybrid_static_memory_switching()
-    check_hybrid_static_memory_switching(static_alloc=True)
-    check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)
 
 @with_seed()
 def test_hook():
@@ -1285,17 +1231,6 @@ def test_legacy_save_params():
     model.load_params('test.params', ctx=mx.cpu())
 
 
-def test_hybrid_static_memory_recording():
-    net = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, ctx=mx.context.current_context())
-    net.hybridize(static_alloc=True)
-
-    x = mx.nd.random.uniform(shape=(1, 3, 32, 32))
-    with mx.autograd.record(True):
-        net(x)
-    net(x)
-
-
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
marcoabreu@apache.org.