You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/06/16 01:27:40 UTC

[incubator-mxnet] 01/01: Revert "Static alloc for hybridblock (#11313)"

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch revert-11313-static
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 24f14f7fdbef427aab25738608267c57de645987
Author: Marco de Abreu <ma...@users.noreply.github.com>
AuthorDate: Fri Jun 15 18:27:32 2018 -0700

    Revert "Static alloc for hybridblock (#11313)"
    
    This reverts commit 5431e12f11fc5446f6ec2a25098b4e3b67ee7eb3.
---
 include/mxnet/c_api.h                   |   5 +
 include/mxnet/imperative.h              |  89 ++++
 include/mxnet/ndarray.h                 |   8 -
 include/mxnet/op_attr_types.h           |  33 +-
 python/mxnet/_ctypes/ndarray.py         |  16 +-
 python/mxnet/gluon/block.py             |  74 ++--
 src/c_api/c_api_ndarray.cc              |  26 +-
 src/engine/threaded_engine.cc           |   3 +-
 src/executor/attach_op_execs_pass.cc    | 165 ++++---
 src/executor/attach_op_resource_pass.cc |  16 +-
 src/executor/exec_pass.h                |  28 +-
 src/executor/graph_executor.cc          |   2 +-
 src/imperative/cached_op.cc             | 750 ++++++--------------------------
 src/imperative/cached_op.h              | 174 --------
 src/imperative/imperative.cc            |  90 +++-
 src/imperative/imperative_utils.cc      | 120 -----
 src/imperative/imperative_utils.h       | 256 ++---------
 tests/python/unittest/test_gluon.py     |  67 +--
 18 files changed, 523 insertions(+), 1399 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 4dd858a..55c26bc 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -987,6 +987,11 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
                                  int num_flags,
                                  const char** keys,
                                  const char** vals,
+                                 int num_inputs,
+                                 const char** input_names,
+                                 int num_params,
+                                 const char** param_names,
+                                 NDArrayHandle* params,
                                  CachedOpHandle *out);
 /*!
  * \brief free cached operator
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 7ea60df..758ce85 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,6 +35,23 @@
 #include "./ndarray.h"
 
 namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
+  uint32_t inline_limit;
+  uint32_t forward_bulk_size;
+  uint32_t backward_bulk_size;
+  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
+    DMLC_DECLARE_FIELD(inline_limit)
+    .set_default(2)
+    .describe("Maximum number of operators that can be inlined.");
+    DMLC_DECLARE_FIELD(forward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during forward pass.");
+    DMLC_DECLARE_FIELD(backward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during backward pass.");
+  }
+};
 /*! \brief runtime functions for NDArray */
 class Imperative {
  public:
@@ -77,6 +94,67 @@ class Imperative {
              && info.out_grads.size() == 1;
     }
   };
+  class CachedOp {
+   public:
+    CachedOp(
+        const nnvm::Symbol& sym,
+        const std::vector<std::pair<std::string, std::string> >& flags,
+        const std::vector<std::string> arg_names,
+        const std::unordered_map<std::string, std::vector<NDArray> >& params);
+    uint32_t num_inputs() {
+      return fwd_graph_.indexed_graph().input_nodes().size();
+    }
+    uint32_t num_outputs() {
+      return fwd_graph_.outputs.size();
+    }
+    uint32_t num_backward_inputs() {
+      return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
+    }
+    std::vector<bool>& save_inputs() {
+      return save_inputs_;
+    }
+    std::vector<bool>& save_outputs() {
+      return save_outputs_;
+    }
+    const std::unordered_set<uint32_t>& mutable_input_nodes() {
+      return fwd_graph_.indexed_graph().mutable_input_nodes();
+    }
+    nnvm::Graph GetForwardGraph(const bool recording,
+                                const std::vector<NDArray*>& inputs);
+    nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
+                                 const std::vector<OpReqType>& reqs,
+                                 const std::vector<NDArray*>& inputs);
+    std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
+                                          const std::vector<nnvm::NodeEntry>& ograds);
+    void Forward(const std::shared_ptr<CachedOp>& op_ptr,
+                 const std::vector<NDArray*>& args,
+                 const std::vector<NDArray*>& outputs);
+    void Backward(const bool retain_graph,
+                  const OpStatePtr& state,
+                  const std::vector<NDArray*>& inputs,
+                  const std::vector<OpReqType>& reqs,
+                  const std::vector<NDArray*>& outputs);
+
+   private:
+    struct CachedOpState {
+      std::vector<NDArray> buff;
+      std::vector<OpStatePtr> states;
+    };
+    std::mutex mutex_;
+    CachedOpConfig config_;
+    nnvm::Graph fwd_graph_;
+    nnvm::Graph grad_graph_;
+    nnvm::Graph full_graph_;
+    std::unordered_map<Context, std::vector<NDArray> > params_;
+    bool inlining_;
+    std::vector<nnvm::NodeEntry> ograd_entries_;
+    std::vector<bool> curr_grad_req_;
+    std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
+    std::vector<uint32_t> fwd_args_idx_;
+    std::vector<uint32_t> fwd_params_idx_;
+    std::vector<uint32_t> bwd_input_eid_;
+    std::vector<bool> save_inputs_, save_outputs_;
+  };
   /*! \brief whether operator recording is on. */
   bool is_training() const {
     return is_train_;
@@ -144,6 +222,15 @@ class Imperative {
       uint32_t num_inputs, uint32_t num_outputs,
       std::vector<bool> *p_save_inputs,
       std::vector<bool> *p_save_outputs);
+  void RunGraph(
+      const bool retain_graph,
+      const nnvm::IndexedGraph& idx,
+      const std::vector<NDArray*> arrays,
+      size_t node_start, size_t node_end,
+      std::vector<OpReqType>&& array_reqs,
+      std::vector<uint32_t>&& ref_count,
+      std::vector<OpStatePtr> *p_states,
+      const DispatchModeVector& dispatch_modes);
   /*! \brief indicate whether is training. */
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local bool is_train_;
@@ -160,5 +247,7 @@ class Imperative {
   int backward_bulk_size_{0};
 };
 
+using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
+
 }  // namespace mxnet
 #endif  // MXNET_IMPERATIVE_H_
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index ae96fd8..e243eb7 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -155,14 +155,6 @@ class NDArray {
     return byte_offset_ > 0 || shape() != ptr_->storage_shape;
   }
 
