You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/16 01:22:18 UTC

[GitHub] tqchen closed pull request #11313: Static alloc for hybridblock

tqchen closed pull request #11313: Static alloc for hybridblock
URL: https://github.com/apache/incubator-mxnet/pull/11313
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 55c26bc980b..4dd858a51c4 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -987,11 +987,6 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
                                  int num_flags,
                                  const char** keys,
                                  const char** vals,
-                                 int num_inputs,
-                                 const char** input_names,
-                                 int num_params,
-                                 const char** param_names,
-                                 NDArrayHandle* params,
                                  CachedOpHandle *out);
 /*!
  * \brief free cached operator
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 758ce851321..7ea60df3302 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,23 +35,6 @@
 #include "./ndarray.h"
 
 namespace mxnet {
-/*! \brief CachedOp Parameters */
-struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
-  uint32_t inline_limit;
-  uint32_t forward_bulk_size;
-  uint32_t backward_bulk_size;
-  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
-    DMLC_DECLARE_FIELD(inline_limit)
-    .set_default(2)
-    .describe("Maximum number of operators that can be inlined.");
-    DMLC_DECLARE_FIELD(forward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during forward pass.");
-    DMLC_DECLARE_FIELD(backward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
-    .describe("Segment size of bulk execution during backward pass.");
-  }
-};
 /*! \brief runtime functions for NDArray */
 class Imperative {
  public:
@@ -94,67 +77,6 @@ class Imperative {
              && info.out_grads.size() == 1;
     }
   };
-  class CachedOp {
-   public:
-    CachedOp(
-        const nnvm::Symbol& sym,
-        const std::vector<std::pair<std::string, std::string> >& flags,
-        const std::vector<std::string> arg_names,
-        const std::unordered_map<std::string, std::vector<NDArray> >& params);
-    uint32_t num_inputs() {
-      return fwd_graph_.indexed_graph().input_nodes().size();
-    }
-    uint32_t num_outputs() {
-      return fwd_graph_.outputs.size();
-    }
-    uint32_t num_backward_inputs() {
-      return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
-    }
-    std::vector<bool>& save_inputs() {
-      return save_inputs_;
-    }
-    std::vector<bool>& save_outputs() {
-      return save_outputs_;
-    }
-    const std::unordered_set<uint32_t>& mutable_input_nodes() {
-      return fwd_graph_.indexed_graph().mutable_input_nodes();
-    }
-    nnvm::Graph GetForwardGraph(const bool recording,
-                                const std::vector<NDArray*>& inputs);
-    nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
-                                 const std::vector<OpReqType>& reqs,
-                                 const std::vector<NDArray*>& inputs);
-    std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
-                                          const std::vector<nnvm::NodeEntry>& ograds);
-    void Forward(const std::shared_ptr<CachedOp>& op_ptr,
-                 const std::vector<NDArray*>& args,
-                 const std::vector<NDArray*>& outputs);
-    void Backward(const bool retain_graph,
-                  const OpStatePtr& state,
-                  const std::vector<NDArray*>& inputs,
-                  const std::vector<OpReqType>& reqs,
-                  const std::vector<NDArray*>& outputs);
-
-   private:
-    struct CachedOpState {
-      std::vector<NDArray> buff;
-      std::vector<OpStatePtr> states;
-    };
-    std::mutex mutex_;
-    CachedOpConfig config_;
-    nnvm::Graph fwd_graph_;
-    nnvm::Graph grad_graph_;
-    nnvm::Graph full_graph_;
-    std::unordered_map<Context, std::vector<NDArray> > params_;
-    bool inlining_;
-    std::vector<nnvm::NodeEntry> ograd_entries_;
-    std::vector<bool> curr_grad_req_;
-    std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
-    std::vector<uint32_t> fwd_args_idx_;
-    std::vector<uint32_t> fwd_params_idx_;
-    std::vector<uint32_t> bwd_input_eid_;
-    std::vector<bool> save_inputs_, save_outputs_;
-  };
   /*! \brief whether operator recording is on. */
   bool is_training() const {
     return is_train_;
@@ -222,15 +144,6 @@ class Imperative {
       uint32_t num_inputs, uint32_t num_outputs,
       std::vector<bool> *p_save_inputs,
       std::vector<bool> *p_save_outputs);
-  void RunGraph(
-      const bool retain_graph,
-      const nnvm::IndexedGraph& idx,
-      const std::vector<NDArray*> arrays,
-      size_t node_start, size_t node_end,
-      std::vector<OpReqType>&& array_reqs,
-      std::vector<uint32_t>&& ref_count,
-      std::vector<OpStatePtr> *p_states,
-      const DispatchModeVector& dispatch_modes);
   /*! \brief indicate whether is training. */
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local bool is_train_;
@@ -247,7 +160,5 @@ class Imperative {
   int backward_bulk_size_{0};
 };
 
-using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
-
 }  // namespace mxnet
 #endif  // MXNET_IMPERATIVE_H_
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index e243eb71c47..ae96fd87b0d 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -155,6 +155,14 @@ class NDArray {
     return byte_offset_ > 0 || shape() != ptr_->storage_shape;
   }
 