-  /* \brief Check whether the two arrays are the same array */
-  inline bool IsSame(const NDArray& other) {
-    return ptr_ == other.ptr_ &&
-        shape_ == other.shape_ &&
-        byte_offset_ == other.byte_offset_ &&
-        dtype_ == other.dtype_;
-  }
-
   /*!
    * \return the shape of current NDArray.
    */
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index f4694ef..3969d84 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -126,36 +126,25 @@ class OpStatePtr {
   template<typename T, typename... Args>
   static OpStatePtr Create(Args&&... args) {
     OpStatePtr ret;
-    auto state = new T(std::forward<Args>(args)...);
-    auto var = Engine::Get()->NewVariable();
-    ret.ptr_.reset(
-      new OpState(var, state),
-      [](OpState* p) {
-        Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
-        delete reinterpret_cast<T*>(p->state);
-        delete p;
-      });
+    ret.ptr_ = std::make_shared<OpState>();
+    ret.ptr_->var_ = Engine::Get()->NewVariable();
+    ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
 
     return ret;
   }
   /* \brief Get engine variable associated with this state */
   engine::VarHandle get_var() const {
-    return ptr_->var;
+    return ptr_->var_;
   }
   /* \brief Get state of type T */
   template<typename T>
   T& get_state() const {
-    return *reinterpret_cast<T*>(ptr_->state);
+    return dmlc::get<T>(ptr_->state_);
   }
   /* \brief clear state */
   void reset() {
     ptr_.reset();
   }
-  /* \brief checks whether the managed object is managed only by the current
-            OpStatePtr instance */
-  bool unique() const {
-    return ptr_.unique();
-  }
   /* \brief Whether state is empty */
   explicit operator bool() const {
     return ptr_ ? true : false;
@@ -164,12 +153,16 @@ class OpStatePtr {
  private:
   /* \brief state structure */
   struct OpState {
-    engine::VarHandle var;
-    void* state;
-
-    OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
+    OpState() {}
     OpState(const OpState& other) = delete;
     OpState& operator=(const OpState& other) = delete;
+
+    ~OpState() {
+      Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
+    }
+
+    engine::VarHandle var_;
+    dmlc::any state_;
   };
   /* \brief shared pointer to state */
   std::shared_ptr<OpState> ptr_;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index f324545..d2cae0c 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,14 +105,28 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 class CachedOp(object):
     """Cached operator handle."""
     __slots__ = ["handle"]
-    def __init__(self, sym, flags=()):
+    def __init__(self, sym, flags=(), inputs=None, params=None):
         self.handle = CachedOpHandle()
+        param_names = []
+        param_arrays = []
+        if inputs is None:
+            assert params is None, "When inputs is None params must also be None."
+            inputs = sym.list_inputs()
+        elif params is not None:
+            for name, arrs in params.items():
+                param_arrays.extend(arrs)
+                param_names.extend([name] * len(arrs))
 
         check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
             len(flags),
             c_str_array([key for key, _ in flags]),
             c_str_array([str(val) for _, val in flags]),
+            len(inputs),
+            c_str_array(inputs),
+            len(param_names),
+            c_str_array(param_names),
+            c_handle_array(param_arrays),
             ctypes.byref(self.handle)))
 
     def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 293fafa..3b97c05 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -502,16 +502,8 @@ class Block(object):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
-        static_alloc : bool, default False
-            Statically allocate memory to improve speed. Memory usage may increase.
-        static_shape : bool, default False
-            Optimize for invariant input shapes between iterations. Must also
-            set static_alloc to True. Change of input shapes is still allowed
-            but slower.
-        forward_bulk_size : int, default 15
-            Segment size of bulk execution during forward pass.
-        backward_bulk_size : int, default 15
-            Segment size of bulk execution during backward pass.
+        **kwargs : string
+            Additional flags for hybridized operator.
         """
         for cld in self._children.values():
             cld.hybridize(active, **kwargs)
@@ -704,7 +696,7 @@ class HybridBlock(Block):
         self._out_format = None
         self._in_format = None
         self._active = False
-        self._flags = []
+        self._flags = {}
 
     def __setattr__(self, name, value):
         """Registers parameters."""
@@ -731,43 +723,39 @@ class HybridBlock(Block):
         return self._cached_graph
 
     def _build_cache(self, *args):
-        data, out = self._get_graph(*args)
-        data_names = {data.name : i for i, data in enumerate(data)}
-        params = self.collect_params()
-        input_names = out.list_inputs()
+        inputs, out = self._get_graph(*args)
+        input_names = [i.name for i in inputs]
 
+        params = self.collect_params()
         param_names = set(params.keys())
-        expected_names = set(input_names)
+        expected_names = set(out.list_inputs())
         for name in expected_names:
-            assert name in param_names or name in data_names, \
+            assert name in param_names or name in input_names, \
                 "Unknown input to HybridBlock: %s"%name
 
-        used_data_names = [i for i in data_names if i in expected_names]
-        if len(used_data_names) != len(data_names):
-            unused = ', '.join(['%d-th'%i for name, i in data_names.items()
+        used_input_names = [i for i in input_names if i in expected_names]
+        if len(used_input_names) != len(input_names):
+            unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
                                 if name not in expected_names])
             warnings.warn("The %s input to HybridBlock is not used by any "
                           "computation. Is this intended?"%unused, stacklevel=4)
 
-        used_param_names = [i for i in param_names if i in expected_names]
+        used_param_names = set(i for i in param_names if i in expected_names)
         if len(used_param_names) != len(param_names):
-            unused = ', '.join(list(param_names - set(used_param_names)))
+            unused = ', '.join(list(param_names - used_param_names))
             warnings.warn("Parameter %s is not used by any computation. "
                           "Is this intended?"%unused, stacklevel=4)
 
-        data_indices = []
-        param_indices = []
-        self._cached_op_args = []
-        for i, name in enumerate(input_names):
-            if name in data_names:
-                data_indices.append(i)
-                self._cached_op_args.append((True, data_names[name]))
-            else:
-                param_indices.append(i)
-                self._cached_op_args.append((False, params[name]))
-        flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
-                self._flags
-        self._cached_op = ndarray.CachedOp(out, flags)
+        used_params = {k: params[k] for k in used_param_names}
+        try:
+            param_dict = {k: v.list_data() for k, v in used_params.items()}
+        except DeferredInitializationError:
+            self._deferred_infer_shape(*args)
+            for i in used_params.values():
+                i._finish_deferred_init()
+            param_dict = {k: v.list_data() for k, v in used_params.items()}
+
+        self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)
 
     def _deferred_infer_shape(self, *args):
         try:
@@ -783,19 +771,7 @@ class HybridBlock(Block):
 
         args, fmt = _flatten(args, "input")
         assert fmt == self._in_format, "Invalid input format"
-        try:
-            cargs = [args[i] if is_arg else i.data()
-                     for is_arg, i in self._cached_op_args]
-        except DeferredInitializationError:
-            self._deferred_infer_shape(*args)
-            cargs = []
-            for is_arg, i in self._cached_op_args:
-                if is_arg:
-                    cargs.append(args[i])
-                else:
-                    i._finish_deferred_init()
-                    cargs.append(i.data())
-        out = self._cached_op(*cargs)
+        out = self._cached_op(*args)
         if isinstance(out, NDArray):
             out = [out]
         return _regroup(out, self._out_format)[0]
@@ -816,7 +792,7 @@ class HybridBlock(Block):
 
     def hybridize(self, active=True, **kwargs):
         self._active = active
-        self._flags = list(kwargs.items())
+        self._flags = kwargs.items()
         self._clear_cached_op()
         if active and self._forward_hooks or self._forward_pre_hooks:
             warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 34bd4b2..9aabe04 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -36,7 +36,6 @@
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
 #include "../imperative/imperative_utils.h"
-#include "../imperative/cached_op.h"
 
 using namespace mxnet;
 
@@ -161,8 +160,12 @@ int MXCreateCachedOp(SymbolHandle handle,
   std::vector<std::string> input_names;
   input_names.reserve(inputs.size());
   for (const auto& i : inputs) input_names.push_back(i->attrs.name);
-  *out = new CachedOpPtr(new CachedOp(
-      *sym, std::vector<std::pair<std::string, std::string> >()));
+  *out = new std::shared_ptr<Imperative::CachedOp>(
+      new Imperative::CachedOp(
+        *sym,
+        std::vector<std::pair<std::string, std::string> >(),
+        input_names,
+        std::unordered_map<std::string, std::vector<NDArray> >()));
   API_END();
 }
 
@@ -170,6 +173,11 @@ int MXCreateCachedOpEx(SymbolHandle handle,
                        int num_flags,
                        const char** keys,
                        const char** vals,
+                       int num_args,
+                       const char** arg_names,
+                       int num_params,
+                       const char** param_names,
+                       NDArrayHandle* params,
                        CachedOpHandle *out) {
   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
 
@@ -178,7 +186,17 @@ int MXCreateCachedOpEx(SymbolHandle handle,
   for (int i = 0; i < num_flags; ++i) {
     flags.push_back({keys[i], vals[i]});
   }
-  *out = new CachedOpPtr(new CachedOp(*sym, flags));
+  std::vector<std::string> args;
+  for (int i = 0; i < num_args; ++i) {
+    args.push_back(arg_names[i]);
+  }
+  std::unordered_map<std::string, std::vector<NDArray> > param_dict;
+  for (int i = 0; i < num_params; ++i) {
+    param_dict[param_names[i]].emplace_back(
+        *reinterpret_cast<NDArray*>(params[i]));
+  }
+  *out = new std::shared_ptr<Imperative::CachedOp>(
+      new Imperative::CachedOp(*sym, flags, args, param_dict));
   API_END();
 }
 
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index e70cc19..dc0436e 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -278,8 +278,6 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
 }
 
 void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
-  BulkFlush();
-
   ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
   OprBlock* opr_block = OprBlock::New();
   opr_block->opr = threaded_opr;
@@ -325,6 +323,7 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
         << device_count_;
   }
 #endif
+  BulkFlush();
   ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
   opr->temporary = true;
   const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 72919d9..697e486 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -134,10 +134,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
     return state_.get_var();
   }
 
-  OpStatePtr state() const override {
-    return state_;
-  }
-
   explicit StatefulComputeExecutor(const OpStatePtr& state,
                                    const FStatefulCompute& fcompute,
                                    ExecType exec_type,
@@ -146,6 +142,7 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
         state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
+  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulCompute fcompute_;
   ExecType exec_type_;
@@ -173,16 +170,13 @@ class StatefulComputeExExecutor : public OpExecutor {
     return state_.get_var();
   }
 
-  OpStatePtr state() const override {
-    return state_;
-  }
-
   explicit StatefulComputeExExecutor(const OpStatePtr& state,
                                      const FStatefulComputeEx& fcompute,
                                      ExecType exec_type)
       : state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
+  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulComputeEx fcompute_;
   ExecType exec_type_;
@@ -247,15 +241,16 @@ class FComputeExExecutor : public OpExecutor {
   ExecType exec_type_;
 };
 
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
+// pass to attach operator executors
+Graph AttachOpExecs(Graph g) {
   using nnvm::DTypeVector;
   using nnvm::ShapeVector;
   using nnvm::FMutateInputs;
 
-  static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
-  static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
-  static auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
+  auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
+  auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+  auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
 
   const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
   const auto& vshape = g.GetAttr<ShapeVector>("shape");
@@ -264,88 +259,82 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
 
   // get the graph
   const auto& idx = g.indexed_graph();
-  OpExecVector& ret = *p_ret;
+  std::vector<std::shared_ptr<OpExecutor> > ret(idx.num_nodes());
 
   // initialize the nodes
-  const auto& inode = idx[i];
-  if (inode.source->is_variable()) return;
-  const nnvm::Op *op = inode.source->op();
-  ExecType exec_type = ExecType::kSync;
-  std::vector<uint32_t> mutate_index;
-  if (fmutate_inputs.count(op)) {
-    mutate_index = fmutate_inputs[op](inode.source->attrs);
-  }
-  if (fexec_type.count(op)) {
-    exec_type = fexec_type[op](inode.source->attrs);
-  }
-  CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
-  if (fcreate_op_state.count(op)) {
-    std::vector<TShape> ishape;
-    std::vector<int> itype;
-    for (const auto& e : inode.inputs) {
-      ishape.emplace_back(vshape[idx.entry_id(e)]);
-      itype.emplace_back(vdtype[idx.entry_id(e)]);
-    }
-
-    OpStatePtr state = fcreate_op_state[op](
-        inode.source->attrs, vctx[i], ishape, itype);
-    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-        op, "FStatefulComputeEx", vctx[i]);
-    // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
-    if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-      ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
-    } else {
-      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-          op, "FStatefulCompute", vctx[i]);
-      CHECK(fcompute != nullptr)
-          << "One of FStatefulCompute and FStatefulComputeEx must be registered "
-          << "for stateful operator " << op->name;
-      ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
-                                                         exec_type, mutate_index);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    const auto& inode = idx[i];
+    if (inode.source->is_variable()) continue;
+    const nnvm::Op *op = inode.source->op();
+    ExecType exec_type = ExecType::kSync;
+    std::vector<uint32_t> mutate_index;
+    if (fmutate_inputs.count(op)) {
+      mutate_index = fmutate_inputs[op](inode.source->attrs);
     }
-  } else if (is_layer_backward.get(op, false)) {
-    CHECK_GE(inode.control_deps.size(), 1);
-    uint32_t fwd_id = inode.control_deps[0];
-    CHECK(vctx[fwd_id] == vctx[i]);
-    CHECK(ret[fwd_id] != nullptr);
-    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-        op, "FStatefulComputeEx", vctx[i]);
-    // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
-    if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-      ret[i] = std::make_shared<StatefulComputeExExecutor>(
-          ret[fwd_id].get()->state(), fcompute_ex, exec_type);
-    } else {
-      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-          op, "FStatefulCompute", vctx[i]);
-      CHECK(fcompute != nullptr)
-          << "One of FStatefulCompute and FStatefulComputeEx must be registered "
-          << "for stateful operator " << op->name;
-      ret[i] = std::make_shared<StatefulComputeExecutor>(
-          ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
+    if (fexec_type.count(op)) {
+      exec_type = fexec_type[op](inode.source->attrs);
     }
-  } else {
-    FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
-    FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
-    if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-      ret[i] = std::make_shared<FComputeExExecutor>(
-          inode.source->attrs, fcomp_ex, exec_type);
-    } else if (fcompute != nullptr) {
-      ret[i] = std::make_shared<FComputeExecutor>(
-          inode.source->attrs, fcompute, exec_type, mutate_index);
+    CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
+    if (fcreate_op_state.count(op)) {
+      std::vector<TShape> ishape;
+      std::vector<int> itype;
+      for (const auto& e : inode.inputs) {
+        ishape.emplace_back(vshape[idx.entry_id(e)]);
+        itype.emplace_back(vdtype[idx.entry_id(e)]);
+      }
+
+      OpStatePtr state = fcreate_op_state[op](
+          inode.source->attrs, vctx[i], ishape, itype);
+      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+          op, "FStatefulComputeEx", vctx[i]);
+      // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+      if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+        ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
+      } else {
+        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+            op, "FStatefulCompute", vctx[i]);
+        CHECK(fcompute != nullptr)
+            << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+            << "for stateful operator " << op->name;
+        ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
+                                                           exec_type, mutate_index);
+      }
+    } else if (is_layer_backward.get(op, false)) {
+      CHECK_GE(inode.control_deps.size(), 1);
+      uint32_t fwd_id = inode.control_deps[0];
+      CHECK(vctx[fwd_id] == vctx[i]);
+      CHECK(ret[fwd_id] != nullptr);
+      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+          op, "FStatefulComputeEx", vctx[i]);
+      // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+      if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+        ret[i] = std::make_shared<StatefulComputeExExecutor>(
+            dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
+            fcompute_ex, exec_type);
+      } else {
+        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+            op, "FStatefulCompute", vctx[i]);
+        CHECK(fcompute != nullptr)
+            << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+            << "for stateful operator " << op->name;
+        ret[i] = std::make_shared<StatefulComputeExecutor>(
+            dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
+            fcompute, exec_type, mutate_index);
+      }
     } else {
-      LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+      FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
+      FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
+      if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+        ret[i] = std::make_shared<FComputeExExecutor>(
+            inode.source->attrs, fcomp_ex, exec_type);
+      } else if (fcompute != nullptr) {
+        ret[i] = std::make_shared<FComputeExecutor>(
+            inode.source->attrs, fcompute, exec_type, mutate_index);
+      } else {
+        LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+      }
     }
   }
-}
-
-
-// pass to attach operator executors
-Graph AttachOpExecs(Graph g) {
-  const auto& idx = g.indexed_graph();
-  OpExecVector ret(idx.num_nodes());
-  for (size_t i = 0; i < idx.num_nodes(); ++i) {
-    CreateOpExecs(g, &ret, i);
-  }
   g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
   return g;
 }
diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc
index 56122cd..6818662 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -30,15 +30,12 @@
 namespace mxnet {
 namespace exec {
 
-void AttachOpResources(
-    const Graph& g,
-    const OpExecVector& op_execs,
-    size_t start_nid,
-    size_t end_nid) {
+Graph AttachOpResources(Graph g) {
   static auto& fresource =
       nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
   static auto& fresource_ex =
       nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
+  auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
   const auto& vctx = g.GetAttr<ContextVector>("context");
   const auto& vdispatch = g.GetAttr<DispatchModeVector>("dispatch_mode");
   const auto& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
@@ -46,7 +43,7 @@ void AttachOpResources(
   // Use global resource pool for each executor for now.
   std::map<Context, Resource> cached_temp;
   // Resource allocation
-  for (uint32_t nid = start_nid; nid < end_nid; ++nid) {
+  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
     const auto& inode = idx[nid];
     if (inode.source->is_variable()) continue;
     const Context &ctx = vctx[nid];
@@ -87,12 +84,7 @@ void AttachOpResources(
       requested.push_back(ResourceManager::Get()->Request(ctx, ResourceRequest::kTempSpace));
     }
   }
+  return g;
 }
-
-void AttachOpResources(const Graph& g) {
-  const auto& op_execs = g.GetAttr<OpExecVector>("op_execs");
-  AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes());
-}
-
 }  // namespace exec
 }  // namespace mxnet
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 26a2491..99b1b16 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -82,10 +82,6 @@ class OpExecutor {
   virtual engine::VarHandle var() const {
     return nullptr;
   }
-  /*! \return return operator state */
-  virtual OpStatePtr state() const {
-    return OpStatePtr();
-  }
 };
 
 /*!
@@ -107,14 +103,6 @@ using ContextVector = std::vector<Context>;
 using DevMaskVector = std::vector<int>;
 
 /*!
- * \brief create OpExecutor for a node in graph
- *
- * \param g input graph
- * \param p_ret OpExecVector for input and output
- * \param i the id of the node
- */
-void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
-/*!
  * \brief Attach OpExecutor to the graph attributes.
  *
  * \param g input graph
@@ -127,20 +115,12 @@ Graph AttachOpExecs(Graph g);
  * \brief Attach Resource to the OpExecVector of the graph.
  *
  * \param g input graph need to contain op_exec attribute.
- */
-void AttachOpResources(const Graph& g);
-/*!
- * \brief Attach Resource to the OpExecVector
  *
- * \param g input graph
- * \param op_execs OpExecutor vector
- * \param start_nid starting node id
- * \param end_nid end node id
+ * \return graph with new attribute "op_exec" of type OpExecVector
+ *  The fields on the OpExecVector are not yet been setup.
  */
-void AttachOpResources(const Graph& g,
-                       const OpExecVector& op_execs,
-                       size_t start_nid,
-                       size_t end_nid);
+Graph AttachOpResources(Graph g);
+
 /*!
  * \brief Discover chance of inplace addto operators.
  *  i.e. z = plus(z, source_op), and encourage it to become z += source_op.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f9..e28867d 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
   }
 
   g = AttachOpExecs(g);
-  AttachOpResources(g);
+  g = AttachOpResources(g);
   graph_ = std::move(g);
 
   if (shared_exec != nullptr) {
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index b17fae4..140b5a5 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -19,78 +19,16 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
-#include "./cached_op.h"
-#include "../executor/exec_pass.h"
-#include "../profiler/profiler.h"
-
 
 namespace mxnet {
 
 DMLC_REGISTER_PARAMETER(CachedOpConfig);
 
-struct CachedOp::GraphInfo {
-  nnvm::Graph fwd_graph;
-  nnvm::Graph full_graph;
-  std::vector<OpReqType> bwd_output_reqs;
-  std::vector<uint32_t> bwd_input_eid;
-};
-
-struct CachedOp::DynamicRuntime {
-  GraphInfo info;
-  std::vector<NDArray> buff;
-  std::vector<OpStatePtr> op_states;
-};
-
-struct CachedOp::CachedOpState {
-  CachedOpState(const Context& context_,
-                const nnvm::Graph& fwd_graph_,
-                const nnvm::Graph& full_graph_) {
-    context = context_;
-    info.fwd_graph = fwd_graph_;
-    info.full_graph = full_graph_;
-
-    size_t max_nodes = info.full_graph.indexed_graph().num_nodes();
-    size_t max_entries = info.full_graph.indexed_graph().num_node_entries();
-    info.fwd_graph.attrs["context"] = std::make_shared<dmlc::any>(
-        std::vector<Context>(info.fwd_graph.indexed_graph().num_nodes(), context));
-    info.full_graph.attrs["context"] = std::make_shared<dmlc::any>(
-        std::vector<Context>(max_nodes, context));
-
-    buff.resize(max_entries);
-    arrays.resize(max_entries);
-    array_reqs.resize(max_entries);
-    dynamic_entries.resize(max_entries, false);
-    op_states.resize(max_nodes);
-    execs.resize(max_nodes);
-    opr_segs.resize(max_nodes);
-  }
-
-  std::mutex mutex;
-  Context context;
-  GraphInfo info;
-
-  bool recording = false;
-  bool fwd_alloc = false;
-  bool bwd_alloc = false;
-  bool fwd_exec_init = false;
-  bool bwd_exec_init = false;
-
-  std::vector<NDArray> buff;
-  std::vector<NDArray*> arrays;
-  std::vector<OpReqType> array_reqs;
-
-  std::vector<OpStatePtr> op_states;
-  std::vector<std::shared_ptr<exec::OpExecutor> > execs;
-  std::vector<imperative::EngineOprSeg> opr_segs;
-
-  std::vector<bool> dynamic_entries;
-  std::multimap<size_t, NDArray> fwd_reuse_pool;
-  std::multimap<size_t, NDArray> bwd_reuse_pool;
-};
-
-CachedOp::CachedOp(
+Imperative::CachedOp::CachedOp(
     const nnvm::Symbol& sym,
-    const std::vector<std::pair<std::string, std::string> >& flags) {
+    const std::vector<std::pair<std::string, std::string> >& flags,
+    const std::vector<std::string> arg_names,
+    const std::unordered_map<std::string, std::vector<NDArray> >& params) {
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
@@ -130,22 +68,34 @@ CachedOp::CachedOp(
     fwd_graph_.attrs["forward_ref_count"] =
         std::make_shared<dmlc::any>(std::move(ref_count));
 
-    inlining_ = !config_.static_alloc &&
-        (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
+    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
   }
 
   // Set params
   {
     const auto& idx = fwd_graph_.indexed_graph();
-    if (config_.data_indices.ndim() || config_.param_indices.ndim()) {
-      CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(),
-               idx.input_nodes().size());
-    } else {
-      std::vector<uint32_t> tmp;
-      for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
-        tmp.push_back(i);
+    std::unordered_map<std::string, size_t> arg_name_to_id;
+    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+      const auto& name = idx[idx.input_nodes()[i]].source->attrs.name;
+      auto iter = params.find(name);
+      if (iter == params.end()) {
+        arg_name_to_id[name] = i;
+        continue;
+      }
+      fwd_params_idx_.push_back(i);
+      for (const auto& param : iter->second) {
+        params_[param.ctx()].emplace_back(param);
       }
-      config_.data_indices.assign(tmp.begin(), tmp.end());
+    }
+
+    CHECK_EQ(arg_name_to_id.size(), arg_names.size())
+        << "CachedOp expects " << arg_name_to_id.size()
+        << " inputs, given " << arg_names.size();
+
+    for (const auto& name : arg_names) {
+      auto iter = arg_name_to_id.find(name);
+      CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name;
+      fwd_args_idx_.push_back(iter->second);
     }
   }
 
@@ -157,14 +107,9 @@ CachedOp::CachedOp(
     }
 
     std::vector<NodeEntry> xs;
-    const auto& idx = fwd_graph_.indexed_graph();
-    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
-      auto nid = idx.input_nodes()[i];
-      if (idx.mutable_input_nodes().count(nid)) continue;
-      fwd_input_to_grad_output_[i] = xs.size();
-      xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0});
-    }
-
+    std::vector<NodePtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
+    xs.reserve(args.size());
+    for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0});
     CHECK_GT(xs.size(), 0)
         << "There are no inputs in computation graph that require gradients.";
 
@@ -180,7 +125,7 @@ CachedOp::CachedOp(
     size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
 
     full_graph_.outputs = fwd_graph_.outputs;
-    bwd_output_reqs_ = std::vector<OpReqType>(grad_graph_.outputs.size(), kWriteTo);
+    curr_grad_req_ = std::vector<bool>(grad_graph_.outputs.size(), true);
     for (const auto& i : grad_graph_.outputs) full_graph_.outputs.emplace_back(i);
     const auto& idx = full_graph_.indexed_graph();
 
@@ -224,10 +169,7 @@ CachedOp::CachedOp(
   }
 }
 
-CachedOp::~CachedOp() {
-}
-
-std::vector<nnvm::NodeEntry> CachedOp::Gradient(
+std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
     const nnvm::NodePtr& node,
     const std::vector<nnvm::NodeEntry>& ograds) {
   using namespace nnvm;
@@ -264,15 +206,13 @@ std::vector<nnvm::NodeEntry> CachedOp::Gradient(
   return ret;
 }
 
-
-bool CachedOp::SetForwardGraph(
-    GraphInfo* info,
-    const bool recording,
-    const std::vector<NDArray*>& inputs) {
+nnvm::Graph Imperative::CachedOp::GetForwardGraph(
+    const bool recording, const std::vector<NDArray*>& inputs) {
   using namespace nnvm;
   using namespace imperative;
+  std::lock_guard<std::mutex> lock(mutex_);
   CHECK_EQ(inputs.size(), num_inputs());
-  nnvm::Graph& g = info->fwd_graph;
+  nnvm::Graph& g = fwd_graph_;
 
   ShapeVector shape_inputs;
   DTypeVector dtype_inputs;
@@ -297,22 +237,18 @@ bool CachedOp::SetForwardGraph(
     g.attrs.erase("forward_mem_plan");
     g.attrs.erase("full_mem_plan");
   } else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
-    return true;
+    return g;
   }
 
   const auto& idx = g.indexed_graph();
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
+  for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
   CHECK_EQ(stypes.size(), storage.size());
   for (size_t i = 0; i < stypes.size(); i++) {
-    if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
-  }
-  for (const auto i : idx.input_nodes()) {
-    storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
-  }
-  for (size_t i = 0; i < idx.outputs().size(); ++i) {
-    storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID;
+    if (stypes[i] != kDefaultStorage)
+      storage[i] = exec::kDynamicStorageID;
   }
 
   auto mem_plan = PlanMemory(
@@ -321,50 +257,51 @@ bool CachedOp::SetForwardGraph(
   g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
       std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return false;
+  return g;
 }
 
-bool CachedOp::SetBackwardGraph(
-    GraphInfo* info,
+nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
+    const OpStatePtr& op_state,
     const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& inputs,
-    bool detect_inplace_addto) {
+    const std::vector<NDArray*>& inputs) {
   using namespace nnvm;
   using namespace imperative;
   std::lock_guard<std::mutex> lock(mutex_);
-  Context default_ctx = inputs[0]->ctx();
-  nnvm::Graph& g = info->full_graph;
-
-  if (info->bwd_output_reqs != reqs) {
-    info->bwd_output_reqs = reqs;
-    info->bwd_input_eid.clear();
+  nnvm::Graph& g = full_graph_;
+  auto& state = op_state.get_state<CachedOpState>();
+  bool req_match = true;
+  for (size_t i = 0; i < reqs.size(); ++i) {
+    if (curr_grad_req_[i] != (reqs[i] != kNullOp)) {
+      curr_grad_req_[i] = reqs[i] != kNullOp;
+      req_match = false;
+    }
+  }
+  if (!req_match) {
     g = nnvm::Graph();
     g.outputs = fwd_graph_.outputs;
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
-      if (info->bwd_output_reqs[i] == kNullOp) continue;
-      g.outputs.emplace_back(grad_graph_.outputs[i]);
+      if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
     }
-    g.attrs["context"] = std::make_shared<dmlc::any>(
-        std::vector<Context>(g.indexed_graph().num_nodes(), default_ctx));
+    bwd_input_eid_.clear();
   }
 
   const auto& idx = g.indexed_graph();
 
-  if (info->bwd_input_eid.size() != inputs.size()) {
-    info->bwd_input_eid.clear();
+  if (bwd_input_eid_.size() != inputs.size()) {
+    bwd_input_eid_.clear();
     for (const auto& i : bwd_ograd_dep_) {
       auto eid = idx.entry_id(ograd_entries_[i]);
-      info->bwd_input_eid.push_back(eid);
+      bwd_input_eid_.push_back(eid);
     }
     for (const auto& i : bwd_in_dep_) {
       auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      info->bwd_input_eid.push_back(eid);
+      bwd_input_eid_.push_back(eid);
     }
     for (const auto& i : bwd_out_dep_) {
       auto eid = idx.entry_id(idx.outputs()[i]);
-      info->bwd_input_eid.push_back(eid);
+      bwd_input_eid_.push_back(eid);
     }
-    CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
+    CHECK_EQ(inputs.size(), bwd_input_eid_.size());
   }
 
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -375,22 +312,25 @@ bool CachedOp::SetBackwardGraph(
     for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
       for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
     }
-    for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[info->bwd_input_eid[i]];
+    for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]];
     for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
     g.attrs["backward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count));
   }
 
-  auto shapes = info->fwd_graph.GetAttr<ShapeVector>("shape");
-  shapes.resize(idx.num_node_entries(), TShape());
-  auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
-  dtypes.resize(idx.num_node_entries(), -1);
-  auto stypes = info->fwd_graph.GetAttr<StorageTypeVector>("storage_type");
-  stypes.resize(idx.num_node_entries(), -1);
+  ShapeVector shapes(idx.num_node_entries(), TShape());
+  DTypeVector dtypes(idx.num_node_entries(), -1);
+  StorageTypeVector stypes(idx.num_node_entries(), -1);
+
+  for (size_t i = 0; i < num_forward_entries; ++i) {
+    shapes[i] = state.buff[i].shape();
+    dtypes[i] = state.buff[i].dtype();
+    stypes[i] = state.buff[i].storage_type();
+  }
 
   for (size_t i = 0; i < inputs.size(); ++i) {
-    shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
-    dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
-    stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
+    shapes[bwd_input_eid_[i]] = inputs[i]->shape();
+    dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
+    stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
   }
 
   std::pair<uint32_t, uint32_t> node_range, entry_range;
@@ -402,353 +342,79 @@ bool CachedOp::SetBackwardGraph(
                               node_range, entry_range);
   match &= CheckAndInferType(&g, std::move(dtypes), false,
                              node_range, entry_range);
-  exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask());
+  exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask());
   match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes),
                                     false, node_range, entry_range);
 
   if (!match) {
     g.attrs.erase("backward_mem_plan");
   } else if (g.attrs.count("backward_mem_plan")) {
-    return true;
+    return g;
   }
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
-  const auto& bwd_stypes = g.GetAttr<StorageTypeVector>("storage_type");
-  for (size_t i = 0; i < bwd_stypes.size(); i++) {
-    if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
-  }
   for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID;
   for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
   for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID;
+  for (size_t i = 0; i < stypes.size(); i++) {
+    if (stypes[i] != kDefaultStorage)
+      storage[i] = exec::kDynamicStorageID;
+  }
 
   auto mem_plan = PlanMemory(
       &g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
-      {num_forward_nodes, idx.num_nodes()},
-      {num_forward_entries, idx.num_node_entries()},
-      detect_inplace_addto);
+      {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, idx.num_node_entries()});
   g.attrs["backward_mem_plan"] = std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return false;
-}
-
-OpStatePtr CachedOp::GetCachedOpState(
-    const Context& ctx) {
-  std::lock_guard<std::mutex> lock(mutex_);
-  for (const auto& i : cached_op_states_[ctx]) {
-    // only create one state per device when not using static memory
-    if (!config_.static_alloc || i.unique()) {
-      return i;
-    }
-  }
-  auto state_ptr = OpStatePtr::Create<CachedOpState>(ctx, fwd_graph_, full_graph_);
-
-  cached_op_states_[ctx].push_back(state_ptr);
-  return state_ptr;
-}
-
-void CachedOp::StaticAllocMemory(
-    const OpStatePtr& state_ptr,
-    bool recording,
-    bool keep_fwd) {
-  using namespace nnvm;
-  using namespace imperative;
-
-  auto& state = state_ptr.get_state<CachedOpState>();
-  const auto& default_ctx = state.context;
-  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
-  const auto& idx = g.indexed_graph();
-  const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
-  const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
-      keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
-  std::vector<int> addto_entry;
-  if (g.attrs.count("addto_entry")) {
-    addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
-  }
-  size_t start_eid =
-      keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0;
-  size_t end_eid = idx.num_node_entries();
-
-  if (!keep_fwd) state.fwd_alloc = false;
-  state.bwd_alloc = false;
-  for (size_t i = start_eid; i < state.buff.size(); ++i) {
-    state.buff[i] = NDArray();
-    state.arrays[i] = &state.buff[i];
-    state.array_reqs[i] = kNullOp;
-    state.dynamic_entries[i] = false;
-  }
-
-  for (auto i : idx.input_nodes()) {
-    auto eid = idx.entry_id(i, 0);
-    if (eid >= start_eid) state.dynamic_entries[eid] = true;
-  }
-  for (auto i : idx.outputs()) {
-    auto eid = idx.entry_id(i);
-    if (eid >= start_eid) state.dynamic_entries[eid] = true;
-  }
-
-  for (size_t i = start_eid; i < end_eid; ++i) {
-    if (addto_entry.size() && addto_entry[i]) {
-      state.array_reqs[i] = kAddTo;
-    } else if (vstorage_inplace[i] >= 0) {
-      state.array_reqs[i] = kWriteInplace;
-    } else if (vstorage_inplace[i] == -2) {
-      // -2 indicate that the entry is never referenced.
-      state.array_reqs[i] = kNullOp;
-    } else {
-      state.array_reqs[i] = kWriteTo;
-    }
-  }
-
-  auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool;
-  reuse_pool = imperative::AllocateMemory(
-      g, idx, default_ctx, start_eid, end_eid, mem_plan,
-      state.arrays, &state.array_reqs, std::move(reuse_pool));
-
-  state.recording = recording;
-  if (keep_fwd) {
-    state.bwd_alloc = true;
-  } else {
-    state.fwd_alloc = true;
-  }
+  return g;
 }
 
-void CachedOp::StaticInitExec(
-    const OpStatePtr& state_ptr,
-    bool recording,
-    bool keep_fwd) {
+void Imperative::CachedOp::Forward(
+    const std::shared_ptr<CachedOp>& op_ptr,
+    const std::vector<NDArray*>& args,
+    const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
+  static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
-  auto& state = state_ptr.get_state<CachedOpState>();
-  const auto& default_ctx = state.context;
-  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
-  const auto& idx = g.indexed_graph();
-  std::vector<int> skip_plus_node;
-  if (g.attrs.count("skip_plus_node")) {
-    skip_plus_node = g.GetAttr<std::vector<int> >("skip_plus_node");
-  }
-  size_t start_nid =
-      keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0;
-  size_t end_nid = idx.num_nodes();
-
-  if (!keep_fwd) state.fwd_exec_init = false;
-  state.bwd_exec_init = false;
-
-  for (size_t i = start_nid; i < state.execs.size(); ++i) {
-    state.execs[i].reset();
-    state.opr_segs[i] = EngineOprSeg();
-  }
-
-  if (!config_.static_shape) {
-    for (size_t i = start_nid; i < end_nid; ++i) {
-      state.opr_segs[i].next_nid = i + 1;
-      state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i];
-    }
-  } else {
-    for (size_t i = start_nid; i < end_nid; ++i) {
-      exec::CreateOpExecs(g, &state.execs, i);
-    }
-    exec::AttachOpResources(g, state.execs, start_nid, end_nid);
-
-    for (size_t i = start_nid; i < end_nid; ++i) {
-      bool skip = idx[i].source->is_variable();
-      for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) {
-        skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])];
-      }
-      for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) {
-        skip = state.dynamic_entries[idx.entry_id(i, j)];
-      }
-      if (skip) continue;
-      SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
-    }
+  CHECK_EQ(args.size(), fwd_args_idx_.size())
+      << "CachedOp requires " << fwd_args_idx_.size()
+      << " inputs but got " << args.size();
 
-    size_t bulk_size = idx.num_nodes();
-    std::unordered_set<uint32_t> excludes;
-    if (recording || keep_fwd) {
-      bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
-      for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
-      for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0));
-    }
+  Context default_ctx = args[0]->ctx();
 
-    CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes,
-                      state.execs, skip_plus_node, &state.opr_segs);
-  }
 
-  if (keep_fwd) {
-    state.bwd_exec_init = true;
-  } else {
-    state.fwd_exec_init = true;
+  std::vector<NDArray*> inputs(num_inputs());
+  for (index_t i = 0; i < fwd_args_idx_.size(); ++i) {
+    inputs[fwd_args_idx_[i]] = args[i];
   }
-}
-
-void CachedOp::StaticRunOps(
-    const Context& default_ctx,
-    const nnvm::Graph& g,
-    const OpStatePtr& state_ptr,
-    size_t start_nid,
-    size_t end_nid) {
-  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-
-  bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
-  bool is_training = Imperative::Get()->is_training();
-  auto& state = state_ptr.get_state<CachedOpState>();
-  const auto& idx = g.indexed_graph();
-  const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
-  const auto& op_execs = state.execs;
-
-  std::vector<NDArray*> ndinputs, ndoutputs;
-  nnvm::ShapeVector arg_shapes;
-  nnvm::DTypeVector arg_dtypes;
-  std::vector<OpReqType> req;
+  if (fwd_params_idx_.size()) {
+    CHECK(params_.find(default_ctx) != params_.end())
+        << "CachedOp is not initialized on context " << default_ctx;
 
-  for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) {
-    if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training;
-  }
-
-  for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) {
-    const auto& opr_seg = state.opr_segs[i];
-    if (opr_seg.skip) continue;
-    if (opr_seg.opr != nullptr) {
-      Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling);
-    } else {
-      const nnvm::IndexedGraph::Node& node = idx[i];
-      if (node.source->is_variable()) continue;
-      auto num_outputs = node.source->num_outputs();
-      ndinputs.clear();
-      ndinputs.reserve(node.inputs.size());
-      for (const auto& j : node.inputs) {
-        ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
-        CHECK(!ndinputs.back()->is_none());
-      }
-      ndoutputs.clear();
-      ndoutputs.reserve(num_outputs);
-      req.clear();
-      req.reserve(num_outputs);
-      for (size_t j = 0; j < num_outputs; ++j) {
-        size_t eid = idx.entry_id(i, j);
-        ndoutputs.emplace_back(state.arrays[eid]);
-        req.push_back(state.array_reqs[eid]);
-        CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
-      }
-      const DispatchMode dispatch_mode = dispatch_modes[i];
-      if (createop.count(node.source->op())) {
-        arg_shapes.clear();
-        arg_dtypes.clear();
-        arg_shapes.reserve(ndinputs.size());
-        arg_dtypes.reserve(ndinputs.size());
-        for (size_t i = 0; i < ndinputs.size(); ++i) {
-          arg_shapes.emplace_back(ndinputs[i]->shape());
-          arg_dtypes.emplace_back(ndinputs[i]->dtype());
-        }
-        state.op_states[i] = createop[node.source->op()](
-            node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
-        Imperative::Get()->InvokeOp(
-            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
-            dispatch_mode, state.op_states[i]);
-      } else if (is_layer_backward.get(node.source->op(), false)) {
-        nnvm::Node* fwd_node = node.source->control_deps[0].get();
-        auto fwd_node_id = idx.node_id(fwd_node);
-        Imperative::Get()->InvokeOp(
-            default_ctx, node.source->attrs, ndinputs, ndoutputs,
-            req, dispatch_mode, state.op_states[fwd_node_id]);
-      } else {
-        Imperative::Get()->InvokeOp(
-            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
-            dispatch_mode);
-      }
+    for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
+      inputs[fwd_params_idx_[i]] = &params_[default_ctx][i];
     }
   }
-}
-
-OpStatePtr CachedOp::StaticForward(
-    const Context& default_ctx,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace nnvm;
-  using namespace imperative;
 
+  // Initialize
   bool recording = Imperative::Get()->is_recording();
-  auto state_ptr = GetCachedOpState(default_ctx);
-  auto& state = state_ptr.get_state<CachedOpState>();
-  std::lock_guard<std::mutex> lock(state.mutex);
-
-  bool match = SetForwardGraph(&state.info, recording, inputs);
-  match = match && state.recording == recording;
-
-  nnvm::Graph& g = state.info.fwd_graph;
+  nnvm::Graph g = GetForwardGraph(recording, inputs);
   const auto& idx = g.indexed_graph();
-  if (!state.fwd_alloc || !match)  {
-    StaticAllocMemory(state_ptr, recording, false);
-  }
-
-  if (config_.static_shape) {
-    for (auto i : config_.param_indices) {
-      auto nid = idx.input_nodes()[i];
-      if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
-        match = false;
-        auto ptr = &state.buff[idx.entry_id(nid, 0)];
-        CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
-        *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
-        state.dynamic_entries[idx.entry_id(nid, 0)] = false;
-      }
-    }
-    for (auto i : config_.data_indices) {
-      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      state.arrays[eid] = inputs[i];
-    }
-  } else {
-    for (size_t i = 0; i < num_inputs(); ++i) {
-      auto nid = idx.input_nodes()[i];
-      state.arrays[idx.entry_id(nid, 0)] = inputs[i];
-    }
-  }
-
-  if (!state.fwd_exec_init || !match) {
-    StaticInitExec(state_ptr, recording, false);
-  }
-
-  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
-  const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  size_t num_inputs = idx.input_nodes().size();
 
-  for (size_t i = 0; i < outputs.size(); ++i) {
-    auto eid = idx.entry_id(idx.outputs()[i]);
-    state.arrays[eid] = outputs[i];
-    if (!outputs[i]->is_none()) continue;
-    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
-                          shapes[eid], default_ctx, true, dtypes[eid]);
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    CHECK_EQ(inputs[i]->ctx(), default_ctx)
+        << "CachedOp requires all inputs to live on the same context. But "
+        << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx
+        << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is on "
+        << inputs[i]->ctx();
   }
 
-  StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
-
-  return recording ? state_ptr : OpStatePtr();
-}
-
-
-OpStatePtr CachedOp::DynamicForward(
-    const Context& default_ctx,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace nnvm;
-  using namespace imperative;
-
-  // Initialize
-  bool recording = Imperative::Get()->is_recording();
-  auto op_state = OpStatePtr::Create<DynamicRuntime>();
-  auto& runtime = op_state.get_state<DynamicRuntime>();
-  {
-    auto state_ptr = GetCachedOpState(default_ctx);
-    auto& state = state_ptr.get_state<CachedOpState>();
-    std::lock_guard<std::mutex> lock(state.mutex);
-    SetForwardGraph(&state.info, recording, inputs);
-    runtime.info.fwd_graph = state.info.fwd_graph;
-  }
-  nnvm::Graph& g = runtime.info.fwd_graph;
-  const auto& idx = g.indexed_graph();
-  size_t num_inputs = idx.input_nodes().size();
-  auto& buff = runtime.buff;
-  auto& states = runtime.op_states;
+  auto op_state_ptr = OpStatePtr::Create<CachedOpState>();
+  auto& cached_op_state = op_state_ptr.get_state<CachedOpState>();
+  auto& buff = cached_op_state.buff;
+  auto& states = cached_op_state.states;
 
   // Allocate entries
   states.resize(idx.num_nodes());
@@ -780,98 +446,57 @@ OpStatePtr CachedOp::DynamicForward(
   AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
-  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
-  const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
-
-  for (size_t i = 0; i < outputs.size(); ++i) {
-    auto eid = idx.entry_id(idx.outputs()[i]);
-    arrays[eid] = outputs[i];
-    if (!outputs[i]->is_none()) continue;
-    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
-                          shapes[eid], default_ctx, true, dtypes[eid]);
-  }
-
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
   if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
 
-  RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
-           std::move(ref_count), &states, dispatch_modes);
+  Imperative::Get()->RunGraph(
+      false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
+      std::move(ref_count), &states, dispatch_modes);
 
+  Engine::Get()->set_bulk_size(prev_bulk_size);
   Imperative::Get()->set_is_recording(recording);
 
-  return op_state;
-}
-
-void CachedOp::Forward(
-    const std::shared_ptr<CachedOp>& op_ptr,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<NDArray*>& outputs) {
-  static const auto cached_op = nnvm::Op::Get("_CachedOp");
-
-  CHECK_EQ(inputs.size(), num_inputs());
-
-  Context default_ctx = inputs[0]->ctx();
-
-  const auto& idx = fwd_graph_.indexed_graph();
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    CHECK_EQ(inputs[i]->ctx(), default_ctx)
-        << "CachedOp requires all inputs to live on the same context. But "
-        << idx[idx.input_nodes()[0]].source->attrs.name
-        << " is on " << default_ctx << " while "
-        << idx[idx.input_nodes()[i]].source->attrs.name
-        << " is on " << inputs[i]->ctx();
-  }
-
-  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
-
-  OpStatePtr op_state;
-  if (config_.static_alloc) {
-    op_state = StaticForward(default_ctx, inputs, outputs);
-  } else {
-    op_state = DynamicForward(default_ctx, inputs, outputs);
+  for (size_t i = 0; i < idx.num_node_entries(); ++i) {
+    if (arrays[i] == &buff[i]) continue;
+    buff[i].shape_ = arrays[i]->shape_;
+    buff[i].dtype_ = arrays[i]->dtype_;
+    buff[i].storage_type_ = arrays[i]->storage_type_;
   }
 
-  Engine::Get()->set_bulk_size(prev_bulk_size);
-
-  if (Imperative::Get()->is_recording() && !inlining_) {
+  if (recording && !inlining_) {
     nnvm::NodeAttrs attrs;
     attrs.op = cached_op;
     attrs.name = "_cachedop";
     attrs.parsed = op_ptr;
     Imperative::Get()->RecordOp(
-        std::move(attrs), inputs, outputs, op_state,
+        std::move(attrs), inputs, outputs, op_state_ptr,
         &save_inputs(), &save_outputs());
   }
 }
 
 
-void CachedOp::DynamicBackward(
+void Imperative::CachedOp::Backward(
     const bool retain_graph,
-    const OpStatePtr& op_state,
+    const OpStatePtr& state,
     const std::vector<NDArray*>& inputs,
     const std::vector<OpReqType>& reqs,
     const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
+  CHECK(!Imperative::Get()->is_recording())
+      << "CachedOp does not support higher order gradients. "
+      << "If you want to do backward with create_graph=True please "
+      << "do not use hybridize.";
 
   // Initialize
-  Context default_ctx = outputs[0]->ctx();
-  auto& runtime = op_state.get_state<DynamicRuntime>();
-  {
-    auto state_ptr = GetCachedOpState(default_ctx);
-    auto& state = state_ptr.get_state<CachedOpState>();
-    std::lock_guard<std::mutex> lock(state.mutex);
-    state.info.fwd_graph = runtime.info.fwd_graph;
-    SetBackwardGraph(&state.info, reqs, inputs);
-    runtime.info.full_graph = state.info.full_graph;
-    runtime.info.bwd_input_eid = state.info.bwd_input_eid;
-  }
-  nnvm::Graph& g = runtime.info.full_graph;
+  nnvm::Graph g = GetBackwardGraph(state, reqs, inputs);
   const auto& idx = g.indexed_graph();
-  auto& buff = runtime.buff;
-  auto& states = runtime.op_states;
+
+  auto& cached_op_state = state.get_state<CachedOpState>();
+  auto& buff = cached_op_state.buff;
+  auto& states = cached_op_state.states;
 
   size_t num_forward_outputs = fwd_graph_.outputs.size();
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -881,7 +506,7 @@ void CachedOp::DynamicBackward(
   arrays.reserve(buff.size());
   for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]);
   for (size_t i = 0; i < inputs.size(); ++i) {
-    arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
+    arrays[bwd_input_eid_[i]] = inputs[i];
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
     if (reqs[i] == kNullOp) continue;
@@ -905,14 +530,20 @@ void CachedOp::DynamicBackward(
     if (ref_count[i] == 0) array_reqs[i] = kNullOp;
   }
 
+  Context default_ctx = outputs[0]->ctx();
   const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
   AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
-  RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
-           std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
+
+  Imperative::Get()->RunGraph(
+      retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
+      std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
+
+  Engine::Get()->set_bulk_size(prev_bulk_size);
 
   if (retain_graph) {
     buff.resize(num_forward_entries);
@@ -922,99 +553,6 @@ void CachedOp::DynamicBackward(
   }
 }
 
-void CachedOp::StaticBackward(
-    const bool retain_graph,
-    const OpStatePtr& state_ptr,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace nnvm;
-  using namespace imperative;
-
-  Context default_ctx = outputs[0]->ctx();
-
-  auto& state = state_ptr.get_state<CachedOpState>();
-  std::lock_guard<std::mutex> lock(state.mutex);
-
-  bool match = SetBackwardGraph(&state.info, reqs, inputs, true);
-
-  nnvm::Graph& g = state.info.full_graph;
-  const auto& idx = g.indexed_graph();
-  auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes();
-
-  if (!state.bwd_alloc || !match) {
-    StaticAllocMemory(state_ptr, true, true);
-  }
-
-  if (config_.static_shape) {
-    for (auto i : config_.param_indices) {
-      const auto iter = fwd_input_to_grad_output_.find(i);
-      if (iter == fwd_input_to_grad_output_.end()) continue;
-      auto entry = grad_graph_.outputs[iter->second];
-      if (!idx.exist(entry.node.get())) continue;
-      auto eid = idx.entry_id(entry);
-      if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
-          !(state.array_reqs[eid] == reqs[iter->second])) {
-        match = false;
-        state.array_reqs[eid] = reqs[iter->second];
-        *state.arrays[eid] = *outputs[iter->second];
-        state.dynamic_entries[eid] = false;
-      }
-    }
-    for (auto i : config_.data_indices) {
-      const auto iter = fwd_input_to_grad_output_.find(i);
-      if (iter == fwd_input_to_grad_output_.end()) continue;
-      auto entry = grad_graph_.outputs[iter->second];
-      if (!idx.exist(entry.node.get())) continue;
-      auto eid = idx.entry_id(entry);
-      state.array_reqs[eid] = reqs[iter->second];
-      state.arrays[eid] = outputs[iter->second];
-    }
-  } else {
-    for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
-      auto entry = grad_graph_.outputs[i];
-      if (!idx.exist(entry.node.get())) continue;
-      auto eid = idx.entry_id(entry);
-      state.array_reqs[eid] = reqs[i];
-      state.arrays[eid] = outputs[i];
-    }
-  }
-
-  if (!state.bwd_exec_init || !match) {
-    StaticInitExec(state_ptr, true, true);
-  }
-
-  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
-    auto eid = state.info.bwd_input_eid[i];
-    if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
-  }
-
-  StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
-}
-
-void CachedOp::Backward(
-    const bool retain_graph,
-    const OpStatePtr& state,
-    const std::vector<NDArray*>& inputs,
-    const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& outputs) {
-  using namespace imperative;
-  CHECK(!Imperative::Get()->is_recording())
-      << "CachedOp does not support higher order gradients. "
-      << "If you want to do backward with create_graph=True please "
-      << "do not use hybridize.";
-
-  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
-
-  if (config_.static_alloc) {
-    StaticBackward(retain_graph, state, inputs, reqs, outputs);
-  } else {
-    DynamicBackward(retain_graph, state, inputs, reqs, outputs);
-  }
-
-  Engine::Get()->set_bulk_size(prev_bulk_size);
-}
-
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
deleted file mode 100644
index 60a40c5..0000000
--- a/src/imperative/cached_op.h
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_IMPERATIVE_CACHED_OP_H_
-#define MXNET_IMPERATIVE_CACHED_OP_H_
-
-#include <mxnet/imperative.h>
-#include <vector>
-#include <atomic>
-#include <utility>
-#include <string>
-#include <unordered_map>
-
-namespace mxnet {
-/*! \brief CachedOp Parameters */
-struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
-  uint32_t inline_limit;
-  uint32_t forward_bulk_size;
-  uint32_t backward_bulk_size;
-  bool static_alloc;
-  bool static_shape;
-  nnvm::Tuple<uint32_t> data_indices;
-  nnvm::Tuple<uint32_t> param_indices;
-  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
-    DMLC_DECLARE_FIELD(static_alloc)
-    .set_default(false)
-    .describe("Statically allocate memory to improve speed. "
-              "Memory usage may increase.");
-    DMLC_DECLARE_FIELD(static_shape)
-    .set_default(false)
-    .describe("Optimize for invariant input shapes between iterations. "
-              "Must also set static_alloc to True. "
-              "Change of input shapes is still allowed but slower.");
-    DMLC_DECLARE_FIELD(inline_limit)
-    .set_default(2)
-    .describe("Maximum number of operators that can be inlined.");
-    DMLC_DECLARE_FIELD(forward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during forward pass.");
-    DMLC_DECLARE_FIELD(backward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during backward pass.");
-    DMLC_DECLARE_FIELD(data_indices)
-    .set_default(nnvm::Tuple<uint32_t>())
-    .describe("Position of argument variables.");
-    DMLC_DECLARE_FIELD(param_indices)
-    .set_default(nnvm::Tuple<uint32_t>())
-    .describe("Position of parameters.");
-  }
-};
-
-class CachedOp {
- public:
-  CachedOp(
-      const nnvm::Symbol& sym,
-      const std::vector<std::pair<std::string, std::string> >& flags);
-  ~CachedOp();
-  uint32_t num_inputs() {
-    return fwd_graph_.indexed_graph().input_nodes().size();
-  }
-  uint32_t num_outputs() {
-    return fwd_graph_.outputs.size();
-  }
-  uint32_t num_backward_inputs() {
-    return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
-  }
-  std::vector<bool>& save_inputs() {
-    return save_inputs_;
-  }
-  std::vector<bool>& save_outputs() {
-    return save_outputs_;
-  }
-  const std::unordered_set<uint32_t>& mutable_input_nodes() {
-    return fwd_graph_.indexed_graph().mutable_input_nodes();
-  }
-  std::vector<nnvm::NodeEntry> Gradient(
-      const nnvm::NodePtr& node,
-      const std::vector<nnvm::NodeEntry>& ograds);
-  void Forward(
-      const std::shared_ptr<CachedOp>& op_ptr,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<NDArray*>& outputs);
-  void Backward(
-      const bool retain_graph,
-      const OpStatePtr& state,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& outputs);
-
- private:
-  struct GraphInfo;
-  struct DynamicRuntime;
-  struct CachedOpState;
-
-  OpStatePtr GetCachedOpState(const Context& ctx);
-  bool SetForwardGraph(
-      GraphInfo* info,
-      const bool recording,
-      const std::vector<NDArray*>& inputs);
-  bool SetBackwardGraph(
-      GraphInfo* info,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& inputs,
-      bool detect_inplace_addto = false);
-  OpStatePtr DynamicForward(
-      const Context& default_ctx,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<NDArray*>& outputs);
-  void DynamicBackward(
-      const bool retain_graph,
-      const OpStatePtr& op_state,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& outputs);
-  void StaticAllocMemory(
-      const OpStatePtr& state_ptr,
-      bool recording,
-      bool keep_fwd);
-  void StaticInitExec(
-      const OpStatePtr& state_ptr,
-      bool recording,
-      bool keep_fwd);
-  void StaticRunOps(
-      const Context& default_ctx,
-      const nnvm::Graph& g,
-      const OpStatePtr& state_ptr,
-      size_t start_nid,
-      size_t end_nid);
-  OpStatePtr StaticForward(
-      const Context& default_ctx,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<NDArray*>& outputs);
-  void StaticBackward(
-      const bool retain_graph,
-      const OpStatePtr& state_ptr,
-      const std::vector<NDArray*>& inputs,
-      const std::vector<OpReqType>& reqs,
-      const std::vector<NDArray*>& outputs);
-
-  CachedOpConfig config_;
-  nnvm::Graph fwd_graph_;
-  nnvm::Graph grad_graph_;
-  nnvm::Graph full_graph_;
-  bool inlining_;
-  std::vector<nnvm::NodeEntry> ograd_entries_;
-  std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
-  std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;
-  std::vector<bool> save_inputs_, save_outputs_;
-  std::vector<OpReqType> bwd_output_reqs_;
-
-  std::mutex mutex_;
-  std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
-};
-
-using CachedOpPtr = std::shared_ptr<CachedOp>;
-
-}  // namespace mxnet
-#endif  // MXNET_IMPERATIVE_CACHED_OP_H_
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index e165425..7caf305 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -19,7 +19,6 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
-#include "./cached_op.h"
 
 namespace mxnet {
 #if DMLC_CXX11_THREAD_LOCAL
@@ -267,6 +266,95 @@ void Imperative::RecordOp(
   }
 }
 
+void Imperative::RunGraph(
+    const bool retain_graph,
+    const nnvm::IndexedGraph& idx,
+    const std::vector<NDArray*> arrays,
+    size_t node_start, size_t node_end,
+    std::vector<OpReqType>&& array_reqs,
+    std::vector<uint32_t>&& ref_count,
+    std::vector<OpStatePtr> *p_states,
+    const DispatchModeVector &dispatch_modes) {
+  using namespace nnvm;
+  using namespace imperative;
+  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
+
+  std::vector<OpStatePtr>& states = *p_states;
+  bool recording = is_recording();
+
+  std::vector<NDArray*> ndinputs, ndoutputs;
+  ShapeVector arg_shapes;
+  DTypeVector arg_dtypes;
+  std::vector<OpReqType> req;
+
+  for (size_t i = node_start; i < node_end; ++i) {
+    const nnvm::IndexedGraph::Node& node = idx[i];
+    if (node.source->op() == nullptr) continue;
+    auto num_outputs = node.source->num_outputs();
+    ndinputs.clear();
+    ndinputs.reserve(node.inputs.size());
+    for (const auto& j : node.inputs) {
+      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
+      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
+    }
+    ndoutputs.clear();
+    ndoutputs.reserve(num_outputs);
+    req.clear();
+    req.reserve(num_outputs);
+    for (size_t j = 0; j < num_outputs; ++j) {
+      size_t eid = idx.entry_id(i, j);
+      ndoutputs.emplace_back(arrays[eid]);
+      req.push_back(array_reqs[eid]);
+      CHECK(!ndoutputs.back()->is_none());
+    }
+    const Context& ctx = ndoutputs[0]->ctx();
+    const DispatchMode dispatch_mode = dispatch_modes[i];
+    if (node.source->op() == bwd_cached_op) {
+      const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
+    } else if (createop.count(node.source->op())) {
+      arg_shapes.clear();
+      arg_dtypes.clear();
+      arg_shapes.reserve(ndinputs.size());
+      arg_dtypes.reserve(ndinputs.size());
+      for (size_t i = 0; i < ndinputs.size(); ++i) {
+        arg_shapes.emplace_back(ndinputs[i]->shape());
+        arg_dtypes.emplace_back(ndinputs[i]->dtype());
+      }
+      states[i] = createop[node.source->op()](
+          node.source->attrs, ctx, arg_shapes, arg_dtypes);
+      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
+      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
+    } else if (is_layer_backward.get(node.source->op(), false)) {
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
+               req, dispatch_mode, states[fwd_node_id]);
+      if (recording) {
+        RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
+      }
+    } else {
+      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
+      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
+    }
+
+    for (const auto& j : node.inputs) {
+      size_t eid = idx.entry_id(j);
+      --ref_count[eid];
+      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
+    }
+    for (size_t j = 0; j < ndoutputs.size(); ++j) {
+      size_t eid = idx.entry_id(i, j);
+      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
+    }
+  }
+}
+
+
 std::vector<NDArray*> Imperative::Backward(
     const std::vector<NDArray*>& outputs,
     const std::vector<NDArray*>& ograds,
diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc
deleted file mode 100644
index 464aefc..0000000
--- a/src/imperative/imperative_utils.cc
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include "./imperative_utils.h"
-#include "./cached_op.h"
-
-namespace mxnet {
-namespace imperative {
-void RunGraph(
-    const bool retain_graph,
-    const nnvm::IndexedGraph& idx,
-    const std::vector<NDArray*> arrays,
-    size_t node_start, size_t node_end,
-    std::vector<OpReqType>&& array_reqs,
-    std::vector<uint32_t>&& ref_count,
-    std::vector<OpStatePtr> *p_states,
-    const DispatchModeVector &dispatch_modes) {
-  using namespace nnvm;
-  using namespace imperative;
-  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
-
-  const auto imp = Imperative::Get();
-
-  std::vector<OpStatePtr>& states = *p_states;
-  bool recording = imp->is_recording();
-
-  std::vector<NDArray*> ndinputs, ndoutputs;
-  ShapeVector arg_shapes;
-  DTypeVector arg_dtypes;
-  std::vector<OpReqType> req;
-
-  for (size_t i = node_start; i < node_end; ++i) {
-    const nnvm::IndexedGraph::Node& node = idx[i];
-    if (node.source->op() == nullptr) continue;
-    auto num_outputs = node.source->num_outputs();
-    ndinputs.clear();
-    ndinputs.reserve(node.inputs.size());
-    for (const auto& j : node.inputs) {
-      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
-      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
-    }
-    ndoutputs.clear();
-    ndoutputs.reserve(num_outputs);
-    req.clear();
-    req.reserve(num_outputs);
-    for (size_t j = 0; j < num_outputs; ++j) {
-      size_t eid = idx.entry_id(i, j);
-      ndoutputs.emplace_back(arrays[eid]);
-      req.push_back(array_reqs[eid]);
-      CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
-    }
-    const Context& ctx = ndoutputs[0]->ctx();
-    const DispatchMode dispatch_mode = dispatch_modes[i];
-    if (node.source->op() == bwd_cached_op) {
-      const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
-    } else if (createop.count(node.source->op())) {
-      arg_shapes.clear();
-      arg_dtypes.clear();
-      arg_shapes.reserve(ndinputs.size());
-      arg_dtypes.reserve(ndinputs.size());
-      for (size_t i = 0; i < ndinputs.size(); ++i) {
-        arg_shapes.emplace_back(ndinputs[i]->shape());
-        arg_dtypes.emplace_back(ndinputs[i]->dtype());
-      }
-      states[i] = createop[node.source->op()](
-          node.source->attrs, ctx, arg_shapes, arg_dtypes);
-      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
-      if (recording) {
-        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
-      }
-    } else if (is_layer_backward.get(node.source->op(), false)) {
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
-               req, dispatch_mode, states[fwd_node_id]);
-      if (recording) {
-        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
-      }
-    } else {
-      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
-      if (recording) {
-        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
-      }
-    }
-
-    for (const auto& j : node.inputs) {
-      size_t eid = idx.entry_id(j);
-      --ref_count[eid];
-      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
-    }
-    for (size_t j = 0; j < ndoutputs.size(); ++j) {
-      size_t eid = idx.entry_id(i, j);
-      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
-    }
-  }
-}
-
-}  // namespace imperative
-}  // namespace mxnet
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 726531d..06b7e05 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -23,7 +23,6 @@
 #include <utility>
 #include <algorithm>
 #include <vector>
-#include <map>
 #include <string>
 #include "../executor/graph_executor.h"
 #include "../executor/exec_pass.h"
@@ -39,24 +38,11 @@ namespace mxnet {
 namespace imperative {
 
 struct MemoryPlanInfo {
-  int storage_id;
-  uint32_t root;
+  uint32_t sid;
   size_t size;
   bool inplace;
 };
 
-struct EngineOprDeleter {
-  void operator()(engine::Opr* handle) {
-    Engine::Get()->DeleteOperator(handle);
-  }
-};
-
-struct EngineOprSeg {
-  bool skip;
-  size_t next_nid;
-  std::unique_ptr<engine::Opr, EngineOprDeleter> opr;
-};
-
 using MemoryPlanVector = std::vector<MemoryPlanInfo>;
 
 inline Context GetContext(const nnvm::NodeAttrs& attrs,
@@ -729,12 +715,10 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {
 
 
 inline MemoryPlanVector PlanMemory(
-    nnvm::Graph* p_g,
-    nnvm::StorageVector&& storage,
+    nnvm::Graph* p_g, nnvm::StorageVector&& storage,
     const std::vector<uint32_t>& ref_count,
     const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
-    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
-    bool detect_inplace_addto = false) {
+    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0}) {
   using namespace nnvm;
   nnvm::Graph& g = *p_g;
   const auto& idx = g.indexed_graph();
@@ -744,31 +728,31 @@ inline MemoryPlanVector PlanMemory(
   g.attrs["ref_count"] = std::make_shared<dmlc::any>(ref_count);
   g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(storage));
   g = nnvm::ApplyPass(g, "PlanMemory");
-  if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g);
 
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
-  const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
+  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  auto storage_ids = g.MoveCopyAttr<StorageVector>("storage_id");
+  auto storage_inplace = g.MoveCopyAttr<std::vector<int> >("storage_inplace_index");
   uint32_t entry_start = entry_range.first;
   uint32_t entry_end =
       entry_range.second > entry_start ? entry_range.second : idx.num_node_entries();
   MemoryPlanVector mem_plan(idx.num_node_entries());
-  std::unordered_map<int, uint32_t> sid_to_root;
+  std::unordered_map<int, uint32_t> sid_to_loc;
 
   for (uint32_t i = entry_start; i < entry_end; ++i) {
+    if (stypes[i] != kDefaultStorage) continue;
     if (storage_ids[i] < 0) {
-      mem_plan[i] = {storage_ids[i], i, 0, false};
-    } else if (!sid_to_root.count(storage_ids[i])) {
+      mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false};
+    } else if (!sid_to_loc.count(storage_ids[i])) {
       CHECK_LT(storage_inplace[i], 0);
-      sid_to_root[storage_ids[i]] = i;
-      mem_plan[i] = {storage_ids[i], i,
-                     mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
-                     false};
+      sid_to_loc[storage_ids[i]] = i;
+      mem_plan[i].sid = i;
+      mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size();
     } else {
-      uint32_t root = sid_to_root[storage_ids[i]];
-      mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0};
-      mem_plan[root].size = std::max(mem_plan[root].size,
+      uint32_t loc = sid_to_loc[storage_ids[i]];
+      mem_plan[i] = {loc, 0, storage_inplace[i] >= 0};
+      mem_plan[loc].size = std::max(mem_plan[loc].size,
           mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size());
     }
   }
@@ -777,213 +761,39 @@ inline MemoryPlanVector PlanMemory(
 }
 
 
-inline std::multimap<size_t, NDArray> AllocateMemory(
-    const nnvm::Graph& g,
-    const nnvm::IndexedGraph& idx,
-    const Context& default_ctx,
-    const uint32_t entry_start, const uint32_t entry_end,
-    const MemoryPlanVector& mem_plan,
-    const std::vector<NDArray*>& arrays,
-    std::vector<OpReqType> *array_reqs,
-    std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
+inline void AllocateMemory(const nnvm::Graph& g,
+                    const nnvm::IndexedGraph& idx,
+                    const Context& default_ctx,
+                    const uint32_t entry_start, const uint32_t entry_end,
+                    const MemoryPlanVector& mem_plan,
+                    const std::vector<NDArray*>& arrays,
+                    std::vector<OpReqType> *array_reqs) {
   using namespace nnvm;
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
 
-  std::multimap<size_t, NDArray> new_pool;
-
   for (uint32_t i = entry_start; i < entry_end; ++i) {
-    if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
-    CHECK(arrays[i]->is_none());
-    if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
-      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
-                           shapes[i], default_ctx, true, dtypes[i]);
-      continue;
-    }
-    CHECK_EQ(stypes[i], kDefaultStorage);
-    if (mem_plan[i].root == i) {
-      CHECK_GT(mem_plan[i].size, 0);
-      auto iter = pool.lower_bound(mem_plan[i].size);
-      if (iter != pool.end()) {
-        *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
-        new_pool.insert(*iter);
-        pool.erase(iter);
-      } else {
+    if (!arrays[i]->is_none()) continue;
+    if (stypes[i] == kDefaultStorage) {
+      if (mem_plan[i].sid == i) {
+        CHECK_GT(mem_plan[i].size, 0);
         NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
                      default_ctx, true, mshadow::kUint8);
         *arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
-        new_pool.insert({mem_plan[i].size, buff});
-      }
-    } else {
-      CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
-      *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
-      if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
-        array_reqs->at(i) = kWriteInplace;
-      }
-    }
-  }
-
-  return new_pool;
-}
-
-inline void SetupOpExec(
-    const nnvm::Graph& g,
-    size_t nid,
-    const std::shared_ptr<exec::OpExecutor>& exec,
-    const std::vector<NDArray*> arrays,
-    const std::vector<OpReqType> array_reqs) {
-  const auto& idx = g.indexed_graph();
-  const auto& inode = idx[nid];
-  CHECK_EQ(exec->in_array.size(), 0U);
-  CHECK_EQ(exec->out_array.size(), 0U);
-  for (const auto& e : inode.inputs) {
-    CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name;
-    exec->in_array.push_back(*arrays[idx.entry_id(e)]);
-  }
-  for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
-    uint32_t eid = idx.entry_id(nid, index);
-    CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name;
-    exec->out_array.push_back(*arrays[eid]);
-    exec->req.push_back(array_reqs[eid]);
-  }
-
-  exec->Setup();
-}
-
-inline Engine::OprHandle CreateEngineOp(
-    const Context& default_ctx,
-    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs) {
-  CHECK_GT(execs.size(), 0);
-  std::vector<Engine::VarHandle> use_vars, mutate_vars;
-
-  for (const auto& exec : execs) {
-    CHECK_GT(exec->out_array.size(), 0);
-    CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);
-
-    // the variables
-    for (const auto& nd : exec->in_array) {
-      use_vars.push_back(nd.var());
-    }
-    for (auto& r : exec->op_ctx.requested) {
-      mutate_vars.push_back(r.var);
-    }
-    for (auto& nd : exec->out_array) {
-      mutate_vars.push_back(nd.var());
-    }
-    if (exec->var() != nullptr) {
-      mutate_vars.push_back(exec->var());
-    }
-  }
-
-  // dedup vars
-  Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
-  bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
-  bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;
-
-  auto exec_fun = [execs, is_async, is_gpu] (
-      RunContext ctx, Engine::CallbackOnComplete on_complete) {
-    if (is_async) {
-      execs[0]->op_ctx.async_on_complete = on_complete;
-    }
-    for (const auto& exec : execs) exec->Run(ctx, is_gpu);
-    // call on complete only if it is async op
-    if (!is_async) {
-      if (is_gpu) {
-      #if MXNET_USE_CUDA
-        // Wait GPU kernel to finish.
-        ctx.get_stream<gpu>()->Wait();
-      #else
-        LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
-      #endif
-      }
-      on_complete();
-    }
-  };
-
-  return Engine::Get()->NewOperator(
-      exec_fun, use_vars, mutate_vars, FnProperty::kNormal);
-}
-
-inline void CreateEngineOpSeg(
-    const nnvm::IndexedGraph& idx,
-    const Context default_ctx,
-    const size_t start_nid,
-    const size_t end_nid,
-    const size_t bulk_size,
-    const std::unordered_set<uint32_t>& excludes,
-    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
-    const std::vector<int> skip_plus_node,
-    std::vector<EngineOprSeg> *opr_segs) {
-  size_t seg_start = start_nid;
-  std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
-  for (size_t nid = start_nid; nid < end_nid; ++nid) {
-    const auto& node = idx[nid];
-    if (node.source->is_variable()) continue;
-    if (skip_plus_node.size() && skip_plus_node[nid]) continue;
-    auto& exec = execs[nid];
-    bool is_async = exec->exec_type() != ExecType::kSync;
-    bool valid = exec->out_array.size() > 0;
-
-    // Stop at async nodes and invalid node (due to input/output is not allocated)
-    bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
-    for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
-      if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
-    }
-    auto num_outputs = node.source->num_outputs();
-    for (size_t i = 0; i < num_outputs && !stop; ++i) {
-      if (excludes.count(idx.entry_id(nid, i))) stop = true;
-    }
-
-    // Create opr segment for previous nodes.
-    if (stop && nid > seg_start) {
-      auto& seg = (*opr_segs)[seg_start];
-      if (seg_execs.size()) {
-        seg = EngineOprSeg{false, nid};
-        seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
       } else {
-        seg = EngineOprSeg{true, nid, nullptr};
+        *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]);
+        if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
+          array_reqs->at(i) = kWriteInplace;
+        }
       }
-      seg_start = nid;
-      seg_execs.clear();
-    }
-
-    seg_execs.push_back(exec);
-
-    auto& seg = (*opr_segs)[nid];
-    if (is_async) {
-      seg = EngineOprSeg{false, nid + 1};
-      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
-      seg_execs.clear();
-      seg_start = nid + 1;
-    } else if (!valid) {
-      seg = EngineOprSeg{false, nid + 1, nullptr};
-      seg_execs.clear();
-      seg_start = nid + 1;
-    }
-  }
-  // The last segment
-  if (end_nid > seg_start) {
-    auto& seg = (*opr_segs)[seg_start];
-    if (seg_execs.size()) {
-      seg = EngineOprSeg{false, end_nid};
-      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
     } else {
-      seg = EngineOprSeg{true, end_nid, nullptr};
+      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
+                           shapes[i], default_ctx, true, dtypes[i]);
     }
   }
 }
 
-
-void RunGraph(const bool retain_graph,
-              const nnvm::IndexedGraph& idx,
-              const std::vector<NDArray*> arrays,
-              size_t node_start, size_t node_end,
-              std::vector<OpReqType>&& array_reqs,
-              std::vector<uint32_t>&& ref_count,
-              std::vector<OpStatePtr> *p_states,
-              const DispatchModeVector &dispatch_modes);
-
 }  // namespace imperative
 }  // namespace mxnet
 
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 5701a5d..451fde2 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -22,7 +22,6 @@ from mxnet.test_utils import assert_almost_equal
 from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from common import setup_module, with_seed, assertRaises, teardown
 import numpy as np
-from numpy.testing import assert_array_equal
 from nose.tools import raises, assert_raises
 from copy import deepcopy
 import warnings
@@ -1125,6 +1124,7 @@ def test_hybrid_multi_context():
     net.hybridize()
     net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()
 
+
 @with_seed()
 def test_zero_grad():
     data = mx.nd.random.uniform(shape=(3,3))
@@ -1137,60 +1137,6 @@ def test_zero_grad():
     grad = net.collect_params()['test_zero_grad_weight'].grad()
     assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
 
-def check_hybrid_static_memory(**kwargs):
-    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
-    x.attach_grad()
-
-    net1 = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
-    net2 = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
-    net2.hybridize(**kwargs)
-    net1(x)
-    net2(x)
-
-    def test(net, x):
-        with mx.autograd.record():
-            y = net(x) + net(x)
-            y.backward()
-
-        grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}
-
-        return y, grads
-
-    y1, grads1 = test(net1, x)
-    y2, grads2 = test(net2, x)
-
-    assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
-    for key in grads1:
-        assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
-
-def test_hybrid_static_memory():
-    check_hybrid_static_memory()
-    check_hybrid_static_memory(static_alloc=True)
-    check_hybrid_static_memory(static_alloc=True, static_shape=True)
-
-def check_hybrid_static_memory_switching(**kwargs):
-    net = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, ctx=mx.context.current_context())
-    net.hybridize(**kwargs)
-
-    x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
-    net(x)
-    with mx.autograd.record():
-        y = net(x)
-        y.backward()
-    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
-    net(x)
-    with mx.autograd.record():
-        y = net(x)
-        y.backward()
-    mx.nd.waitall()
-
-def test_hybrid_static_memory_switching():
-    check_hybrid_static_memory_switching()
-    check_hybrid_static_memory_switching(static_alloc=True)
-    check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)
 
 @with_seed()
 def test_hook():
@@ -1285,17 +1231,6 @@ def test_legacy_save_params():
     model.load_params('test.params', ctx=mx.cpu())
 
 
-def test_hybrid_static_memory_recording():
-    net = gluon.model_zoo.vision.get_resnet(
-        1, 18, pretrained=True, ctx=mx.context.current_context())
-    net.hybridize(static_alloc=True)
-
-    x = mx.nd.random.uniform(shape=(1, 3, 32, 32))
-    with mx.autograd.record(True):
-        net(x)
-    net(x)
-
-
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
marcoabreu@apache.org.