+  /* \brief Check whether the two arrays are the same array */
+  inline bool IsSame(const NDArray& other) {
+    return ptr_ == other.ptr_ &&
+        shape_ == other.shape_ &&
+        byte_offset_ == other.byte_offset_ &&
+        dtype_ == other.dtype_;
+  }
+
   /*!
    * \return the shape of current NDArray.
    */
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 3969d8445be..f4694efad29 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -126,25 +126,36 @@ class OpStatePtr {
   template<typename T, typename... Args>
   static OpStatePtr Create(Args&&... args) {
     OpStatePtr ret;
-    ret.ptr_ = std::make_shared<OpState>();
-    ret.ptr_->var_ = Engine::Get()->NewVariable();
-    ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
+    auto state = new T(std::forward<Args>(args)...);
+    auto var = Engine::Get()->NewVariable();
+    ret.ptr_.reset(
+      new OpState(var, state),
+      [](OpState* p) {
+        Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
+        delete reinterpret_cast<T*>(p->state);
+        delete p;
+      });
 
     return ret;
   }
   /* \brief Get engine variable associated with this state */
   engine::VarHandle get_var() const {
-    return ptr_->var_;
+    return ptr_->var;
   }
   /* \brief Get state of type T */
   template<typename T>
   T& get_state() const {
-    return dmlc::get<T>(ptr_->state_);
+    return *reinterpret_cast<T*>(ptr_->state);
   }
   /* \brief clear state */
   void reset() {
     ptr_.reset();
   }
+  /* \brief checks whether the managed object is managed only by the current
+            OpStatePtr instance */
+  bool unique() const {
+    return ptr_.unique();
+  }
   /* \brief Whether state is empty */
   explicit operator bool() const {
     return ptr_ ? true : false;
@@ -153,16 +164,12 @@ class OpStatePtr {
  private:
   /* \brief state structure */
   struct OpState {
-    OpState() {}
+    engine::VarHandle var;
+    void* state;
+
+    OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
     OpState(const OpState& other) = delete;
     OpState& operator=(const OpState& other) = delete;
-
-    ~OpState() {
-      Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
-    }
-
-    engine::VarHandle var_;
-    dmlc::any state_;
   };
   /* \brief shared pointer to state */
   std::shared_ptr<OpState> ptr_;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index d2cae0c45aa..f324545a235 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,28 +105,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 class CachedOp(object):
     """Cached operator handle."""
     __slots__ = ["handle"]
-    def __init__(self, sym, flags=(), inputs=None, params=None):
+    def __init__(self, sym, flags=()):
         self.handle = CachedOpHandle()
-        param_names = []
-        param_arrays = []
-        if inputs is None:
-            assert params is None, "When inputs is None params must also be None."
-            inputs = sym.list_inputs()
-        elif params is not None:
-            for name, arrs in params.items():
-                param_arrays.extend(arrs)
-                param_names.extend([name] * len(arrs))
 
         check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
             len(flags),
             c_str_array([key for key, _ in flags]),
             c_str_array([str(val) for _, val in flags]),
-            len(inputs),
-            c_str_array(inputs),
-            len(param_names),
-            c_str_array(param_names),
-            c_handle_array(param_arrays),
             ctypes.byref(self.handle)))
 
     def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 3b97c0578ca..293fafab487 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -502,8 +502,16 @@ def hybridize(self, active=True, **kwargs):
         ----------
         active : bool, default True
             Whether to turn hybrid on or off.
-        **kwargs : string
-            Additional flags for hybridized operator.
+        static_alloc : bool, default False
+            Statically allocate memory to improve speed. Memory usage may increase.
+        static_shape : bool, default False
+            Optimize for invariant input shapes between iterations. Must also
+            set static_alloc to True. Change of input shapes is still allowed
+            but slower.
+        forward_bulk_size : int, default 15
+            Segment size of bulk execution during forward pass.
+        backward_bulk_size : int, default 15
+            Segment size of bulk execution during backward pass.
         """
         for cld in self._children.values():
             cld.hybridize(active, **kwargs)
@@ -696,7 +704,7 @@ def __init__(self, prefix=None, params=None):
         self._out_format = None
         self._in_format = None
         self._active = False
-        self._flags = {}
+        self._flags = []
 
     def __setattr__(self, name, value):
         """Registers parameters."""
@@ -723,39 +731,43 @@ def _get_graph(self, *args):
         return self._cached_graph
 
     def _build_cache(self, *args):
-        inputs, out = self._get_graph(*args)
-        input_names = [i.name for i in inputs]
-
+        data, out = self._get_graph(*args)
+        data_names = {data.name : i for i, data in enumerate(data)}
         params = self.collect_params()
+        input_names = out.list_inputs()
+
         param_names = set(params.keys())
-        expected_names = set(out.list_inputs())
+        expected_names = set(input_names)
         for name in expected_names:
-            assert name in param_names or name in input_names, \
+            assert name in param_names or name in data_names, \
                 "Unknown input to HybridBlock: %s"%name
 
-        used_input_names = [i for i in input_names if i in expected_names]
-        if len(used_input_names) != len(input_names):
-            unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
+        used_data_names = [i for i in data_names if i in expected_names]
+        if len(used_data_names) != len(data_names):
+            unused = ', '.join(['%d-th'%i for name, i in data_names.items()
                                 if name not in expected_names])
             warnings.warn("The %s input to HybridBlock is not used by any "
                           "computation. Is this intended?"%unused, stacklevel=4)
 
-        used_param_names = set(i for i in param_names if i in expected_names)
+        used_param_names = [i for i in param_names if i in expected_names]
         if len(used_param_names) != len(param_names):
-            unused = ', '.join(list(param_names - used_param_names))
+            unused = ', '.join(list(param_names - set(used_param_names)))
             warnings.warn("Parameter %s is not used by any computation. "
                           "Is this intended?"%unused, stacklevel=4)
 
-        used_params = {k: params[k] for k in used_param_names}
-        try:
-            param_dict = {k: v.list_data() for k, v in used_params.items()}
-        except DeferredInitializationError:
-            self._deferred_infer_shape(*args)
-            for i in used_params.values():
-                i._finish_deferred_init()
-            param_dict = {k: v.list_data() for k, v in used_params.items()}
-
-        self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)
+        data_indices = []
+        param_indices = []
+        self._cached_op_args = []
+        for i, name in enumerate(input_names):
+            if name in data_names:
+                data_indices.append(i)
+                self._cached_op_args.append((True, data_names[name]))
+            else:
+                param_indices.append(i)
+                self._cached_op_args.append((False, params[name]))
+        flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
+                self._flags
+        self._cached_op = ndarray.CachedOp(out, flags)
 
     def _deferred_infer_shape(self, *args):
         try:
@@ -771,7 +783,19 @@ def _call_cached_op(self, *args):
 
         args, fmt = _flatten(args, "input")
         assert fmt == self._in_format, "Invalid input format"
-        out = self._cached_op(*args)
+        try:
+            cargs = [args[i] if is_arg else i.data()
+                     for is_arg, i in self._cached_op_args]
+        except DeferredInitializationError:
+            self._deferred_infer_shape(*args)
+            cargs = []
+            for is_arg, i in self._cached_op_args:
+                if is_arg:
+                    cargs.append(args[i])
+                else:
+                    i._finish_deferred_init()
+                    cargs.append(i.data())
+        out = self._cached_op(*cargs)
         if isinstance(out, NDArray):
             out = [out]
         return _regroup(out, self._out_format)[0]
@@ -792,7 +816,7 @@ def register_child(self, block, name=None):
 
     def hybridize(self, active=True, **kwargs):
         self._active = active
-        self._flags = kwargs.items()
+        self._flags = list(kwargs.items())
         self._clear_cached_op()
         if active and self._forward_hooks or self._forward_pre_hooks:
             warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 9aabe04656e..34bd4b20aa5 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -36,6 +36,7 @@
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
 #include "../imperative/imperative_utils.h"
+#include "../imperative/cached_op.h"
 
 using namespace mxnet;
 
@@ -160,12 +161,8 @@ int MXCreateCachedOp(SymbolHandle handle,
   std::vector<std::string> input_names;
   input_names.reserve(inputs.size());
   for (const auto& i : inputs) input_names.push_back(i->attrs.name);
-  *out = new std::shared_ptr<Imperative::CachedOp>(
-      new Imperative::CachedOp(
-        *sym,
-        std::vector<std::pair<std::string, std::string> >(),
-        input_names,
-        std::unordered_map<std::string, std::vector<NDArray> >()));
+  *out = new CachedOpPtr(new CachedOp(
+      *sym, std::vector<std::pair<std::string, std::string> >()));
   API_END();
 }
 
@@ -173,11 +170,6 @@ int MXCreateCachedOpEx(SymbolHandle handle,
                        int num_flags,
                        const char** keys,
                        const char** vals,
-                       int num_args,
-                       const char** arg_names,
-                       int num_params,
-                       const char** param_names,
-                       NDArrayHandle* params,
                        CachedOpHandle *out) {
   nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
 
@@ -186,17 +178,7 @@ int MXCreateCachedOpEx(SymbolHandle handle,
   for (int i = 0; i < num_flags; ++i) {
     flags.push_back({keys[i], vals[i]});
   }
-  std::vector<std::string> args;
-  for (int i = 0; i < num_args; ++i) {
-    args.push_back(arg_names[i]);
-  }
-  std::unordered_map<std::string, std::vector<NDArray> > param_dict;
-  for (int i = 0; i < num_params; ++i) {
-    param_dict[param_names[i]].emplace_back(
-        *reinterpret_cast<NDArray*>(params[i]));
-  }
-  *out = new std::shared_ptr<Imperative::CachedOp>(
-      new Imperative::CachedOp(*sym, flags, args, param_dict));
+  *out = new CachedOpPtr(new CachedOp(*sym, flags));
   API_END();
 }
 
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index dc0436e02a8..e70cc197c0c 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -278,6 +278,8 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
 }
 
 void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
+  BulkFlush();
+
   ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
   OprBlock* opr_block = OprBlock::New();
   opr_block->opr = threaded_opr;
@@ -323,7 +325,6 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
         << device_count_;
   }
 #endif
-  BulkFlush();
   ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
   opr->temporary = true;
   const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 697e4869a04..72919d90c62 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -134,6 +134,10 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
     return state_.get_var();
   }
 
+  OpStatePtr state() const override {
+    return state_;
+  }
+
   explicit StatefulComputeExecutor(const OpStatePtr& state,
                                    const FStatefulCompute& fcompute,
                                    ExecType exec_type,
@@ -142,7 +146,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
         state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
-  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulCompute fcompute_;
   ExecType exec_type_;
@@ -170,13 +173,16 @@ class StatefulComputeExExecutor : public OpExecutor {
     return state_.get_var();
   }
 
+  OpStatePtr state() const override {
+    return state_;
+  }
+
   explicit StatefulComputeExExecutor(const OpStatePtr& state,
                                      const FStatefulComputeEx& fcompute,
                                      ExecType exec_type)
       : state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
 
  private:
-  friend Graph AttachOpExecs(Graph g);
   OpStatePtr state_;
   FStatefulComputeEx fcompute_;
   ExecType exec_type_;
@@ -241,16 +247,15 @@ class FComputeExExecutor : public OpExecutor {
   ExecType exec_type_;
 };
 
-// pass to attach operator executors
-Graph AttachOpExecs(Graph g) {
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
   using nnvm::DTypeVector;
   using nnvm::ShapeVector;
   using nnvm::FMutateInputs;
 
-  auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
-  auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
-  auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
+  static auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
+  static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+  static auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
 
   const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
   const auto& vshape = g.GetAttr<ShapeVector>("shape");
@@ -259,81 +264,87 @@ Graph AttachOpExecs(Graph g) {
 
   // get the graph
   const auto& idx = g.indexed_graph();
-  std::vector<std::shared_ptr<OpExecutor> > ret(idx.num_nodes());
+  OpExecVector& ret = *p_ret;
 
   // initialize the nodes
-  for (size_t i = 0; i < idx.num_nodes(); ++i) {
-    const auto& inode = idx[i];
-    if (inode.source->is_variable()) continue;
-    const nnvm::Op *op = inode.source->op();
-    ExecType exec_type = ExecType::kSync;
-    std::vector<uint32_t> mutate_index;
-    if (fmutate_inputs.count(op)) {
-      mutate_index = fmutate_inputs[op](inode.source->attrs);
-    }
-    if (fexec_type.count(op)) {
-      exec_type = fexec_type[op](inode.source->attrs);
+  const auto& inode = idx[i];
+  if (inode.source->is_variable()) return;
+  const nnvm::Op *op = inode.source->op();
+  ExecType exec_type = ExecType::kSync;
+  std::vector<uint32_t> mutate_index;
+  if (fmutate_inputs.count(op)) {
+    mutate_index = fmutate_inputs[op](inode.source->attrs);
+  }
+  if (fexec_type.count(op)) {
+    exec_type = fexec_type[op](inode.source->attrs);
+  }
+  CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
+  if (fcreate_op_state.count(op)) {
+    std::vector<TShape> ishape;
+    std::vector<int> itype;
+    for (const auto& e : inode.inputs) {
+      ishape.emplace_back(vshape[idx.entry_id(e)]);
+      itype.emplace_back(vdtype[idx.entry_id(e)]);
     }
-    CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
-    if (fcreate_op_state.count(op)) {
-      std::vector<TShape> ishape;
-      std::vector<int> itype;
-      for (const auto& e : inode.inputs) {
-        ishape.emplace_back(vshape[idx.entry_id(e)]);
-        itype.emplace_back(vdtype[idx.entry_id(e)]);
-      }
 
-      OpStatePtr state = fcreate_op_state[op](
-          inode.source->attrs, vctx[i], ishape, itype);
-      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-          op, "FStatefulComputeEx", vctx[i]);
-      // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
-      if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-        ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
-      } else {
-        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-            op, "FStatefulCompute", vctx[i]);
-        CHECK(fcompute != nullptr)
-            << "One of FStatefulCompute and FStatefulComputeEx must be registered "
-            << "for stateful operator " << op->name;
-        ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
-                                                           exec_type, mutate_index);
-      }
-    } else if (is_layer_backward.get(op, false)) {
-      CHECK_GE(inode.control_deps.size(), 1);
-      uint32_t fwd_id = inode.control_deps[0];
-      CHECK(vctx[fwd_id] == vctx[i]);
-      CHECK(ret[fwd_id] != nullptr);
-      FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
-          op, "FStatefulComputeEx", vctx[i]);
-      // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
-      if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-        ret[i] = std::make_shared<StatefulComputeExExecutor>(
-            dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
-            fcompute_ex, exec_type);
-      } else {
-        FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
-            op, "FStatefulCompute", vctx[i]);
-        CHECK(fcompute != nullptr)
-            << "One of FStatefulCompute and FStatefulComputeEx must be registered "
-            << "for stateful operator " << op->name;
-        ret[i] = std::make_shared<StatefulComputeExecutor>(
-            dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
-            fcompute, exec_type, mutate_index);
-      }
+    OpStatePtr state = fcreate_op_state[op](
+        inode.source->attrs, vctx[i], ishape, itype);
+    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+        op, "FStatefulComputeEx", vctx[i]);
+    // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+    if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+      ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
     } else {
-      FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
-      FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
-      if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
-        ret[i] = std::make_shared<FComputeExExecutor>(
-            inode.source->attrs, fcomp_ex, exec_type);
-      } else if (fcompute != nullptr) {
-        ret[i] = std::make_shared<FComputeExecutor>(
-            inode.source->attrs, fcompute, exec_type, mutate_index);
-      } else {
-        LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
-      }
+      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+          op, "FStatefulCompute", vctx[i]);
+      CHECK(fcompute != nullptr)
+          << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+          << "for stateful operator " << op->name;
+      ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
+                                                         exec_type, mutate_index);
+    }
+  } else if (is_layer_backward.get(op, false)) {
+    CHECK_GE(inode.control_deps.size(), 1);
+    uint32_t fwd_id = inode.control_deps[0];
+    CHECK(vctx[fwd_id] == vctx[i]);
+    CHECK(ret[fwd_id] != nullptr);
+    FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+        op, "FStatefulComputeEx", vctx[i]);
+    // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
+    if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+      ret[i] = std::make_shared<StatefulComputeExExecutor>(
+          ret[fwd_id].get()->state(), fcompute_ex, exec_type);
+    } else {
+      FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+          op, "FStatefulCompute", vctx[i]);
+      CHECK(fcompute != nullptr)
+          << "One of FStatefulCompute and FStatefulComputeEx must be registered "
+          << "for stateful operator " << op->name;
+      ret[i] = std::make_shared<StatefulComputeExecutor>(
+          ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
     }
+  } else {
+    FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
+    FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
+    if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
+      ret[i] = std::make_shared<FComputeExExecutor>(
+          inode.source->attrs, fcomp_ex, exec_type);
+    } else if (fcompute != nullptr) {
+      ret[i] = std::make_shared<FComputeExecutor>(
+          inode.source->attrs, fcompute, exec_type, mutate_index);
+    } else {
+      LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+    }
+  }
+}
+
+
+// pass to attach operator executors
+Graph AttachOpExecs(Graph g) {
+  const auto& idx = g.indexed_graph();
+  OpExecVector ret(idx.num_nodes());
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    CreateOpExecs(g, &ret, i);
   }
   g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
   return g;
diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc
index 681866296e1..56122cda6ff 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -30,12 +30,15 @@
 namespace mxnet {
 namespace exec {
 
-Graph AttachOpResources(Graph g) {
+void AttachOpResources(
+    const Graph& g,
+    const OpExecVector& op_execs,
+    size_t start_nid,
+    size_t end_nid) {
   static auto& fresource =
       nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
   static auto& fresource_ex =
       nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
-  auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
   const auto& vctx = g.GetAttr<ContextVector>("context");
   const auto& vdispatch = g.GetAttr<DispatchModeVector>("dispatch_mode");
   const auto& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
@@ -43,7 +46,7 @@ Graph AttachOpResources(Graph g) {
   // Use global resource pool for each executor for now.
   std::map<Context, Resource> cached_temp;
   // Resource allocation
-  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
+  for (uint32_t nid = start_nid; nid < end_nid; ++nid) {
     const auto& inode = idx[nid];
     if (inode.source->is_variable()) continue;
     const Context &ctx = vctx[nid];
@@ -84,7 +87,12 @@ Graph AttachOpResources(Graph g) {
       requested.push_back(ResourceManager::Get()->Request(ctx, ResourceRequest::kTempSpace));
     }
   }
-  return g;
 }
+
+void AttachOpResources(const Graph& g) {
+  const auto& op_execs = g.GetAttr<OpExecVector>("op_execs");
+  AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes());
+}
+
 }  // namespace exec
 }  // namespace mxnet
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 99b1b162eae..26a24911894 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -82,6 +82,10 @@ class OpExecutor {
   virtual engine::VarHandle var() const {
     return nullptr;
   }
+  /*! \return return operator state */
+  virtual OpStatePtr state() const {
+    return OpStatePtr();
+  }
 };
 
 /*!
@@ -102,6 +106,14 @@ using ContextVector = std::vector<Context>;
  */
 using DevMaskVector = std::vector<int>;
 
+/*!
+ * \brief create OpExecutor for a node in graph
+ *
+ * \param g input graph
+ * \param p_ret OpExecVector for input and output
+ * \param i the id of the node
+ */
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
 /*!
  * \brief Attach OpExecutor to the graph attributes.
  *
@@ -115,12 +127,20 @@ Graph AttachOpExecs(Graph g);
  * \brief Attach Resource to the OpExecVector of the graph.
  *
  * \param g input graph need to contain op_exec attribute.
+ */
+void AttachOpResources(const Graph& g);
+/*!
+ * \brief Attach Resource to the OpExecVector
  *
- * \return graph with new attribute "op_exec" of type OpExecVector
- *  The fields on the OpExecVector are not yet been setup.
+ * \param g input graph
+ * \param op_execs OpExecutor vector
+ * \param start_nid starting node id
+ * \param end_nid end node id
  */
-Graph AttachOpResources(Graph g);
-
+void AttachOpResources(const Graph& g,
+                       const OpExecVector& op_execs,
+                       size_t start_nid,
+                       size_t end_nid);
 /*!
  * \brief Discover chance of inplace addto operators.
  *  i.e. z = plus(z, source_op), and encourage it to become z += source_op.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index e28867d5488..831b5f90023 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
   }
 
   g = AttachOpExecs(g);
-  g = AttachOpResources(g);
+  AttachOpResources(g);
   graph_ = std::move(g);
 
   if (shared_exec != nullptr) {
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 140b5a5d81e..b17fae4b3cf 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -19,16 +19,78 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
+#include "./cached_op.h"
+#include "../executor/exec_pass.h"
+#include "../profiler/profiler.h"
+
 
 namespace mxnet {
 
 DMLC_REGISTER_PARAMETER(CachedOpConfig);
 
-Imperative::CachedOp::CachedOp(
+struct CachedOp::GraphInfo {
+  nnvm::Graph fwd_graph;
+  nnvm::Graph full_graph;
+  std::vector<OpReqType> bwd_output_reqs;
+  std::vector<uint32_t> bwd_input_eid;
+};
+
+struct CachedOp::DynamicRuntime {
+  GraphInfo info;
+  std::vector<NDArray> buff;
+  std::vector<OpStatePtr> op_states;
+};
+
+struct CachedOp::CachedOpState {
+  CachedOpState(const Context& context_,
+                const nnvm::Graph& fwd_graph_,
+                const nnvm::Graph& full_graph_) {
+    context = context_;
+    info.fwd_graph = fwd_graph_;
+    info.full_graph = full_graph_;
+
+    size_t max_nodes = info.full_graph.indexed_graph().num_nodes();
+    size_t max_entries = info.full_graph.indexed_graph().num_node_entries();
+    info.fwd_graph.attrs["context"] = std::make_shared<dmlc::any>(
+        std::vector<Context>(info.fwd_graph.indexed_graph().num_nodes(), context));
+    info.full_graph.attrs["context"] = std::make_shared<dmlc::any>(
+        std::vector<Context>(max_nodes, context));
+
+    buff.resize(max_entries);
+    arrays.resize(max_entries);
+    array_reqs.resize(max_entries);
+    dynamic_entries.resize(max_entries, false);
+    op_states.resize(max_nodes);
+    execs.resize(max_nodes);
+    opr_segs.resize(max_nodes);
+  }
+
+  std::mutex mutex;
+  Context context;
+  GraphInfo info;
+
+  bool recording = false;
+  bool fwd_alloc = false;
+  bool bwd_alloc = false;
+  bool fwd_exec_init = false;
+  bool bwd_exec_init = false;
+
+  std::vector<NDArray> buff;
+  std::vector<NDArray*> arrays;
+  std::vector<OpReqType> array_reqs;
+
+  std::vector<OpStatePtr> op_states;
+  std::vector<std::shared_ptr<exec::OpExecutor> > execs;
+  std::vector<imperative::EngineOprSeg> opr_segs;
+
+  std::vector<bool> dynamic_entries;
+  std::multimap<size_t, NDArray> fwd_reuse_pool;
+  std::multimap<size_t, NDArray> bwd_reuse_pool;
+};
+
+CachedOp::CachedOp(
     const nnvm::Symbol& sym,
-    const std::vector<std::pair<std::string, std::string> >& flags,
-    const std::vector<std::string> arg_names,
-    const std::unordered_map<std::string, std::vector<NDArray> >& params) {
+    const std::vector<std::pair<std::string, std::string> >& flags) {
   using namespace nnvm;
   using namespace imperative;
   static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
@@ -68,34 +130,22 @@ Imperative::CachedOp::CachedOp(
     fwd_graph_.attrs["forward_ref_count"] =
         std::make_shared<dmlc::any>(std::move(ref_count));
 
-    inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
+    inlining_ = !config_.static_alloc &&
+        (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
   }
 
   // Set params
   {
     const auto& idx = fwd_graph_.indexed_graph();
-    std::unordered_map<std::string, size_t> arg_name_to_id;
-    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
-      const auto& name = idx[idx.input_nodes()[i]].source->attrs.name;
-      auto iter = params.find(name);
-      if (iter == params.end()) {
-        arg_name_to_id[name] = i;
-        continue;
-      }
-      fwd_params_idx_.push_back(i);
-      for (const auto& param : iter->second) {
-        params_[param.ctx()].emplace_back(param);
+    if (config_.data_indices.ndim() || config_.param_indices.ndim()) {
+      CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(),
+               idx.input_nodes().size());
+    } else {
+      std::vector<uint32_t> tmp;
+      for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+        tmp.push_back(i);
       }
-    }
-
-    CHECK_EQ(arg_name_to_id.size(), arg_names.size())
-        << "CachedOp expects " << arg_name_to_id.size()
-        << " inputs, given " << arg_names.size();
-
-    for (const auto& name : arg_names) {
-      auto iter = arg_name_to_id.find(name);
-      CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name;
-      fwd_args_idx_.push_back(iter->second);
+      config_.data_indices.assign(tmp.begin(), tmp.end());
     }
   }
 
@@ -107,9 +157,14 @@ Imperative::CachedOp::CachedOp(
     }
 
     std::vector<NodeEntry> xs;
-    std::vector<NodePtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
-    xs.reserve(args.size());
-    for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0});
+    const auto& idx = fwd_graph_.indexed_graph();
+    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+      auto nid = idx.input_nodes()[i];
+      if (idx.mutable_input_nodes().count(nid)) continue;
+      fwd_input_to_grad_output_[i] = xs.size();
+      xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0});
+    }
+
     CHECK_GT(xs.size(), 0)
         << "There are no inputs in computation graph that require gradients.";
 
@@ -125,7 +180,7 @@ Imperative::CachedOp::CachedOp(
     size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
 
     full_graph_.outputs = fwd_graph_.outputs;
-    curr_grad_req_ = std::vector<bool>(grad_graph_.outputs.size(), true);
+    bwd_output_reqs_ = std::vector<OpReqType>(grad_graph_.outputs.size(), kWriteTo);
     for (const auto& i : grad_graph_.outputs) full_graph_.outputs.emplace_back(i);
     const auto& idx = full_graph_.indexed_graph();
 
@@ -169,7 +224,10 @@ Imperative::CachedOp::CachedOp(
   }
 }
 
-std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
+CachedOp::~CachedOp() {
+}
+
+std::vector<nnvm::NodeEntry> CachedOp::Gradient(
     const nnvm::NodePtr& node,
     const std::vector<nnvm::NodeEntry>& ograds) {
   using namespace nnvm;
@@ -206,13 +264,15 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
   return ret;
 }
 
-nnvm::Graph Imperative::CachedOp::GetForwardGraph(
-    const bool recording, const std::vector<NDArray*>& inputs) {
+
+bool CachedOp::SetForwardGraph(
+    GraphInfo* info,
+    const bool recording,
+    const std::vector<NDArray*>& inputs) {
   using namespace nnvm;
   using namespace imperative;
-  std::lock_guard<std::mutex> lock(mutex_);
   CHECK_EQ(inputs.size(), num_inputs());
-  nnvm::Graph& g = fwd_graph_;
+  nnvm::Graph& g = info->fwd_graph;
 
   ShapeVector shape_inputs;
   DTypeVector dtype_inputs;
@@ -237,18 +297,22 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
     g.attrs.erase("forward_mem_plan");
     g.attrs.erase("full_mem_plan");
   } else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
-    return g;
+    return true;
   }
 
   const auto& idx = g.indexed_graph();
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
-  for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
   CHECK_EQ(stypes.size(), storage.size());
   for (size_t i = 0; i < stypes.size(); i++) {
-    if (stypes[i] != kDefaultStorage)
-      storage[i] = exec::kDynamicStorageID;
+    if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
+  }
+  for (const auto i : idx.input_nodes()) {
+    storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
+  }
+  for (size_t i = 0; i < idx.outputs().size(); ++i) {
+    storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID;
   }
 
   auto mem_plan = PlanMemory(
@@ -257,51 +321,50 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
   g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
       std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return g;
+  return false;
 }
 
-nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
-    const OpStatePtr& op_state,
+bool CachedOp::SetBackwardGraph(
+    GraphInfo* info,
     const std::vector<OpReqType>& reqs,
-    const std::vector<NDArray*>& inputs) {
+    const std::vector<NDArray*>& inputs,
+    bool detect_inplace_addto) {
   using namespace nnvm;
   using namespace imperative;
   std::lock_guard<std::mutex> lock(mutex_);
-  nnvm::Graph& g = full_graph_;
-  auto& state = op_state.get_state<CachedOpState>();
-  bool req_match = true;
-  for (size_t i = 0; i < reqs.size(); ++i) {
-    if (curr_grad_req_[i] != (reqs[i] != kNullOp)) {
-      curr_grad_req_[i] = reqs[i] != kNullOp;
-      req_match = false;
-    }
-  }
-  if (!req_match) {
+  Context default_ctx = inputs[0]->ctx();
+  nnvm::Graph& g = info->full_graph;
+
+  if (info->bwd_output_reqs != reqs) {
+    info->bwd_output_reqs = reqs;
+    info->bwd_input_eid.clear();
     g = nnvm::Graph();
     g.outputs = fwd_graph_.outputs;
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
-      if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
+      if (info->bwd_output_reqs[i] == kNullOp) continue;
+      g.outputs.emplace_back(grad_graph_.outputs[i]);
     }
-    bwd_input_eid_.clear();
+    g.attrs["context"] = std::make_shared<dmlc::any>(
+        std::vector<Context>(g.indexed_graph().num_nodes(), default_ctx));
   }
 
   const auto& idx = g.indexed_graph();
 
-  if (bwd_input_eid_.size() != inputs.size()) {
-    bwd_input_eid_.clear();
+  if (info->bwd_input_eid.size() != inputs.size()) {
+    info->bwd_input_eid.clear();
     for (const auto& i : bwd_ograd_dep_) {
       auto eid = idx.entry_id(ograd_entries_[i]);
-      bwd_input_eid_.push_back(eid);
+      info->bwd_input_eid.push_back(eid);
     }
     for (const auto& i : bwd_in_dep_) {
       auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      bwd_input_eid_.push_back(eid);
+      info->bwd_input_eid.push_back(eid);
     }
     for (const auto& i : bwd_out_dep_) {
       auto eid = idx.entry_id(idx.outputs()[i]);
-      bwd_input_eid_.push_back(eid);
+      info->bwd_input_eid.push_back(eid);
     }
-    CHECK_EQ(inputs.size(), bwd_input_eid_.size());
+    CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
   }
 
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -312,25 +375,22 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
     for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
       for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
     }
-    for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]];
+    for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[info->bwd_input_eid[i]];
     for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
     g.attrs["backward_ref_count"] = std::make_shared<dmlc::any>(std::move(ref_count));
   }
 
-  ShapeVector shapes(idx.num_node_entries(), TShape());
-  DTypeVector dtypes(idx.num_node_entries(), -1);
-  StorageTypeVector stypes(idx.num_node_entries(), -1);
-
-  for (size_t i = 0; i < num_forward_entries; ++i) {
-    shapes[i] = state.buff[i].shape();
-    dtypes[i] = state.buff[i].dtype();
-    stypes[i] = state.buff[i].storage_type();
-  }
+  auto shapes = info->fwd_graph.GetAttr<ShapeVector>("shape");
+  shapes.resize(idx.num_node_entries(), TShape());
+  auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
+  dtypes.resize(idx.num_node_entries(), -1);
+  auto stypes = info->fwd_graph.GetAttr<StorageTypeVector>("storage_type");
+  stypes.resize(idx.num_node_entries(), -1);
 
   for (size_t i = 0; i < inputs.size(); ++i) {
-    shapes[bwd_input_eid_[i]] = inputs[i]->shape();
-    dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
-    stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
+    shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
+    dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
+    stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
   }
 
   std::pair<uint32_t, uint32_t> node_range, entry_range;
@@ -342,79 +402,353 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
                               node_range, entry_range);
   match &= CheckAndInferType(&g, std::move(dtypes), false,
                              node_range, entry_range);
-  exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask());
+  exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask());
   match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes),
                                     false, node_range, entry_range);
 
   if (!match) {
     g.attrs.erase("backward_mem_plan");
   } else if (g.attrs.count("backward_mem_plan")) {
-    return g;
+    return true;
   }
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
+  const auto& bwd_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  for (size_t i = 0; i < bwd_stypes.size(); i++) {
+    if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
+  }
   for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID;
   for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
   for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID;
-  for (size_t i = 0; i < stypes.size(); i++) {
-    if (stypes[i] != kDefaultStorage)
-      storage[i] = exec::kDynamicStorageID;
-  }
 
   auto mem_plan = PlanMemory(
       &g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
-      {num_forward_nodes, idx.num_nodes()}, {num_forward_entries, idx.num_node_entries()});
+      {num_forward_nodes, idx.num_nodes()},
+      {num_forward_entries, idx.num_node_entries()},
+      detect_inplace_addto);
   g.attrs["backward_mem_plan"] = std::make_shared<dmlc::any>(std::move(mem_plan));
 
-  return g;
+  return false;
 }
 
-void Imperative::CachedOp::Forward(
-    const std::shared_ptr<CachedOp>& op_ptr,
-    const std::vector<NDArray*>& args,
-    const std::vector<NDArray*>& outputs) {
+OpStatePtr CachedOp::GetCachedOpState(
+    const Context& ctx) {
+  std::lock_guard<std::mutex> lock(mutex_);
+  for (const auto& i : cached_op_states_[ctx]) {
+    // only create one state per device when not using static memory
+    if (!config_.static_alloc || i.unique()) {
+      return i;
+    }
+  }
+  auto state_ptr = OpStatePtr::Create<CachedOpState>(ctx, fwd_graph_, full_graph_);
+
+  cached_op_states_[ctx].push_back(state_ptr);
+  return state_ptr;
+}
+
+void CachedOp::StaticAllocMemory(
+    const OpStatePtr& state_ptr,
+    bool recording,
+    bool keep_fwd) {
   using namespace nnvm;
   using namespace imperative;
-  static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
-  CHECK_EQ(args.size(), fwd_args_idx_.size())
-      << "CachedOp requires " << fwd_args_idx_.size()
-      << " inputs but got " << args.size();
+  auto& state = state_ptr.get_state<CachedOpState>();
+  const auto& default_ctx = state.context;
+  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
+  const auto& idx = g.indexed_graph();
+  const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
+  const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
+      keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
+  std::vector<int> addto_entry;
+  if (g.attrs.count("addto_entry")) {
+    addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
+  }
+  size_t start_eid =
+      keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0;
+  size_t end_eid = idx.num_node_entries();
+
+  if (!keep_fwd) state.fwd_alloc = false;
+  state.bwd_alloc = false;
+  for (size_t i = start_eid; i < state.buff.size(); ++i) {
+    state.buff[i] = NDArray();
+    state.arrays[i] = &state.buff[i];
+    state.array_reqs[i] = kNullOp;
+    state.dynamic_entries[i] = false;
+  }
+
+  for (auto i : idx.input_nodes()) {
+    auto eid = idx.entry_id(i, 0);
+    if (eid >= start_eid) state.dynamic_entries[eid] = true;
+  }
+  for (auto i : idx.outputs()) {
+    auto eid = idx.entry_id(i);
+    if (eid >= start_eid) state.dynamic_entries[eid] = true;
+  }
+
+  for (size_t i = start_eid; i < end_eid; ++i) {
+    if (addto_entry.size() && addto_entry[i]) {
+      state.array_reqs[i] = kAddTo;
+    } else if (vstorage_inplace[i] >= 0) {
+      state.array_reqs[i] = kWriteInplace;
+    } else if (vstorage_inplace[i] == -2) {
+      // -2 indicate that the entry is never referenced.
+      state.array_reqs[i] = kNullOp;
+    } else {
+      state.array_reqs[i] = kWriteTo;
+    }
+  }
+
+  auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool;
+  reuse_pool = imperative::AllocateMemory(
+      g, idx, default_ctx, start_eid, end_eid, mem_plan,
+      state.arrays, &state.array_reqs, std::move(reuse_pool));
 
-  Context default_ctx = args[0]->ctx();
+  state.recording = recording;
+  if (keep_fwd) {
+    state.bwd_alloc = true;
+  } else {
+    state.fwd_alloc = true;
+  }
+}
 
+void CachedOp::StaticInitExec(
+    const OpStatePtr& state_ptr,
+    bool recording,
+    bool keep_fwd) {
+  using namespace nnvm;
+  using namespace imperative;
 
-  std::vector<NDArray*> inputs(num_inputs());
-  for (index_t i = 0; i < fwd_args_idx_.size(); ++i) {
-    inputs[fwd_args_idx_[i]] = args[i];
+  auto& state = state_ptr.get_state<CachedOpState>();
+  const auto& default_ctx = state.context;
+  nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
+  const auto& idx = g.indexed_graph();
+  std::vector<int> skip_plus_node;
+  if (g.attrs.count("skip_plus_node")) {
+    skip_plus_node = g.GetAttr<std::vector<int> >("skip_plus_node");
   }
-  if (fwd_params_idx_.size()) {
-    CHECK(params_.find(default_ctx) != params_.end())
-        << "CachedOp is not initialized on context " << default_ctx;
+  size_t start_nid =
+      keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0;
+  size_t end_nid = idx.num_nodes();
 
-    for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
-      inputs[fwd_params_idx_[i]] = &params_[default_ctx][i];
+  if (!keep_fwd) state.fwd_exec_init = false;
+  state.bwd_exec_init = false;
+
+  for (size_t i = start_nid; i < state.execs.size(); ++i) {
+    state.execs[i].reset();
+    state.opr_segs[i] = EngineOprSeg();
+  }
+
+  if (!config_.static_shape) {
+    for (size_t i = start_nid; i < end_nid; ++i) {
+      state.opr_segs[i].next_nid = i + 1;
+      state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i];
+    }
+  } else {
+    for (size_t i = start_nid; i < end_nid; ++i) {
+      exec::CreateOpExecs(g, &state.execs, i);
+    }
+    exec::AttachOpResources(g, state.execs, start_nid, end_nid);
+
+    for (size_t i = start_nid; i < end_nid; ++i) {
+      bool skip = idx[i].source->is_variable();
+      for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) {
+        skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])];
+      }
+      for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) {
+        skip = state.dynamic_entries[idx.entry_id(i, j)];
+      }
+      if (skip) continue;
+      SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
     }
+
+    size_t bulk_size = idx.num_nodes();
+    std::unordered_set<uint32_t> excludes;
+    if (recording || keep_fwd) {
+      bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
+      for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
+      for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0));
+    }
+
+    CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes,
+                      state.execs, skip_plus_node, &state.opr_segs);
   }
 
-  // Initialize
+  if (keep_fwd) {
+    state.bwd_exec_init = true;
+  } else {
+    state.fwd_exec_init = true;
+  }
+}
+
+void CachedOp::StaticRunOps(
+    const Context& default_ctx,
+    const nnvm::Graph& g,
+    const OpStatePtr& state_ptr,
+    size_t start_nid,
+    size_t end_nid) {
+  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+
+  bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
+  bool is_training = Imperative::Get()->is_training();
+  auto& state = state_ptr.get_state<CachedOpState>();
+  const auto& idx = g.indexed_graph();
+  const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+  const auto& op_execs = state.execs;
+
+  std::vector<NDArray*> ndinputs, ndoutputs;
+  nnvm::ShapeVector arg_shapes;
+  nnvm::DTypeVector arg_dtypes;
+  std::vector<OpReqType> req;
+
+  for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) {
+    if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training;
+  }
+
+  for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) {
+    const auto& opr_seg = state.opr_segs[i];
+    if (opr_seg.skip) continue;
+    if (opr_seg.opr != nullptr) {
+      Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling);
+    } else {
+      const nnvm::IndexedGraph::Node& node = idx[i];
+      if (node.source->is_variable()) continue;
+      auto num_outputs = node.source->num_outputs();
+      ndinputs.clear();
+      ndinputs.reserve(node.inputs.size());
+      for (const auto& j : node.inputs) {
+        ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
+        CHECK(!ndinputs.back()->is_none());
+      }
+      ndoutputs.clear();
+      ndoutputs.reserve(num_outputs);
+      req.clear();
+      req.reserve(num_outputs);
+      for (size_t j = 0; j < num_outputs; ++j) {
+        size_t eid = idx.entry_id(i, j);
+        ndoutputs.emplace_back(state.arrays[eid]);
+        req.push_back(state.array_reqs[eid]);
+        CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
+      }
+      const DispatchMode dispatch_mode = dispatch_modes[i];
+      if (createop.count(node.source->op())) {
+        arg_shapes.clear();
+        arg_dtypes.clear();
+        arg_shapes.reserve(ndinputs.size());
+        arg_dtypes.reserve(ndinputs.size());
+        for (size_t i = 0; i < ndinputs.size(); ++i) {
+          arg_shapes.emplace_back(ndinputs[i]->shape());
+          arg_dtypes.emplace_back(ndinputs[i]->dtype());
+        }
+        state.op_states[i] = createop[node.source->op()](
+            node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
+        Imperative::Get()->InvokeOp(
+            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
+            dispatch_mode, state.op_states[i]);
+      } else if (is_layer_backward.get(node.source->op(), false)) {
+        nnvm::Node* fwd_node = node.source->control_deps[0].get();
+        auto fwd_node_id = idx.node_id(fwd_node);
+        Imperative::Get()->InvokeOp(
+            default_ctx, node.source->attrs, ndinputs, ndoutputs,
+            req, dispatch_mode, state.op_states[fwd_node_id]);
+      } else {
+        Imperative::Get()->InvokeOp(
+            default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
+            dispatch_mode);
+      }
+    }
+  }
+}
+
+OpStatePtr CachedOp::StaticForward(
+    const Context& default_ctx,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+
   bool recording = Imperative::Get()->is_recording();
-  nnvm::Graph g = GetForwardGraph(recording, inputs);
+  auto state_ptr = GetCachedOpState(default_ctx);
+  auto& state = state_ptr.get_state<CachedOpState>();
+  std::lock_guard<std::mutex> lock(state.mutex);
+
+  bool match = SetForwardGraph(&state.info, recording, inputs);
+  match = match && state.recording == recording;
+
+  nnvm::Graph& g = state.info.fwd_graph;
   const auto& idx = g.indexed_graph();
-  size_t num_inputs = idx.input_nodes().size();
+  if (!state.fwd_alloc || !match)  {
+    StaticAllocMemory(state_ptr, recording, false);
+  }
 
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    CHECK_EQ(inputs[i]->ctx(), default_ctx)
-        << "CachedOp requires all inputs to live on the same context. But "
-        << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx
-        << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is on "
-        << inputs[i]->ctx();
+  if (config_.static_shape) {
+    for (auto i : config_.param_indices) {
+      auto nid = idx.input_nodes()[i];
+      if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
+        match = false;
+        auto ptr = &state.buff[idx.entry_id(nid, 0)];
+        CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
+        *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
+        state.dynamic_entries[idx.entry_id(nid, 0)] = false;
+      }
+    }
+    for (auto i : config_.data_indices) {
+      auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+      state.arrays[eid] = inputs[i];
+    }
+  } else {
+    for (size_t i = 0; i < num_inputs(); ++i) {
+      auto nid = idx.input_nodes()[i];
+      state.arrays[idx.entry_id(nid, 0)] = inputs[i];
+    }
   }
 
-  auto op_state_ptr = OpStatePtr::Create<CachedOpState>();
-  auto& cached_op_state = op_state_ptr.get_state<CachedOpState>();
-  auto& buff = cached_op_state.buff;
-  auto& states = cached_op_state.states;
+  if (!state.fwd_exec_init || !match) {
+    StaticInitExec(state_ptr, recording, false);
+  }
+
+  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const auto& shapes = g.GetAttr<ShapeVector>("shape");
+  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    auto eid = idx.entry_id(idx.outputs()[i]);
+    state.arrays[eid] = outputs[i];
+    if (!outputs[i]->is_none()) continue;
+    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
+                          shapes[eid], default_ctx, true, dtypes[eid]);
+  }
+
+  StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
+
+  return recording ? state_ptr : OpStatePtr();
+}
+
+
+OpStatePtr CachedOp::DynamicForward(
+    const Context& default_ctx,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+
+  // Initialize
+  bool recording = Imperative::Get()->is_recording();
+  auto op_state = OpStatePtr::Create<DynamicRuntime>();
+  auto& runtime = op_state.get_state<DynamicRuntime>();
+  {
+    auto state_ptr = GetCachedOpState(default_ctx);
+    auto& state = state_ptr.get_state<CachedOpState>();
+    std::lock_guard<std::mutex> lock(state.mutex);
+    SetForwardGraph(&state.info, recording, inputs);
+    runtime.info.fwd_graph = state.info.fwd_graph;
+  }
+  nnvm::Graph& g = runtime.info.fwd_graph;
+  const auto& idx = g.indexed_graph();
+  size_t num_inputs = idx.input_nodes().size();
+  auto& buff = runtime.buff;
+  auto& states = runtime.op_states;
 
   // Allocate entries
   states.resize(idx.num_nodes());
@@ -446,57 +780,98 @@ void Imperative::CachedOp::Forward(
   AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
+  const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+  const auto& shapes = g.GetAttr<ShapeVector>("shape");
+  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    auto eid = idx.entry_id(idx.outputs()[i]);
+    arrays[eid] = outputs[i];
+    if (!outputs[i]->is_none()) continue;
+    *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
+                          shapes[eid], default_ctx, true, dtypes[eid]);
+  }
+
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
   if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
-  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
 
-  Imperative::Get()->RunGraph(
-      false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
-      std::move(ref_count), &states, dispatch_modes);
+  RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
+           std::move(ref_count), &states, dispatch_modes);
 
-  Engine::Get()->set_bulk_size(prev_bulk_size);
   Imperative::Get()->set_is_recording(recording);
 
-  for (size_t i = 0; i < idx.num_node_entries(); ++i) {
-    if (arrays[i] == &buff[i]) continue;
-    buff[i].shape_ = arrays[i]->shape_;
-    buff[i].dtype_ = arrays[i]->dtype_;
-    buff[i].storage_type_ = arrays[i]->storage_type_;
+  return op_state;
+}
+
+void CachedOp::Forward(
+    const std::shared_ptr<CachedOp>& op_ptr,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<NDArray*>& outputs) {
+  static const auto cached_op = nnvm::Op::Get("_CachedOp");
+
+  CHECK_EQ(inputs.size(), num_inputs());
+
+  Context default_ctx = inputs[0]->ctx();
+
+  const auto& idx = fwd_graph_.indexed_graph();
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    CHECK_EQ(inputs[i]->ctx(), default_ctx)
+        << "CachedOp requires all inputs to live on the same context. But "
+        << idx[idx.input_nodes()[0]].source->attrs.name
+        << " is on " << default_ctx << " while "
+        << idx[idx.input_nodes()[i]].source->attrs.name
+        << " is on " << inputs[i]->ctx();
   }
 
-  if (recording && !inlining_) {
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
+
+  OpStatePtr op_state;
+  if (config_.static_alloc) {
+    op_state = StaticForward(default_ctx, inputs, outputs);
+  } else {
+    op_state = DynamicForward(default_ctx, inputs, outputs);
+  }
+
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+
+  if (Imperative::Get()->is_recording() && !inlining_) {
     nnvm::NodeAttrs attrs;
     attrs.op = cached_op;
     attrs.name = "_cachedop";
     attrs.parsed = op_ptr;
     Imperative::Get()->RecordOp(
-        std::move(attrs), inputs, outputs, op_state_ptr,
+        std::move(attrs), inputs, outputs, op_state,
         &save_inputs(), &save_outputs());
   }
 }
 
 
-void Imperative::CachedOp::Backward(
+void CachedOp::DynamicBackward(
     const bool retain_graph,
-    const OpStatePtr& state,
+    const OpStatePtr& op_state,
     const std::vector<NDArray*>& inputs,
     const std::vector<OpReqType>& reqs,
     const std::vector<NDArray*>& outputs) {
   using namespace nnvm;
   using namespace imperative;
-  CHECK(!Imperative::Get()->is_recording())
-      << "CachedOp does not support higher order gradients. "
-      << "If you want to do backward with create_graph=True please "
-      << "do not use hybridize.";
 
   // Initialize
-  nnvm::Graph g = GetBackwardGraph(state, reqs, inputs);
+  Context default_ctx = outputs[0]->ctx();
+  auto& runtime = op_state.get_state<DynamicRuntime>();
+  {
+    auto state_ptr = GetCachedOpState(default_ctx);
+    auto& state = state_ptr.get_state<CachedOpState>();
+    std::lock_guard<std::mutex> lock(state.mutex);
+    state.info.fwd_graph = runtime.info.fwd_graph;
+    SetBackwardGraph(&state.info, reqs, inputs);
+    runtime.info.full_graph = state.info.full_graph;
+    runtime.info.bwd_input_eid = state.info.bwd_input_eid;
+  }
+  nnvm::Graph& g = runtime.info.full_graph;
   const auto& idx = g.indexed_graph();
-
-  auto& cached_op_state = state.get_state<CachedOpState>();
-  auto& buff = cached_op_state.buff;
-  auto& states = cached_op_state.states;
+  auto& buff = runtime.buff;
+  auto& states = runtime.op_states;
 
   size_t num_forward_outputs = fwd_graph_.outputs.size();
   size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -506,7 +881,7 @@ void Imperative::CachedOp::Backward(
   arrays.reserve(buff.size());
   for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]);
   for (size_t i = 0; i < inputs.size(); ++i) {
-    arrays[bwd_input_eid_[i]] = inputs[i];
+    arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
     if (reqs[i] == kNullOp) continue;
@@ -530,20 +905,14 @@ void Imperative::CachedOp::Backward(
     if (ref_count[i] == 0) array_reqs[i] = kNullOp;
   }
 
-  Context default_ctx = outputs[0]->ctx();
   const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
   AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(),
                  mem_plan, arrays, &array_reqs);
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
-  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
-
-  Imperative::Get()->RunGraph(
-      retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
-      std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
-
-  Engine::Get()->set_bulk_size(prev_bulk_size);
+  RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
+           std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
 
   if (retain_graph) {
     buff.resize(num_forward_entries);
@@ -553,6 +922,99 @@ void Imperative::CachedOp::Backward(
   }
 }
 
+void CachedOp::StaticBackward(
+    const bool retain_graph,
+    const OpStatePtr& state_ptr,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<OpReqType>& reqs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+
+  Context default_ctx = outputs[0]->ctx();
+
+  auto& state = state_ptr.get_state<CachedOpState>();
+  std::lock_guard<std::mutex> lock(state.mutex);
+
+  bool match = SetBackwardGraph(&state.info, reqs, inputs, true);
+
+  nnvm::Graph& g = state.info.full_graph;
+  const auto& idx = g.indexed_graph();
+  auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes();
+
+  if (!state.bwd_alloc || !match) {
+    StaticAllocMemory(state_ptr, true, true);
+  }
+
+  if (config_.static_shape) {
+    for (auto i : config_.param_indices) {
+      const auto iter = fwd_input_to_grad_output_.find(i);
+      if (iter == fwd_input_to_grad_output_.end()) continue;
+      auto entry = grad_graph_.outputs[iter->second];
+      if (!idx.exist(entry.node.get())) continue;
+      auto eid = idx.entry_id(entry);
+      if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
+          !(state.array_reqs[eid] == reqs[iter->second])) {
+        match = false;
+        state.array_reqs[eid] = reqs[iter->second];
+        *state.arrays[eid] = *outputs[iter->second];
+        state.dynamic_entries[eid] = false;
+      }
+    }
+    for (auto i : config_.data_indices) {
+      const auto iter = fwd_input_to_grad_output_.find(i);
+      if (iter == fwd_input_to_grad_output_.end()) continue;
+      auto entry = grad_graph_.outputs[iter->second];
+      if (!idx.exist(entry.node.get())) continue;
+      auto eid = idx.entry_id(entry);
+      state.array_reqs[eid] = reqs[iter->second];
+      state.arrays[eid] = outputs[iter->second];
+    }
+  } else {
+    for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
+      auto entry = grad_graph_.outputs[i];
+      if (!idx.exist(entry.node.get())) continue;
+      auto eid = idx.entry_id(entry);
+      state.array_reqs[eid] = reqs[i];
+      state.arrays[eid] = outputs[i];
+    }
+  }
+
+  if (!state.bwd_exec_init || !match) {
+    StaticInitExec(state_ptr, true, true);
+  }
+
+  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
+    auto eid = state.info.bwd_input_eid[i];
+    if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
+  }
+
+  StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
+}
+
+void CachedOp::Backward(
+    const bool retain_graph,
+    const OpStatePtr& state,
+    const std::vector<NDArray*>& inputs,
+    const std::vector<OpReqType>& reqs,
+    const std::vector<NDArray*>& outputs) {
+  using namespace imperative;
+  CHECK(!Imperative::Get()->is_recording())
+      << "CachedOp does not support higher order gradients. "
+      << "If you want to do backward with create_graph=True please "
+      << "do not use hybridize.";
+
+  int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size);
+
+  if (config_.static_alloc) {
+    StaticBackward(retain_graph, state, inputs, reqs, outputs);
+  } else {
+    DynamicBackward(retain_graph, state, inputs, reqs, outputs);
+  }
+
+  Engine::Get()->set_bulk_size(prev_bulk_size);
+}
+
 
 NNVM_REGISTER_OP(_CachedOp)
 .set_num_inputs([](const NodeAttrs& attrs) {
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
new file mode 100644
index 00000000000..60a40c5e4a5
--- /dev/null
+++ b/src/imperative/cached_op.h
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_IMPERATIVE_CACHED_OP_H_
+#define MXNET_IMPERATIVE_CACHED_OP_H_
+
+#include <mxnet/imperative.h>
+#include <vector>
+#include <atomic>
+#include <utility>
+#include <string>
+#include <unordered_map>
+
+namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
+  uint32_t inline_limit;
+  uint32_t forward_bulk_size;
+  uint32_t backward_bulk_size;
+  bool static_alloc;
+  bool static_shape;
+  nnvm::Tuple<uint32_t> data_indices;
+  nnvm::Tuple<uint32_t> param_indices;
+  DMLC_DECLARE_PARAMETER(CachedOpConfig) {
+    DMLC_DECLARE_FIELD(static_alloc)
+    .set_default(false)
+    .describe("Statically allocate memory to improve speed. "
+              "Memory usage may increase.");
+    DMLC_DECLARE_FIELD(static_shape)
+    .set_default(false)
+    .describe("Optimize for invariant input shapes between iterations. "
+              "Must also set static_alloc to True. "
+              "Change of input shapes is still allowed but slower.");
+    DMLC_DECLARE_FIELD(inline_limit)
+    .set_default(2)
+    .describe("Maximum number of operators that can be inlined.");
+    DMLC_DECLARE_FIELD(forward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during forward pass.");
+    DMLC_DECLARE_FIELD(backward_bulk_size)
+    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .describe("Segment size of bulk execution during backward pass.");
+    DMLC_DECLARE_FIELD(data_indices)
+    .set_default(nnvm::Tuple<uint32_t>())
+    .describe("Position of argument variables.");
+    DMLC_DECLARE_FIELD(param_indices)
+    .set_default(nnvm::Tuple<uint32_t>())
+    .describe("Position of parameters.");
+  }
+};
+
+class CachedOp {
+ public:
+  CachedOp(
+      const nnvm::Symbol& sym,
+      const std::vector<std::pair<std::string, std::string> >& flags);
+  ~CachedOp();
+  uint32_t num_inputs() {
+    return fwd_graph_.indexed_graph().input_nodes().size();
+  }
+  uint32_t num_outputs() {
+    return fwd_graph_.outputs.size();
+  }
+  uint32_t num_backward_inputs() {
+    return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
+  }
+  std::vector<bool>& save_inputs() {
+    return save_inputs_;
+  }
+  std::vector<bool>& save_outputs() {
+    return save_outputs_;
+  }
+  const std::unordered_set<uint32_t>& mutable_input_nodes() {
+    return fwd_graph_.indexed_graph().mutable_input_nodes();
+  }
+  std::vector<nnvm::NodeEntry> Gradient(
+      const nnvm::NodePtr& node,
+      const std::vector<nnvm::NodeEntry>& ograds);
+  void Forward(
+      const std::shared_ptr<CachedOp>& op_ptr,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<NDArray*>& outputs);
+  void Backward(
+      const bool retain_graph,
+      const OpStatePtr& state,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& outputs);
+
+ private:
+  struct GraphInfo;
+  struct DynamicRuntime;
+  struct CachedOpState;
+
+  OpStatePtr GetCachedOpState(const Context& ctx);
+  bool SetForwardGraph(
+      GraphInfo* info,
+      const bool recording,
+      const std::vector<NDArray*>& inputs);
+  bool SetBackwardGraph(
+      GraphInfo* info,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& inputs,
+      bool detect_inplace_addto = false);
+  OpStatePtr DynamicForward(
+      const Context& default_ctx,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<NDArray*>& outputs);
+  void DynamicBackward(
+      const bool retain_graph,
+      const OpStatePtr& op_state,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& outputs);
+  void StaticAllocMemory(
+      const OpStatePtr& state_ptr,
+      bool recording,
+      bool keep_fwd);
+  void StaticInitExec(
+      const OpStatePtr& state_ptr,
+      bool recording,
+      bool keep_fwd);
+  void StaticRunOps(
+      const Context& default_ctx,
+      const nnvm::Graph& g,
+      const OpStatePtr& state_ptr,
+      size_t start_nid,
+      size_t end_nid);
+  OpStatePtr StaticForward(
+      const Context& default_ctx,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<NDArray*>& outputs);
+  void StaticBackward(
+      const bool retain_graph,
+      const OpStatePtr& state_ptr,
+      const std::vector<NDArray*>& inputs,
+      const std::vector<OpReqType>& reqs,
+      const std::vector<NDArray*>& outputs);
+
+  CachedOpConfig config_;
+  nnvm::Graph fwd_graph_;
+  nnvm::Graph grad_graph_;
+  nnvm::Graph full_graph_;
+  bool inlining_;
+  std::vector<nnvm::NodeEntry> ograd_entries_;
+  std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
+  std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;
+  std::vector<bool> save_inputs_, save_outputs_;
+  std::vector<OpReqType> bwd_output_reqs_;
+
+  std::mutex mutex_;
+  std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
+};
+
+using CachedOpPtr = std::shared_ptr<CachedOp>;
+
+}  // namespace mxnet
+#endif  // MXNET_IMPERATIVE_CACHED_OP_H_
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 7caf305eac7..e1654259a2f 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -19,6 +19,7 @@
 #include <unordered_set>
 #include <iostream>
 #include "./imperative_utils.h"
+#include "./cached_op.h"
 
 namespace mxnet {
 #if DMLC_CXX11_THREAD_LOCAL
@@ -266,95 +267,6 @@ void Imperative::RecordOp(
   }
 }
 
-void Imperative::RunGraph(
-    const bool retain_graph,
-    const nnvm::IndexedGraph& idx,
-    const std::vector<NDArray*> arrays,
-    size_t node_start, size_t node_end,
-    std::vector<OpReqType>&& array_reqs,
-    std::vector<uint32_t>&& ref_count,
-    std::vector<OpStatePtr> *p_states,
-    const DispatchModeVector &dispatch_modes) {
-  using namespace nnvm;
-  using namespace imperative;
-  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
-  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
-
-  std::vector<OpStatePtr>& states = *p_states;
-  bool recording = is_recording();
-
-  std::vector<NDArray*> ndinputs, ndoutputs;
-  ShapeVector arg_shapes;
-  DTypeVector arg_dtypes;
-  std::vector<OpReqType> req;
-
-  for (size_t i = node_start; i < node_end; ++i) {
-    const nnvm::IndexedGraph::Node& node = idx[i];
-    if (node.source->op() == nullptr) continue;
-    auto num_outputs = node.source->num_outputs();
-    ndinputs.clear();
-    ndinputs.reserve(node.inputs.size());
-    for (const auto& j : node.inputs) {
-      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
-      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
-    }
-    ndoutputs.clear();
-    ndoutputs.reserve(num_outputs);
-    req.clear();
-    req.reserve(num_outputs);
-    for (size_t j = 0; j < num_outputs; ++j) {
-      size_t eid = idx.entry_id(i, j);
-      ndoutputs.emplace_back(arrays[eid]);
-      req.push_back(array_reqs[eid]);
-      CHECK(!ndoutputs.back()->is_none());
-    }
-    const Context& ctx = ndoutputs[0]->ctx();
-    const DispatchMode dispatch_mode = dispatch_modes[i];
-    if (node.source->op() == bwd_cached_op) {
-      const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
-    } else if (createop.count(node.source->op())) {
-      arg_shapes.clear();
-      arg_dtypes.clear();
-      arg_shapes.reserve(ndinputs.size());
-      arg_dtypes.reserve(ndinputs.size());
-      for (size_t i = 0; i < ndinputs.size(); ++i) {
-        arg_shapes.emplace_back(ndinputs[i]->shape());
-        arg_dtypes.emplace_back(ndinputs[i]->dtype());
-      }
-      states[i] = createop[node.source->op()](
-          node.source->attrs, ctx, arg_shapes, arg_dtypes);
-      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
-      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
-    } else if (is_layer_backward.get(node.source->op(), false)) {
-      nnvm::Node* fwd_node = node.source->control_deps[0].get();
-      auto fwd_node_id = idx.node_id(fwd_node);
-      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
-               req, dispatch_mode, states[fwd_node_id]);
-      if (recording) {
-        RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
-      }
-    } else {
-      InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
-      if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
-    }
-
-    for (const auto& j : node.inputs) {
-      size_t eid = idx.entry_id(j);
-      --ref_count[eid];
-      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
-    }
-    for (size_t j = 0; j < ndoutputs.size(); ++j) {
-      size_t eid = idx.entry_id(i, j);
-      if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
-    }
-  }
-}
-
-
 std::vector<NDArray*> Imperative::Backward(
     const std::vector<NDArray*>& outputs,
     const std::vector<NDArray*>& ograds,
diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc
new file mode 100644
index 00000000000..464aefc220d
--- /dev/null
+++ b/src/imperative/imperative_utils.cc
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./imperative_utils.h"
+#include "./cached_op.h"
+
+namespace mxnet {
+namespace imperative {
+void RunGraph(
+    const bool retain_graph,
+    const nnvm::IndexedGraph& idx,
+    const std::vector<NDArray*> arrays,
+    size_t node_start, size_t node_end,
+    std::vector<OpReqType>&& array_reqs,
+    std::vector<uint32_t>&& ref_count,
+    std::vector<OpStatePtr> *p_states,
+    const DispatchModeVector &dispatch_modes) {
+  using namespace nnvm;
+  using namespace imperative;
+  static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+  static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+  static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
+
+  const auto imp = Imperative::Get();
+
+  std::vector<OpStatePtr>& states = *p_states;
+  bool recording = imp->is_recording();
+
+  std::vector<NDArray*> ndinputs, ndoutputs;
+  ShapeVector arg_shapes;
+  DTypeVector arg_dtypes;
+  std::vector<OpReqType> req;
+
+  for (size_t i = node_start; i < node_end; ++i) {
+    const nnvm::IndexedGraph::Node& node = idx[i];
+    if (node.source->op() == nullptr) continue;
+    auto num_outputs = node.source->num_outputs();
+    ndinputs.clear();
+    ndinputs.reserve(node.inputs.size());
+    for (const auto& j : node.inputs) {
+      ndinputs.emplace_back(arrays[idx.entry_id(j)]);
+      CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name << " " << j.index;
+    }
+    ndoutputs.clear();
+    ndoutputs.reserve(num_outputs);
+    req.clear();
+    req.reserve(num_outputs);
+    for (size_t j = 0; j < num_outputs; ++j) {
+      size_t eid = idx.entry_id(i, j);
+      ndoutputs.emplace_back(arrays[eid]);
+      req.push_back(array_reqs[eid]);
+      CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
+    }
+    const Context& ctx = ndoutputs[0]->ctx();
+    const DispatchMode dispatch_mode = dispatch_modes[i];
+    if (node.source->op() == bwd_cached_op) {
+      const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
+    } else if (createop.count(node.source->op())) {
+      arg_shapes.clear();
+      arg_dtypes.clear();
+      arg_shapes.reserve(ndinputs.size());
+      arg_dtypes.reserve(ndinputs.size());
+      for (size_t i = 0; i < ndinputs.size(); ++i) {
+        arg_shapes.emplace_back(ndinputs[i]->shape());
+        arg_dtypes.emplace_back(ndinputs[i]->dtype());
+      }
+      states[i] = createop[node.source->op()](
+          node.source->attrs, ctx, arg_shapes, arg_dtypes);
+      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, states[i]);
+      if (recording) {
+        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[i]);
+      }
+    } else if (is_layer_backward.get(node.source->op(), false)) {
+      nnvm::Node* fwd_node = node.source->control_deps[0].get();
+      auto fwd_node_id = idx.node_id(fwd_node);
+      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
+               req, dispatch_mode, states[fwd_node_id]);
+      if (recording) {
+        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, states[fwd_node_id]);
+      }
+    } else {
+      imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode);
+      if (recording) {
+        imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
+      }
+    }
+
+    for (const auto& j : node.inputs) {
+      size_t eid = idx.entry_id(j);
+      --ref_count[eid];
+      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
+    }
+    for (size_t j = 0; j < ndoutputs.size(); ++j) {
+      size_t eid = idx.entry_id(i, j);
+      if (ref_count[eid] == 0) *arrays[eid] = NDArray();
+    }
+  }
+}
+
+}  // namespace imperative
+}  // namespace mxnet
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 06b7e058dd1..726531d0299 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -23,6 +23,7 @@
 #include <utility>
 #include <algorithm>
 #include <vector>
+#include <map>
 #include <string>
 #include "../executor/graph_executor.h"
 #include "../executor/exec_pass.h"
@@ -38,11 +39,24 @@ namespace mxnet {
 namespace imperative {
 
 struct MemoryPlanInfo {
-  uint32_t sid;
+  int storage_id;
+  uint32_t root;
   size_t size;
   bool inplace;
 };
 
+struct EngineOprDeleter {
+  void operator()(engine::Opr* handle) {
+    Engine::Get()->DeleteOperator(handle);
+  }
+};
+
+struct EngineOprSeg {
+  bool skip;
+  size_t next_nid;
+  std::unique_ptr<engine::Opr, EngineOprDeleter> opr;
+};
+
 using MemoryPlanVector = std::vector<MemoryPlanInfo>;
 
 inline Context GetContext(const nnvm::NodeAttrs& attrs,
@@ -715,10 +729,12 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {
 
 
 inline MemoryPlanVector PlanMemory(
-    nnvm::Graph* p_g, nnvm::StorageVector&& storage,
+    nnvm::Graph* p_g,
+    nnvm::StorageVector&& storage,
     const std::vector<uint32_t>& ref_count,
     const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
-    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0}) {
+    const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
+    bool detect_inplace_addto = false) {
   using namespace nnvm;
   nnvm::Graph& g = *p_g;
   const auto& idx = g.indexed_graph();
@@ -728,31 +744,31 @@ inline MemoryPlanVector PlanMemory(
   g.attrs["ref_count"] = std::make_shared<dmlc::any>(ref_count);
   g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(storage));
   g = nnvm::ApplyPass(g, "PlanMemory");
+  if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g);
 
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
-  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
-  auto storage_ids = g.MoveCopyAttr<StorageVector>("storage_id");
-  auto storage_inplace = g.MoveCopyAttr<std::vector<int> >("storage_inplace_index");
+  const auto& storage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
+  const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
   uint32_t entry_start = entry_range.first;
   uint32_t entry_end =
       entry_range.second > entry_start ? entry_range.second : idx.num_node_entries();
   MemoryPlanVector mem_plan(idx.num_node_entries());
-  std::unordered_map<int, uint32_t> sid_to_loc;
+  std::unordered_map<int, uint32_t> sid_to_root;
 
   for (uint32_t i = entry_start; i < entry_end; ++i) {
-    if (stypes[i] != kDefaultStorage) continue;
     if (storage_ids[i] < 0) {
-      mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false};
-    } else if (!sid_to_loc.count(storage_ids[i])) {
+      mem_plan[i] = {storage_ids[i], i, 0, false};
+    } else if (!sid_to_root.count(storage_ids[i])) {
       CHECK_LT(storage_inplace[i], 0);
-      sid_to_loc[storage_ids[i]] = i;
-      mem_plan[i].sid = i;
-      mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size();
+      sid_to_root[storage_ids[i]] = i;
+      mem_plan[i] = {storage_ids[i], i,
+                     mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
+                     false};
     } else {
-      uint32_t loc = sid_to_loc[storage_ids[i]];
-      mem_plan[i] = {loc, 0, storage_inplace[i] >= 0};
-      mem_plan[loc].size = std::max(mem_plan[loc].size,
+      uint32_t root = sid_to_root[storage_ids[i]];
+      mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0};
+      mem_plan[root].size = std::max(mem_plan[root].size,
           mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size());
     }
   }
@@ -761,39 +777,213 @@ inline MemoryPlanVector PlanMemory(
 }
 
 
-inline void AllocateMemory(const nnvm::Graph& g,
-                    const nnvm::IndexedGraph& idx,
-                    const Context& default_ctx,
-                    const uint32_t entry_start, const uint32_t entry_end,
-                    const MemoryPlanVector& mem_plan,
-                    const std::vector<NDArray*>& arrays,
-                    std::vector<OpReqType> *array_reqs) {
+inline std::multimap<size_t, NDArray> AllocateMemory(
+    const nnvm::Graph& g,
+    const nnvm::IndexedGraph& idx,
+    const Context& default_ctx,
+    const uint32_t entry_start, const uint32_t entry_end,
+    const MemoryPlanVector& mem_plan,
+    const std::vector<NDArray*>& arrays,
+    std::vector<OpReqType> *array_reqs,
+    std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
   using namespace nnvm;
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<ShapeVector>("shape");
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
 
+  std::multimap<size_t, NDArray> new_pool;
+
   for (uint32_t i = entry_start; i < entry_end; ++i) {
-    if (!arrays[i]->is_none()) continue;
-    if (stypes[i] == kDefaultStorage) {
-      if (mem_plan[i].sid == i) {
-        CHECK_GT(mem_plan[i].size, 0);
+    if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
+    CHECK(arrays[i]->is_none());
+    if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
+      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
+                           shapes[i], default_ctx, true, dtypes[i]);
+      continue;
+    }
+    CHECK_EQ(stypes[i], kDefaultStorage);
+    if (mem_plan[i].root == i) {
+      CHECK_GT(mem_plan[i].size, 0);
+      auto iter = pool.lower_bound(mem_plan[i].size);
+      if (iter != pool.end()) {
+        *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
+        new_pool.insert(*iter);
+        pool.erase(iter);
+      } else {
         NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
                      default_ctx, true, mshadow::kUint8);
         *arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
+        new_pool.insert({mem_plan[i].size, buff});
+      }
+    } else {
+      CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
+      *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
+      if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
+        array_reqs->at(i) = kWriteInplace;
+      }
+    }
+  }
+
+  return new_pool;
+}
+
+inline void SetupOpExec(
+    const nnvm::Graph& g,
+    size_t nid,
+    const std::shared_ptr<exec::OpExecutor>& exec,
+    const std::vector<NDArray*> arrays,
+    const std::vector<OpReqType> array_reqs) {
+  const auto& idx = g.indexed_graph();
+  const auto& inode = idx[nid];
+  CHECK_EQ(exec->in_array.size(), 0U);
+  CHECK_EQ(exec->out_array.size(), 0U);
+  for (const auto& e : inode.inputs) {
+    CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name;
+    exec->in_array.push_back(*arrays[idx.entry_id(e)]);
+  }
+  for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
+    uint32_t eid = idx.entry_id(nid, index);
+    CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name;
+    exec->out_array.push_back(*arrays[eid]);
+    exec->req.push_back(array_reqs[eid]);
+  }
+
+  exec->Setup();
+}
+
+inline Engine::OprHandle CreateEngineOp(
+    const Context& default_ctx,
+    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs) {
+  CHECK_GT(execs.size(), 0);
+  std::vector<Engine::VarHandle> use_vars, mutate_vars;
+
+  for (const auto& exec : execs) {
+    CHECK_GT(exec->out_array.size(), 0);
+    CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);
+
+    // the variables
+    for (const auto& nd : exec->in_array) {
+      use_vars.push_back(nd.var());
+    }
+    for (auto& r : exec->op_ctx.requested) {
+      mutate_vars.push_back(r.var);
+    }
+    for (auto& nd : exec->out_array) {
+      mutate_vars.push_back(nd.var());
+    }
+    if (exec->var() != nullptr) {
+      mutate_vars.push_back(exec->var());
+    }
+  }
+
+  // dedup vars
+  Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
+  bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
+  bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;
+
+  auto exec_fun = [execs, is_async, is_gpu] (
+      RunContext ctx, Engine::CallbackOnComplete on_complete) {
+    if (is_async) {
+      execs[0]->op_ctx.async_on_complete = on_complete;
+    }
+    for (const auto& exec : execs) exec->Run(ctx, is_gpu);
+    // call on complete only if it is async op
+    if (!is_async) {
+      if (is_gpu) {
+      #if MXNET_USE_CUDA
+        // Wait GPU kernel to finish.
+        ctx.get_stream<gpu>()->Wait();
+      #else
+        LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+      #endif
+      }
+      on_complete();
+    }
+  };
+
+  return Engine::Get()->NewOperator(
+      exec_fun, use_vars, mutate_vars, FnProperty::kNormal);
+}
+
+inline void CreateEngineOpSeg(
+    const nnvm::IndexedGraph& idx,
+    const Context default_ctx,
+    const size_t start_nid,
+    const size_t end_nid,
+    const size_t bulk_size,
+    const std::unordered_set<uint32_t>& excludes,
+    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
+    const std::vector<int> skip_plus_node,
+    std::vector<EngineOprSeg> *opr_segs) {
+  size_t seg_start = start_nid;
+  std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
+  for (size_t nid = start_nid; nid < end_nid; ++nid) {
+    const auto& node = idx[nid];
+    if (node.source->is_variable()) continue;
+    if (skip_plus_node.size() && skip_plus_node[nid]) continue;
+    auto& exec = execs[nid];
+    bool is_async = exec->exec_type() != ExecType::kSync;
+    bool valid = exec->out_array.size() > 0;
+
+    // Stop at async nodes and invalid node (due to input/output is not allocated)
+    bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
+    for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
+      if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
+    }
+    auto num_outputs = node.source->num_outputs();
+    for (size_t i = 0; i < num_outputs && !stop; ++i) {
+      if (excludes.count(idx.entry_id(nid, i))) stop = true;
+    }
+
+    // Create opr segment for previous nodes.
+    if (stop && nid > seg_start) {
+      auto& seg = (*opr_segs)[seg_start];
+      if (seg_execs.size()) {
+        seg = EngineOprSeg{false, nid};
+        seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
       } else {
-        *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]);
-        if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
-          array_reqs->at(i) = kWriteInplace;
-        }
+        seg = EngineOprSeg{true, nid, nullptr};
       }
+      seg_start = nid;
+      seg_execs.clear();
+    }
+
+    seg_execs.push_back(exec);
+
+    auto& seg = (*opr_segs)[nid];
+    if (is_async) {
+      seg = EngineOprSeg{false, nid + 1};
+      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
+      seg_execs.clear();
+      seg_start = nid + 1;
+    } else if (!valid) {
+      seg = EngineOprSeg{false, nid + 1, nullptr};
+      seg_execs.clear();
+      seg_start = nid + 1;
+    }
+  }
+  // The last segment
+  if (end_nid > seg_start) {
+    auto& seg = (*opr_segs)[seg_start];
+    if (seg_execs.size()) {
+      seg = EngineOprSeg{false, end_nid};
+      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
     } else {
-      *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
-                           shapes[i], default_ctx, true, dtypes[i]);
+      seg = EngineOprSeg{true, end_nid, nullptr};
     }
   }
 }
 
+
+void RunGraph(const bool retain_graph,
+              const nnvm::IndexedGraph& idx,
+              const std::vector<NDArray*> arrays,
+              size_t node_start, size_t node_end,
+              std::vector<OpReqType>&& array_reqs,
+              std::vector<uint32_t>&& ref_count,
+              std::vector<OpStatePtr> *p_states,
+              const DispatchModeVector &dispatch_modes);
+
 }  // namespace imperative
 }  // namespace mxnet
 
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 451fde2eb86..5701a5df5a0 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -22,6 +22,7 @@
 from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from common import setup_module, with_seed, assertRaises, teardown
 import numpy as np
+from numpy.testing import assert_array_equal
 from nose.tools import raises, assert_raises
 from copy import deepcopy
 import warnings
@@ -1124,7 +1125,6 @@ def test_hybrid_multi_context():
     net.hybridize()
     net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()
 
-
 @with_seed()
 def test_zero_grad():
     data = mx.nd.random.uniform(shape=(3,3))
@@ -1137,6 +1137,60 @@ def test_zero_grad():
     grad = net.collect_params()['test_zero_grad_weight'].grad()
     assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
 
+def check_hybrid_static_memory(**kwargs):
+    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
+    x.attach_grad()
+
+    net1 = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
+    net2 = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
+    net2.hybridize(**kwargs)
+    net1(x)
+    net2(x)
+
+    def test(net, x):
+        with mx.autograd.record():
+            y = net(x) + net(x)
+            y.backward()
+
+        grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}
+
+        return y, grads
+
+    y1, grads1 = test(net1, x)
+    y2, grads2 = test(net2, x)
+
+    assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
+    for key in grads1:
+        assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
+
+def test_hybrid_static_memory():
+    check_hybrid_static_memory()
+    check_hybrid_static_memory(static_alloc=True)
+    check_hybrid_static_memory(static_alloc=True, static_shape=True)
+
+def check_hybrid_static_memory_switching(**kwargs):
+    net = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, ctx=mx.context.current_context())
+    net.hybridize(**kwargs)
+
+    x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
+    net(x)
+    with mx.autograd.record():
+        y = net(x)
+        y.backward()
+    x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
+    net(x)
+    with mx.autograd.record():
+        y = net(x)
+        y.backward()
+    mx.nd.waitall()
+
+def test_hybrid_static_memory_switching():
+    check_hybrid_static_memory_switching()
+    check_hybrid_static_memory_switching(static_alloc=True)
+    check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)
 
 @with_seed()
 def test_hook():
@@ -1231,6 +1285,17 @@ def test_legacy_save_params():
     model.load_params('test.params', ctx=mx.cpu())
 
 
+def test_hybrid_static_memory_recording():
+    net = gluon.model_zoo.vision.get_resnet(
+        1, 18, pretrained=True, ctx=mx.context.current_context())
+    net.hybridize(static_alloc=True)
+
+    x = mx.nd.random.uniform(shape=(1, 3, 32, 32))
+    with mx.autograd.record(True):
+        net(x)
+    net(x)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services