You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/03/23 18:22:07 UTC

[incubator-mxnet] branch master updated: Add simplified HybridBlock.forward without F (#17530)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 83b5170  Add simplified HybridBlock.forward without F (#17530)
83b5170 is described below

commit 83b51703ed354f41024423f140de38df2ba22d50
Author: Leonard Lausen <la...@amazon.com>
AuthorDate: Mon Mar 23 11:21:23 2020 -0700

    Add simplified HybridBlock.forward without F (#17530)
    
    Users can now implement HybridBlock.forward instead of HybridBlock.hybrid_forward.
    HybridBlock.forward has the same signature as Block.forward. For example:
    
      class MyBlock(mx.gluon.HybridBlock):
          def __init__(self, *, prefix=None, params=None):
              super().__init__(prefix, params)
              with self.name_scope():
                  self.dense = mx.gluon.nn.Dense(units=10)
                  self.weight = self.params.get('weight', allow_deferred_init=True)
          def infer_shape(self, x):
              self.weight.shape = (x.shape[1], )
          def forward(self, x):
              return self.dense(x) + self.weight.data(x.context)
    
    Hybridization of HybridBlock.forward is based on a deferred computation mode in
    the MXNet backend, which enables recording computation via tracing in the
    mxnet.nd and mxnet.np interfaces. The recorded computation can be exported to a
    symbolic representation and is used for optimized execution with the CachedOp.
    
    As tracing is based on the imperative APIs, users can access shape information
    of the arrays. As x.shape for some array x is a python tuple, any use of that
    shape will be a constant in the recorded graph and may limit the recorded graph
    to be used with inputs of the same shape only.
    
    As part of the change from hybrid_forward to forward, we also disable support
    for parameter shape inference in the MXNet backend in the case of deferred
    parameter initialization. Shape inference in the backend was limited and did by
    it's very nature not support dynamic shape operators. Instead, users should now
    always implement HybridBlock.infer_shape to set the parameter shapes if the
    parameter shape was not set during HybridBlock.__init__. See the example above.
    
    An example of the internal deferred compute APIs is:
    
      a = mx.np.arange(10)
      dc.set_variable(a, mx.sym.var('a').as_np_ndarray())
      with dc.context():
          b = a ** 2
      symbol = dc.get_symbol(b)
---
 include/mxnet/c_api.h                          |  42 ++
 include/mxnet/imperative.h                     |  90 ++++-
 include/mxnet/ndarray.h                        |  61 ++-
 python/mxnet/__init__.py                       |   2 +
 python/mxnet/_deferred_compute.py              | 106 +++++
 python/mxnet/gluon/block.py                    | 103 ++++-
 python/mxnet/gluon/parameter.py                |   8 +-
 python/mxnet/ndarray/ndarray.py                |  11 +-
 python/mxnet/ndarray/sparse.py                 |   1 +
 python/mxnet/numpy/multiarray.py               |  16 +-
 src/api/operator/utils.cc                      |  21 +-
 src/c_api/c_api.cc                             |  17 +-
 src/c_api/c_api_ndarray.cc                     |  55 ++-
 src/imperative/cached_op.h                     |   2 +-
 src/imperative/imperative.cc                   | 199 ++++++++-
 src/imperative/imperative_utils.h              |  47 ++-
 src/ndarray/ndarray.cc                         | 121 +++++-
 tests/python/gpu/test_deferred_compute_gpu.py  |  33 ++
 tests/python/unittest/test_deferred_compute.py | 536 +++++++++++++++++++++++++
 19 files changed, 1357 insertions(+), 114 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 637b31d..638385b 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1423,6 +1423,44 @@ MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
                                        CachedOpMonitorCallback callback,
                                        bool monitor_all);
 
+/*!
+ * \brief Get current status of deferred compute mode
+ * \param curr returns the current status.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayIsDeferredCompute(int *curr);
+
+/*!
+ * \brief set whether to enable deferred compute mode
+ * \param deferred_compute_enabled 1 to enable, 0 to disable.
+ * \param prev returns the previous status before this set.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArraySetIsDeferredCompute(int deferred_compute_enabled, int *prev);
+
+/*!
+ * \brief Associate variables with deferred compute arrays
+ * \param arrays ndarray handles to be matched with variables
+ * \param variables symbol handles of variables to be matched with ndarrays
+ * \param num number of arrays and variables respectively
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArraySetDeferredComputeVariable(NDArrayHandle *arrays,
+                                                  SymbolHandle *variables,
+                                                  int num);
+
+/*!
+ * \brief Convert the graph constructed during deferred computation mode to a Symbol.
+ * \param output_handles ndarray handles of outputs
+ * \param out grouped output symbol handle
+ *
+ * Construct a Symbol for the deferred computation graph. output_handles
+ * specifies the outputs of interest which the returned symbol will compute.
+ */
+MXNET_DLL int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle *output_handles,
+                                                int num_outputs,
+                                                SymbolHandle *out);
+
 //--------------------------------------------
 // Part 3: symbolic configuration generation
 //--------------------------------------------
@@ -1501,6 +1539,10 @@ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
                                           const char **return_type DEFAULT(NULL));
 /*!
  * \brief Create an AtomicSymbol.
+ *
+ * A Symbol is said to be atomic if it is not composed of other Symbols. Atomic
+ * Symbols can be composed.
+ *
  * \param creator the AtomicSymbolCreator
  * \param num_param the number of parameters
  * \param keys the keys to the params
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 6a367b3..ca6f935 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -56,7 +56,8 @@ class Imperative {
     OpReqType grad_req;
     OpStatePtr state;
     std::vector<NDArray> outputs;
-    std::vector<NDArray> out_grads;
+    std::vector<NDArray> out_grads;  // used to hold gradient arrays the user is
+                                     // interested in (marked variables)
     bool fresh_out_grad;
 
     AGInfo() :
@@ -79,7 +80,7 @@ class Imperative {
     }
 
     static bool IsNone(const NDArray& arr) {
-      return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
+      return arr.autograd_entry_.node == nullptr || arr.autograd_entry_.node->info.empty();
     }
 
     static bool IsVariable(const nnvm::ObjectPtr& node) {
@@ -88,6 +89,73 @@ class Imperative {
              && info.out_grads.size() == 1;
     }
   };
+
+  /*! \brief DCInfo datastructure to enable deferred computation */
+  class DCInfo {
+   public:
+    explicit DCInfo(const std::vector<NDArray *> &inputs,
+                    const std::vector<NDArray *> &outputs);
+
+    /*! \brief Compute the outputs of the associated operator. */
+    static void Compute(const NDArray &arr);
+
+    static DCInfo &Get(const nnvm::ObjectPtr &node) {
+      return dmlc::get<DCInfo>(node->info);
+    }
+
+    static bool IsNone(const NDArray &arr) {
+      return arr.deferredcompute_entry_.node == nullptr ||
+             arr.deferredcompute_entry_.node->info.empty();
+    }
+
+    static bool IsComputed(const NDArray &arr) {
+      return IsNone(arr) ||
+        dmlc::get<DCInfo>(arr.deferredcompute_entry_.node->info).is_computed_;
+    }
+
+    static DCInfo &Create(const nnvm::ObjectPtr &node,
+                          const std::vector<NDArray *> &inputs,
+                          const std::vector<NDArray *> &outputs);
+
+   private:
+    friend class Imperative;
+
+    /*! \brief Copies of input NDArrays
+     *
+     * If respective input NDArray is deallocated on the frontend, we still need
+     * to keep a copy around to facilitate deferred computation of this array.
+     * The copies share the chunk.
+     *
+     * They are automatically deallocated after computation finished.
+     */
+    std::vector<NDArray> inputs_;
+
+    /*! \brief Handles of input NDArrays used by frontend
+     *
+     * Frontend may request conversion to Symbol, specifying a list of NDArray
+     * handles corresponding to inputs and outputs of the Symbol. We store the
+     * handles used by frontend to facilitate matching in
+     * GetDeferredComputeSymbol.
+     *
+     * Note that the frontend may have deallocated the NDArray* and the
+     * input_handles stored here may point to invalid memory.
+     */
+    std::vector<const NDArray *> input_handles_;
+
+    /*! \brief Copies of output NDArrays
+     *
+     * If respective output NDArray is deallocated on the frontend, we still
+     * need to keep a copy around to facilitate deferred computation of arrays
+     * relying on the output array. The copies share the chunk.
+     *
+     * They are automatically deallocated after computation finished.
+     */
+    std::vector<NDArray> outputs_;
+
+    /*! \brief Remember if the outputs associated with this DCInfo have been computed already */
+    bool is_computed_ = false;
+  };
+
   /*! \brief whether operator recording is on. */
   bool is_training() const {
     return is_train_;
@@ -108,6 +176,14 @@ class Imperative {
       is_recording_ = is_recording;
       return old;
   }
+  /*! \brief whether deferred compute mode is on. */
+  bool is_deferred_compute() const { return is_deferred_compute_; }
+  /*! \brief turn on or turn off operator recording for autograd. */
+  bool set_is_deferred_compute(bool is_deferred_compute) {
+    bool old = is_deferred_compute_;
+    is_deferred_compute_ = is_deferred_compute;
+    return old;
+  }
   /*! \brief return current numpy compatibility status,
    *  GlobalOn(2), ThreadLocalOn(1), Off(0).
    * */
@@ -143,6 +219,14 @@ class Imperative {
                 const OpStatePtr& state = OpStatePtr(),
                 std::vector<bool>* p_save_inputs = nullptr,
                 std::vector<bool>* p_save_outputs = nullptr);
+  /*! \brief to record operator, return corresponding node. */
+  void RecordDeferredCompute(nnvm::NodeAttrs&& attrs,
+                             const std::vector<NDArray*>& inputs,
+                             const std::vector<NDArray*>& outputs);
+  /*! \brief obtain symbol representation of deferred compute session. */
+  nnvm::Symbol GetDeferredComputeSymbol(const std::vector<NDArray *> &outputs);
+  /*! \brief associate arrays with variables for deferred compute */
+  void SetDeferredComputeVariable(NDArrayHandle *arrays, SymbolHandle *variables, const int num);
   /*! \brief */
   OpStatePtr Invoke(const Context& default_ctx,
                     const nnvm::NodeAttrs& attrs,
@@ -204,12 +288,14 @@ class Imperative {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local bool is_train_;
   static thread_local bool is_recording_;
+  static thread_local bool is_deferred_compute_;
   // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
   // Delete it in the next major release.
   static thread_local bool is_np_shape_thread_local_;
 #else
   static MX_THREAD_LOCAL bool is_train_;
   static MX_THREAD_LOCAL bool is_recording_;
+  static MX_THREAD_LOCAL bool is_deferred_compute_;
   // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
   // Delete it in the next major release.
   static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index fd7cc38..81cae0f 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -83,7 +83,7 @@ class NDArray {
  public:
   /*! \brief default constructor */
   NDArray()
-    : entry_(nullptr) {
+    : autograd_entry_(nullptr) {
   }
   /*!
    * \brief constructs a new dynamic NDArray
@@ -98,7 +98,7 @@ class NDArray {
         shape_(shape),
         dtype_(dtype),
         storage_type_(kDefaultStorage),
-        entry_(nullptr) {
+        autograd_entry_(nullptr) {
   }
   /*! \brief constructor for NDArray with storage type
    */
@@ -117,7 +117,7 @@ class NDArray {
         shape_(),
         dtype_(dtype),
         storage_type_(kDefaultStorage),
-        entry_(nullptr) {
+        autograd_entry_(nullptr) {
   }
   /*!
    * \brief constructing a static NDArray that shares data with TBlob
@@ -131,7 +131,7 @@ class NDArray {
         shape_(data.shape_),
         dtype_(data.type_flag_),
         storage_type_(kDefaultStorage),
-        entry_(nullptr) {
+        autograd_entry_(nullptr) {
   }
 
   /*!
@@ -149,7 +149,7 @@ class NDArray {
         }),
         shape_(data.shape_),
         dtype_(data.type_flag_), storage_type_(kDefaultStorage),
-        entry_(nullptr) {
+        autograd_entry_(nullptr) {
   }
 
   /*! \brief create ndarray from shared memory */
@@ -158,7 +158,7 @@ class NDArray {
         shape_(shape),
         dtype_(dtype),
         storage_type_(kDefaultStorage),
-        entry_(nullptr) {
+        autograd_entry_(nullptr) {
   }
 
   /*!
@@ -177,7 +177,7 @@ class NDArray {
         shape_(shape),
         dtype_(data.type_flag_),
         storage_type_(stype),
-        entry_(nullptr) {
+        autograd_entry_(nullptr) {
   }
   /*!
    * \brief initialize the NDArray, assuming it is not assigned a meaningful shape before
@@ -190,7 +190,7 @@ class NDArray {
   /*!
    * \brief set the correct shape of NDArray directly from the storage_shape of its own chunk.
    */
-  void SetShapeFromChunk();
+  void SetShapeFromChunk() const;
   /*
    * This indicates whether an array is a view of another array (created by
    * reshape or slice). If an array is a view and the data is stored in
@@ -326,9 +326,9 @@ class NDArray {
   inline bool is_none() const {
     return ptr_.get() == nullptr;
   }
-  /*! \return updated grad state in entry_ */
+  /*! \return updated grad state in autograd_entry_ */
   bool fresh_out_grad() const;
-  /*! \return updated grad state in entry_ */
+  /*! \return updated grad state in autograd_entry_ */
   void set_fresh_out_grad(bool state) const;
   /*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
    * Throws an exception if the indices array shape is inconsistent
@@ -367,27 +367,19 @@ class NDArray {
   /*!
    * \brief Block until all the pending write operations with respect
    *    to current NDArray are finished, and read can be performed.
+   *
+   * If the array has not been computed yet (deferred compute), this will
+   * trigger computation.
    */
-  inline void WaitToRead() const {
-    if (is_none()) return;
-    Engine::Get()->WaitForVar(ptr_->var);
-  }
+  void WaitToRead() const;
   /*!
    * \brief Block until all the pending read/write operations with respect
    *    to current NDArray are finished, and write can be performed.
+   *
+   * If the array has not been computed yet (deferred compute), this will
+   * trigger computation.
    */
-  inline void WaitToWrite() const {
-    if (is_none()) return;
-    /*!
-     * Push an empty mutable function to flush all preceding reads to the
-     * variable.
-     */
-    Engine::Get()->PushAsync(
-      [](RunContext, Engine::CallbackOnComplete on_complete) {
-        on_complete();
-      }, Context{}, {}, {ptr_->var});
-    Engine::Get()->WaitForVar(ptr_->var);
-  }
+  void WaitToWrite() const;
   /*! \return the associated variable of the ndarray.*/
   inline Engine::VarHandle var() const {
     return ptr_->var;
@@ -648,11 +640,13 @@ class NDArray {
    */
   NDArray ReshapeWithRecord(const mxnet::TShape &shape);
   /*!
-   * \brief Return a copy of this NDArray without autograd history
+   * \brief Return a copy of this NDArray without autograd and deferred compute
+   * history
    */
   NDArray Detach() const {
     NDArray ret(*this);
-    ret.entry_ = nnvm::NodeEntry(nullptr);
+    ret.autograd_entry_ = nnvm::NodeEntry(nullptr);
+    ret.deferredcompute_entry_ = nnvm::NodeEntry(nullptr);
     return ret;
   }
 
@@ -1100,8 +1094,11 @@ class NDArray {
 
   /*! \brief internal data of NDArray */
   std::shared_ptr<Chunk> ptr_{nullptr};
-  /*! \brief shape of current NDArray */
-  mxnet::TShape shape_;
+  /*! \brief shape of current NDArray
+   *  \note const methods WaitToRead, WaitToWrite will set shape, if shape is
+   *        previously unknown and array is deferred computed.
+   */
+  mutable mxnet::TShape shape_;
   /*! \brief byte offset in chunk */
   size_t byte_offset_ = 0;
   /*! \brief type of data */
@@ -1111,7 +1108,9 @@ class NDArray {
   /*! \brief storage type of data */
   NDArrayStorageType storage_type_ = kUndefinedStorage;
   /*! \brief node entry for autograd */
-  nnvm::NodeEntry entry_;
+  nnvm::NodeEntry autograd_entry_;
+  /*! \brief node entry for deferred computation tracking */
+  nnvm::NodeEntry deferredcompute_entry_;
   /*!
    * \brief internal TBlob
    * \note When user access tblob_ by some const methods like
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 83cf72d..49f10aa 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -87,6 +87,8 @@ from . import test_utils
 from . import rnn
 from . import gluon
 
+from . import _deferred_compute
+
 # With the native kvstore module (such as 'dist_sync_device'), the module launches a separate
 # process when role is set to "server". This should be done after other modules are initialized.
 # Otherwise this may result in errors when unpickling custom LR scheduler/optimizers.
diff --git a/python/mxnet/_deferred_compute.py b/python/mxnet/_deferred_compute.py
new file mode 100644
index 0000000..4cb1725
--- /dev/null
+++ b/python/mxnet/_deferred_compute.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Deferred Compute for NDArray."""
+
+import ctypes
+import contextlib
+
+from .base import _LIB, check_call, SymbolHandle, _as_list
+from .symbol import Symbol
+
+__all__ = []
+
+def is_deferred_compute():
+    """Get status of deferred compute mode."""
+    curr = ctypes.c_bool()
+    check_call(_LIB.MXNDArrayIsDeferredCompute(ctypes.byref(curr)))
+    return curr.value
+
+def set_deferred_compute(state):
+    """Enable / Disable deferred compute mode.
+
+    Parameters
+    ----------
+    state: bool
+
+    Returns
+    -------
+    Previous deferred compute state.
+    """
+    prev = ctypes.c_int()
+    check_call(_LIB.MXNDArraySetIsDeferredCompute(ctypes.c_int(state), ctypes.byref(prev)))
+    return bool(prev.value)
+
+
+@contextlib.contextmanager
+def context(state=True):
+    """Set deferred compute state to `state` within context. Reset afterwards to previous value."""
+    # Like other MXNet context manager, this bleeds state across concurrent
+    # code: "Context managers that have state should use Context Variables
+    # instead of threading.local() to prevent their state from bleeding to
+    # other code unexpectedly, when used in concurrent code."
+    # https://github.com/apache/incubator-mxnet/issues/17495#issuecomment-585461965
+    val = set_deferred_compute(state)
+    try:
+        yield
+    finally:
+        set_deferred_compute(val)
+
+
+def get_symbol(output_arrays, *, sym_cls=Symbol):
+    """Get symbolic representation of computation recorded in deferred compute mode.
+
+    Parameters
+    ----------
+    output_arrays: NDArray or List[NDArray]
+    sym_cls: class used to construct Symbol
+
+    Returns
+    -------
+    Symbol of sym_cls
+    """
+    output_arrays = _as_list(output_arrays)
+    # Prepare ctypes array types
+    output_handles_type = ctypes.c_void_p * len(output_arrays)
+    # Convert handles
+    output_handles = output_handles_type(*[array.handle for array in output_arrays])
+    handle = SymbolHandle()
+    check_call(_LIB.MXNDArrayGetDeferredComputeSymbol(output_handles, len(output_arrays),
+                                                      ctypes.byref(handle)))
+    return sym_cls(handle)
+
+
+def set_variable(arrays, variables):
+    """Associate variables with arrays.
+
+    Parameters
+    ----------
+    arrays: NDArray or List[NDArray]
+    variables: Symbol or List[Symbol] of variables
+    """
+
+    arrays = _as_list(arrays)
+    variables = _as_list(variables)
+
+    # Prepare ctypes array types
+    arrays_type = variables_type = ctypes.c_void_p * len(arrays)
+
+    # Convert handles
+    arrays = arrays_type(*[array.handle for array in arrays])
+    variables = variables_type(*[symbol.handle for symbol in variables])
+
+    check_call(_LIB.MXNDArraySetDeferredComputeVariable(arrays, variables, len(arrays)))
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 312358c..10c11b8 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -28,13 +28,13 @@ from collections import OrderedDict, defaultdict
 import numpy as np
 
 from ..base import mx_real_t, MXNetError
-from .. import symbol, ndarray, initializer, np_symbol
+from .. import symbol, ndarray, initializer, np_symbol, autograd, _deferred_compute as dc
 from ..symbol import Symbol
 from ..ndarray import NDArray
 from .. import name as _name
 from .. import profiler as _profiler
 from .parameter import Parameter, ParameterDict, DeferredInitializationError
-from .utils import _indent, _brief_print_list, HookHandle
+from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
 from .utils import _check_same_symbol_type, _check_all_np_ndarrays
 from .. import numpy_extension as _mx_npx
 from .. import numpy as _mx_np
@@ -248,8 +248,8 @@ class Block(object):
     :py:class:`Block` can be nested recursively in a tree structure. You can create and
     assign child :py:class:`Block` as regular attributes::
 
+        import mxnet as mx
         from mxnet.gluon import Block, nn
-        from mxnet import ndarray as F
 
         class Model(Block):
             def __init__(self, **kwargs):
@@ -260,12 +260,12 @@ class Block(object):
                     self.dense1 = nn.Dense(20)
 
             def forward(self, x):
-                x = F.relu(self.dense0(x))
-                return F.relu(self.dense1(x))
+                x = mx.nd.relu(self.dense0(x))
+                return mx.nd.relu(self.dense1(x))
 
         model = Model()
         model.initialize(ctx=mx.cpu(0))
-        model(F.zeros((10, 10), ctx=mx.cpu(0)))
+        model(mx.nd.zeros((10, 10), ctx=mx.cpu(0)))
 
 
     Child :py:class:`Block` assigned this way will be registered and :py:meth:`collect_params`
@@ -856,9 +856,9 @@ class HybridBlock(Block):
                     self.dense0 = nn.Dense(20)
                     self.dense1 = nn.Dense(20)
 
-            def hybrid_forward(self, F, x):
-                x = F.relu(self.dense0(x))
-                return F.relu(self.dense1(x))
+            def forward(self, x):
+                x = nd.relu(self.dense0(x))
+                return nd.relu(self.dense1(x))
 
         model = Model()
         model.initialize(ctx=mx.cpu(0))
@@ -890,6 +890,7 @@ class HybridBlock(Block):
         self._cached_op = None
         self._out_format = None
         self._in_format = None
+        self._called_infer_shape_already = False
         self._active = False
         self._flags = []
         self._callback = None
@@ -903,7 +904,7 @@ class HybridBlock(Block):
         if isinstance(value, HybridBlock):
             self._clear_cached_op()
 
-    def _get_graph(self, *args):
+    def _get_graph_v1(self, *args):
         if not self._cached_graph:
             flatten_args, self._in_format = _flatten(args, "input")
             flatten_inputs = []
@@ -936,6 +937,40 @@ class HybridBlock(Block):
 
         return self._cached_graph
 
+    def _get_graph_v2(self, *args):
+        if not self._cached_graph:
+            flatten_args, self._in_format = _flatten(args, "input")
+            flatten_args = [ele.detach() if ele is not None else None for ele in flatten_args]
+            real_args = [ele for ele in flatten_args if ele is not None]
+            if len(real_args) == 0:
+                raise ValueError('All args are None and we do not support such a case.'
+                                 ' Received args={}'.format(args))
+            if len(real_args) == 1:
+                arg_names = ['data']
+            else:
+                arg_names = ['data{}'.format(i) for i, ele in enumerate(real_args)]
+            symbol_inputs = [
+                symbol.var(name).as_np_ndarray()
+                if isinstance(arg, _mx_np.ndarray) else symbol.var(name)
+                for arg, name in zip(real_args, arg_names)
+            ]
+            dc.set_variable(real_args, symbol_inputs)
+            args = _regroup(flatten_args, self._in_format)
+            with autograd.pause(), dc.context():
+                out = super().__call__(*args)
+            flatten_out, self._out_format = _flatten(out, "output")
+            symbol_outputs = dc.get_symbol(flatten_out)
+            self._cached_graph = symbol_inputs, symbol_outputs
+        return self._cached_graph
+
+    def _get_graph(self, *args):
+        if not self._cached_graph:
+            if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:  # Gluon 1
+                return self._get_graph_v1(*args)
+            else:  # Gluon 2 based on deferred compute mode
+                return self._get_graph_v2(*args)
+        return self._cached_graph
+
     def _build_cache(self, *args):
         data, out = self._get_graph(*args)
         data_names = {data.name: i for i, data in enumerate(data)}
@@ -1180,7 +1215,20 @@ class HybridBlock(Block):
 
     def infer_shape(self, *args):
         """Infers shape of Parameters from inputs."""
-        self._infer_attrs('infer_shape', 'shape', *args)
+        if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
+            # Gluon 1 based on F:  hybrid_forward is defined by user
+            self._infer_attrs('infer_shape', 'shape', *args)
+        else:
+            # In Gluon 2, users must implement infer_shape, if any deferred
+            # initialized parameters are associated with the HybridBlock
+            params = [p for p in self._reg_params.values() if not shape_is_known(p.shape)]
+            if params:
+                params_str = ", ".join("{} ({})".format(p.name, p.shape) for p in params)
+                raise RuntimeError(
+                    "{name} has parameters with unknown shape. You need to either specify the shape "
+                    "in __init__ or implement {name}.infer_shape to set the parameter shapes "
+                    "based on the first input. Parameters with unknown shapes are {params}".format(
+                        name=type(self).__name__, params=params_str))
 
     def infer_type(self, *args):
         """Infers data type of Parameters from inputs."""
@@ -1246,6 +1294,32 @@ class HybridBlock(Block):
             cld._callback = callback
             cld._monitor_all = monitor_all
 
+    def __call__(self, x, *args):
+        if self.hybrid_forward.__func__ is not HybridBlock.hybrid_forward:
+            # Gluon 1 based on F:  hybrid_forward is defined by user
+            return super().__call__(x, *args)
+        else:  # Gluon 2 based on deferred compute mode
+            assert self.forward is not HybridBlock.forward, (
+                'Must either define {name}.forward or {name}.hybrid_forward. '
+                'Defining {name}.hybrid_forward is deprecated.'.format(name=type(self).__name__))
+
+            if not self._called_infer_shape_already:
+                self.infer_shape(x, *args)
+                for p in self._reg_params.values():
+                    p._finish_deferred_init()
+                self._called_infer_shape_already = True
+
+            if not self._active:
+                # Normal imperative computation of forward()
+                return super().__call__(x, *args)
+
+            if dc.is_deferred_compute():
+                # Deferred compute is already enabled. This typically means that the current
+                # HybridBlock is a child block of a HybridBlock that has been hybridized.
+                return super().__call__(x, *args)
+
+            return self._call_cached_op(x, *args)
+
     def forward(self, x, *args):
         """Defines the forward computation. Arguments can be either
         :py:class:`NDArray` or :py:class:`Symbol`."""
@@ -1259,7 +1333,8 @@ class HybridBlock(Block):
                              ' Please check the type of the args.\n')
         if has_ndarray:
             ctx = first_ctx
-            if self._active:
+            if self._active and not dc.is_deferred_compute():
+                # Do not call CachedOp if not hybridized or inside deferred compute mode.
                 if len(ctx_set) > 1:
                     raise ValueError('Find multiple contexts in the input, '
                                      'After hybridized, the HybridBlock only supports one input '
@@ -1450,6 +1525,10 @@ class SymbolBlock(HybridBlock):
         self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()}
 
     def forward(self, x, *args):
+        if dc.is_deferred_compute():
+            raise RuntimeError('Calling a SymbolBlock from within HybridBlock '
+                               'is not yet supported in Gluon 2.')
+
         if isinstance(x, NDArray):
             with x.ctx:
                 return self._call_cached_op(x, *args)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 55b0f4a..06b6150 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -28,7 +28,7 @@ import warnings
 import numpy as np
 
 from ..base import mx_real_t, MXNetError
-from .. import symbol, ndarray, initializer, context
+from .. import symbol, ndarray, initializer, context, _deferred_compute as dc
 from ..context import Context, cpu
 from .. import autograd
 from .utils import _indent, _brief_print_list, shape_is_known
@@ -335,7 +335,7 @@ class Parameter(object):
             "in_channels, etc for `Block`s."%(
                 self.name, str(self.shape))
 
-        with autograd.pause():
+        with autograd.pause(), dc.context(False):
             if data is None:
                 kwargs = {'shape': self.shape, 'dtype': self.dtype, 'ctx': context.cpu()}
                 if is_np_array():
@@ -568,7 +568,9 @@ class Parameter(object):
             raise RuntimeError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \
                                "because its storage type is %s. Please use row_sparse_data() " \
                                "instead." % (self.name, str(ctx), self._stype))
-        return self._check_and_get(self._data, ctx)
+        data = self._check_and_get(self._data, ctx)
+        dc.set_variable(data, self.var())
+        return data
 
     def list_data(self):
         """Returns copies of this parameter on all contexts, in the same order
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index f9d04df..49a4406 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2614,12 +2614,12 @@ fixed-size items.
         <type 'numpy.int32'>
         """
 
+        if dtype is None:
+            dtype = mx_real_t
         if not copy and np.dtype(dtype) == self.dtype:
             return self
 
-        res = empty(self.shape, ctx=self.ctx, dtype=dtype)
-        self.copyto(res)
-        return res
+        return op.cast(self, dtype=dtype)
 
     def copyto(self, other):
         """Copies the value of this array to another array.
@@ -4635,6 +4635,11 @@ def concatenate(arrays, axis=0, always_copy=True):
     NDArray
         An `NDArray` that lives on the same context as `arrays[0].context`.
     """
+    # Unsupported in deferred compute mode due to use of inplace operations.
+    from .._deferred_compute import is_deferred_compute  # pylint: disable=wrong-import-position
+    assert not is_deferred_compute(), 'nd.concatenate is deprecated and ' \
+        'unsupported in deferred compute mode. Use nd.concat instead.'
+
     assert isinstance(arrays, list)
     assert len(arrays) > 0
     assert isinstance(arrays[0], NDArray)
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index b0238e3..eddf840 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -230,6 +230,7 @@ class BaseSparseNDArray(NDArray):
         if not copy and np.dtype(dtype) == self.dtype:
             return self
 
+        # Use copyto for casting, as op.cast(self, dtype=dtype) doesn't support sparse stype
         res = zeros(shape=self.shape, ctx=self.context,
                     dtype=dtype, stype=self.stype)
         self.copyto(res)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 25d4691..ee9df30 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -32,6 +32,7 @@ from array import array as native_array
 import ctypes
 import warnings
 import numpy as _np
+from .. import _deferred_compute as dc
 from ..autograd import is_recording
 from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP
 from ..ndarray import indexing_key_expand_implicit_axes, get_indexing_dispatch_code,\
@@ -614,8 +615,11 @@ class ndarray(NDArray):
                     key = new_key
             except Exception as err:
                 raise TypeError('{}'.format(str(err)))
-        if isinstance(key, _np.ndarray) and key.dtype == _np.bool_:
-            key = array(key, dtype='bool', ctx=self.ctx)
+        if isinstance(key, _np.ndarray):
+            if dc.is_deferred_compute():
+                raise TypeError('Indexing with a numpy array is not supported in HybridBlock.')
+            if key.dtype == _np.bool_:
+                key = array(key, dtype='bool', ctx=self.ctx)
 
         # Handle single boolean index of matching dimensionality and size first for higher speed
         # If the boolean array is mixed with other idices, it is instead expanded into (multiple)
@@ -671,6 +675,8 @@ class ndarray(NDArray):
                 key = (_np.newaxis,) + key
             return self._get_np_basic_indexing(key)
         elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
+            if dc.is_deferred_compute():
+                raise TypeError('Advanced indexing is not supported in HybridBlock.')
             if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
                 return empty((0,) + self._get_np_adanced_indexing(key).shape,
                              dtype=self.dtype, ctx=self.ctx)
@@ -1273,12 +1279,12 @@ class ndarray(NDArray):
             raise ValueError('casting must be equal to \'unsafe\'')
         if not subok:
             raise ValueError('subok must be equal to True')
+        if dtype is None:
+            dtype = _np.float32
         if not copy and _np.dtype(dtype) == self.dtype:
             return self
 
-        res = empty(self.shape, dtype=dtype, ctx=self.ctx)
-        self.copyto(res)
-        return res
+        return _npi.cast(self, dtype=dtype)
 
     def copyto(self, other):
         """Copies the value of this array to another array.
diff --git a/src/api/operator/utils.cc b/src/api/operator/utils.cc
index 3d84012..79e94cf 100644
--- a/src/api/operator/utils.cc
+++ b/src/api/operator/utils.cc
@@ -38,9 +38,11 @@ void SetInOut(std::vector<NDArray*>* ndinputs,
   for (int i = 0; i < num_inputs; ++i) {
     NDArray* inp = reinterpret_cast<NDArray*>(inputs[i]);
     if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
-      CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) <<
-                "[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than "
-                "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
+      if (shape_is_known(inp->shape())) {  // Shape may be unknown after dynamic shape operators
+        CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1)
+          << "[SetInOut] Size of tensor you are trying to allocate is larger than "
+               "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
+      }
     }
     ndinputs->emplace_back(inp);
   }
@@ -80,9 +82,16 @@ std::vector<NDArray*> Invoke(const nnvm::Op* op,
   SetInOut(&ndinputs, &ndoutputs, num_inputs, inputs,
       num_outputs, infered_num_outputs, num_visible_outputs, outputs);
 
-  auto state = Imperative::Get()->Invoke(Context::CPU(), *attrs, ndinputs, ndoutputs);
-  if (Imperative::Get()->is_recording()) {
-    Imperative::Get()->RecordOp(std::move(*attrs), ndinputs, ndoutputs, state);
+  if (Imperative::Get()->is_deferred_compute()) {
+    Imperative::Get()->RecordDeferredCompute(std::move(*attrs), ndinputs, ndoutputs);
+  } else {
+    for (NDArray *input : ndinputs) {
+      Imperative::DCInfo::Compute(*input);
+    }
+    auto state = Imperative::Get()->Invoke(Context::CPU(), *attrs, ndinputs, ndoutputs);
+    if (Imperative::Get()->is_recording()) {
+      Imperative::Get()->RecordOp(std::move(*attrs), ndinputs, ndoutputs, state);
+    }
   }
   for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i];
   return ndoutputs;
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index fe00a9a..949a594 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1531,12 +1531,21 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim
                      MXAPIThreadLocalEntry<dtype>* ret) {
   NDArray* arr = static_cast<NDArray*>(handle);
   if (!arr->is_none()) {
+    mxnet::TShape s = arr->shape();
+    // Handle dynamic shape in deferred compute mode
+    if (!Imperative::DCInfo::IsNone(*arr)) {
+      if (!shape_is_known(s) && !Imperative::DCInfo::IsComputed(*arr)) {
+        Imperative::DCInfo::Compute(*arr);
+        s = arr->shape();
+      }
+    }
+
     if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
-      CHECK_LT(arr->shape().Size(), (int64_t{1} << 31) - 1) <<
-                      "[Get Shape] Size of tensor you are trying to allocate is larger than "
-                      "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
+      CHECK_LT(s.Size(), (int64_t{1} << 31) - 1) <<
+        "[Get Shape] Size of tensor you are trying to allocate is larger than "
+        "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
     }
-    mxnet::TShape s = arr->shape();
+
     if (!Imperative::Get()->is_np_shape()) {
       common::ConvertToLegacyShape(&s);
     }
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index ef03fe6..45cb71a 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -58,9 +58,11 @@ void SetNDInputsOutputs(const nnvm::Op* op,
   for (int i = 0; i < num_inputs; ++i) {
     NDArray* inp = reinterpret_cast<NDArray*>(inputs[i]);
     if (!features::is_enabled(features::INT64_TENSOR_SIZE)) {
-      CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) <<
-                "[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than "
-                "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
+      if (shape_is_known(inp->shape())) {  // Shape may be unknown after dynamic shape operators
+        CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) <<
+          "[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than "
+          "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
+      }
     }
     ndinputs->emplace_back(inp);
   }
@@ -112,9 +114,16 @@ void MXImperativeInvokeImpl(AtomicSymbolCreator creator,
   SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs,
       num_outputs, infered_num_outputs, num_visible_outputs, outputs);
 
-  auto state = Imperative::Get()->Invoke(Context::CPU(), attrs, ndinputs, ndoutputs);
-  if (Imperative::Get()->is_recording()) {
-    Imperative::Get()->RecordOp(std::move(attrs), ndinputs, ndoutputs, state);
+  if (Imperative::Get()->is_deferred_compute()) {
+    Imperative::Get()->RecordDeferredCompute(std::move(attrs), ndinputs, ndoutputs);
+  } else {
+    for (NDArray* input : ndinputs) {
+      Imperative::DCInfo::Compute(*input);
+    }
+    auto state = Imperative::Get()->Invoke(Context::CPU(), attrs, ndinputs, ndoutputs);
+    if (Imperative::Get()->is_recording()) {
+      Imperative::Get()->RecordOp(std::move(attrs), ndinputs, ndoutputs, state);
+    }
   }
 
   for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i];
@@ -433,3 +442,37 @@ int MXCachedOpRegisterOpHook(NDArrayHandle handle,
   op->RegisterOpHook(clbk, monitor_all);
   API_END();
 }
+
+int MXNDArrayIsDeferredCompute(int *curr) {
+  API_BEGIN();
+  *curr = Imperative::Get()->is_deferred_compute();
+  API_END();
+}
+
+int MXNDArraySetIsDeferredCompute(int deferred_compute, int *prev) {
+  API_BEGIN();
+  *prev = Imperative::Get()->set_is_deferred_compute(static_cast<bool>(deferred_compute));
+  API_END();
+}
+
+int MXNDArraySetDeferredComputeVariable(NDArrayHandle *arrays, SymbolHandle *variables, int num) {
+  API_BEGIN();
+  Imperative::Get()->SetDeferredComputeVariable(arrays, variables, num);
+  API_END();
+}
+
+int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle *output_handles, int num_outputs,
+                                      SymbolHandle *out) {
+  nnvm::Symbol *s = new nnvm::Symbol();
+  API_BEGIN();
+  std::vector<NDArray *> outputs;
+  outputs.reserve(num_outputs);
+  for (int i = 0; i < num_outputs; ++i) {
+    NDArray *array = reinterpret_cast<NDArray *>(output_handles[i]);
+    outputs.emplace_back(array);
+  }
+  // Obtain Symbol
+  *s = Imperative::Get()->GetDeferredComputeSymbol(outputs);
+  *out = s;
+  API_END_HANDLE_ERROR(delete s;);
+}
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index d3db4ba..731ba2e 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -300,7 +300,7 @@ void SetInputIndices(const nnvm::Graph& fwd_graph,
   const auto& indexed_graph = fwd_graph.indexed_graph();
   if (data_indices->ndim() || param_indices.ndim()) {
     CHECK_EQ(data_indices->ndim() + param_indices.ndim(),
-             indexed_graph.input_nodes().size());
+             static_cast<const int>(indexed_graph.input_nodes().size()));
   } else {
     std::vector<uint32_t> tmp;
     tmp.reserve(indexed_graph.input_nodes().size());
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 97a09fd..14fedc9 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -16,19 +16,28 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include <unordered_set>
+#include <algorithm>
 #include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+
 #include "./imperative_utils.h"
 #include "./cached_op.h"
 
+namespace nnvm {
+ObjectPtr CreateVariableNode(const std::string &name);
+}
+
 namespace mxnet {
 #if DMLC_CXX11_THREAD_LOCAL
 thread_local bool Imperative::is_train_ = false;
 thread_local bool Imperative::is_recording_ = false;
+thread_local bool Imperative::is_deferred_compute_ = false;
 thread_local bool Imperative::is_np_shape_thread_local_ = false;
 #else
 MX_THREAD_LOCAL bool Imperative::is_train_ = false;
 MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
+MX_THREAD_LOCAL bool Imperative::is_deferred_compute_ = false;
 MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false;
 #endif
 
@@ -120,6 +129,8 @@ OpStatePtr Imperative::Invoke(
   return ret;
 }
 
+// Create nnvm::NodeEntry for variables' and gradients' autograd_entry_
+// attribute and associate AGInfo with it's info attribute
 void Imperative::MarkVariables(
     const std::vector<NDArray*>& variables,
     const std::vector<uint32_t>& grad_reqs,
@@ -127,17 +138,17 @@ void Imperative::MarkVariables(
   for (uint32_t i = 0; i < variables.size(); ++i) {
     std::string str_c(std::to_string(variable_count_++));
 
-    variables[i]->entry_ = nnvm::NodeEntry{
+    variables[i]->autograd_entry_ = nnvm::NodeEntry{
         nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
-    AGInfo& info = AGInfo::Create(variables[i]->entry_.node);
+    AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
     info.outputs.emplace_back(variables[i]->Detach());
     info.out_grads.emplace_back(gradients[i]->Detach());
     info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
     info.ctx = variables[i]->ctx();
 
-    gradients[i]->entry_ = nnvm::NodeEntry{
+    gradients[i]->autograd_entry_ = nnvm::NodeEntry{
         nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
-    AGInfo& grad_info = AGInfo::Create(gradients[i]->entry_.node);
+    AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
     grad_info.outputs.emplace_back(gradients[i]->Detach());
     grad_info.ctx = gradients[i]->ctx();
   }
@@ -199,6 +210,9 @@ void Imperative::RecordOp(
     std::vector<bool>* p_save_outputs) {
   MXAPIThreadLocalEntry<> *local_buff = MXAPIThreadLocalStore<>::Get();
 
+  CHECK(!is_deferred_compute())
+      << "Autograd recording is not supported during deferred compute mode.";
+
   for (auto output : outputs) {
     CHECK(AGInfo::IsNone(*output))
       << "Assigning to NDArrays that are already in a computational graph "
@@ -250,17 +264,18 @@ void Imperative::RecordOp(
         input_info.outputs.back().dtype_ = inputs[i]->dtype();
         input_info.outputs.back().storage_type_ = inputs[i]->storage_type();
       }
-      inputs[i]->entry_ = std::move(entry);  // assign last to prevent cyclic reference
+      inputs[i]->autograd_entry_ = std::move(entry);  // assign last to prevent cyclic reference
     } else if (save_inputs[i]) {
-      AGInfo::Get(inputs[i]->entry_.node).outputs[inputs[i]->entry_.index] = inputs[i]->Detach();
+      nnvm::NodeEntry& entry = inputs[i]->autograd_entry_;
+      AGInfo::Get(entry.node).outputs[entry.index] = inputs[i]->Detach();
     }
-    node->inputs[i] = inputs[i]->entry_;
+    node->inputs[i] = inputs[i]->autograd_entry_;
   }
 
   for (auto output : outputs) {
     CHECK(AGInfo::IsNone(*output))
-      << "Inplace operations (+=, -=, x[:]=, etc) are not supported when "
-      << "recording with autograd.";
+        << "NotImplementedError: Inplace operations (+=, -=, x[:]=, etc) "
+        << "are not supported when recording with autograd.";
   }
 
   for (uint32_t i = 0; i < outputs.size(); ++i) {
@@ -273,7 +288,88 @@ void Imperative::RecordOp(
       info.outputs.back().dtype_ = outputs[i]->dtype();
       info.outputs.back().storage_type_ = outputs[i]->storage_type();
     }
-    outputs[i]->entry_ = nnvm::NodeEntry{node, i, 0};
+    outputs[i]->autograd_entry_ = nnvm::NodeEntry{node, i, 0};
+  }
+}
+
+void Imperative::RecordDeferredCompute(nnvm::NodeAttrs &&attrs,
+                                       const std::vector<NDArray *> &inputs,
+                                       const std::vector<NDArray *> &outputs) {
+  CHECK(!is_recording())
+      << "MXNetError: Autograd recording is not supported during deferred compute mode.";
+
+  for (const NDArray *input : inputs) {
+    CHECK(!DCInfo::IsNone(*input))
+        << "ValueError: All inputs to deferred compute recording must be associated "
+        << "with a symbolic variable or be the output of a deferred compute operator.";
+  }
+  for (const NDArray *output : outputs) {
+    CHECK(DCInfo::IsNone(*output))
+        << "NotImplementedError: Inplace operations (+=, -=, x[:]=, etc) "
+        << "are not supported when recording in deferred compute mode.";
+  }
+  DispatchMode dispatch_mode = DispatchMode::kUndefined;
+  Context ctx = imperative::GetContext(attrs, inputs, outputs, Context::CPU());
+  imperative::SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode);
+
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+  node->inputs.reserve(inputs.size());
+  // Get NodeEntries for inputs
+  for (const NDArray *array : inputs) {
+    CHECK(array->deferredcompute_entry_.node);  // Must not be nullptr
+    node->inputs.emplace_back(array->deferredcompute_entry_);
+  }
+  node->attrs = std::move(attrs);
+  // Need to support NameManager in imperative API to better name node->attrs.name
+  node->attrs.name = "node_" + std::to_string(node_count_++);
+
+  for (uint32_t i = 0; i < outputs.size(); ++i) {
+    outputs[i]->deferredcompute_entry_ = nnvm::NodeEntry{node, i, 0};
+  }
+
+  DCInfo::Create(node, inputs, outputs);
+}
+
+nnvm::Symbol Imperative::GetDeferredComputeSymbol(const std::vector<NDArray *> &outputs) {
+  Symbol s;
+  s.outputs.reserve(outputs.size());
+  for (NDArray * ndoutput : outputs) {
+    CHECK(!Imperative::DCInfo::IsNone(*ndoutput))
+        << "ValueError: output_arrays for GetDeferredComputeSymbol "
+        << "must have a deferred compute history associated with them.";
+    s.outputs.emplace_back(ndoutput->deferredcompute_entry_);
+  }
+  return s.Copy();
+}
+
+void Imperative::SetDeferredComputeVariable(NDArrayHandle *arrays,
+                                            SymbolHandle *variables, const int num) {
+  // Sanity check all inputs
+  for (int i = 0; i < num; i++) {
+    nnvm::Symbol *s = reinterpret_cast<nnvm::Symbol *>(variables[i]);
+    NDArray *nd = reinterpret_cast<NDArray *>(arrays[i]);
+    CHECK_EQ(s->outputs.size(), 1)
+        << "MXNDArraySetDeferredComputeVariable expects variables as input. "
+        << "Instead got a Symbol with " << s->outputs.size()
+        << " outputs as input " << i;
+    CHECK(s->outputs[0].node->is_variable())
+        << "MXNDArraySetDeferredComputeVariable expects variables as input. "
+        << "Instead got a Symbol associated with an operator as input " << i;
+    CHECK(DCInfo::IsNone(*nd) || nd->deferredcompute_entry_.node == s->outputs[0].node)
+        << "ValueError: array " << i << " is already associated with a different variable. "
+        << "You can call array.detach() to obtain a copy without the variable";
+  }
+
+  // Store variables in DCInfo of arrays
+  for (int i = 0; i < num; i++) {
+    nnvm::Symbol *s = reinterpret_cast<nnvm::Symbol *>(variables[i]);
+    NDArray *nd = reinterpret_cast<NDArray *>(arrays[i]);
+    nd->deferredcompute_entry_ = nnvm::NodeEntry{s->outputs[0].node, 0, 0};
+
+    std::vector<NDArray *> inputs;
+    std::vector<NDArray *> outputs;  // No need to specify outputs, as we will set is_computed_
+    Imperative::DCInfo& info = Imperative::DCInfo::Create(s->outputs[0].node, inputs, outputs);
+    info.is_computed_ = true;
   }
 }
 
@@ -297,7 +393,7 @@ std::vector<NDArray*> Imperative::Backward(
       << "You need to set is_recording to true or use autograd.record() to save "
       << "computational graphs for backward. If you want to differentiate the same "
       << "graph twice, you need to pass retain_graph=True to backward.";
-    graph.outputs.emplace_back(i->entry_);
+    graph.outputs.emplace_back(i->autograd_entry_);
   }
   size_t num_forward_outputs = graph.outputs.size();
 
@@ -333,10 +429,10 @@ std::vector<NDArray*> Imperative::Backward(
     x_reqs.reserve(variables.size());
     for (size_t i = 0; i < variables.size(); ++i) {
       CHECK(!AGInfo::IsNone(*variables[i]) &&
-            AGInfo::IsVariable(variables[i]->entry_.node))
+            AGInfo::IsVariable(variables[i]->autograd_entry_.node))
           << "Cannot differentiate with respect to the " << i+1 << "-th variable"
           << " because it does not require gradient.";
-      xs.emplace_back(variables[i]->entry_);
+      xs.emplace_back(variables[i]->autograd_entry_);
       x_grads.push_back(new NDArray());
       x_reqs.push_back(kWriteTo);
     }
@@ -402,7 +498,7 @@ std::vector<NDArray*> Imperative::Backward(
         size_t nid = idx.node_id(n.get());
         size_t eid = idx.entry_id(nid, i);
         buff[eid] = info.outputs[i];
-        buff[eid].entry_ = NodeEntry{n, i, 0};
+        buff[eid].autograd_entry_ = NodeEntry{n, i, 0};
         ref_count[eid] = 1;
       }
     });
@@ -411,7 +507,7 @@ std::vector<NDArray*> Imperative::Backward(
       if (!idx.exist(ograd_entry.node.get())) continue;
       size_t eid = idx.entry_id(ograd_entry);
       buff[eid] = info.outputs[0];
-      buff[eid].entry_ = ograd_entry;
+      buff[eid].autograd_entry_ = ograd_entry;
     }
   } else {
     states.reserve(num_forward_nodes);
@@ -544,4 +640,75 @@ std::vector<NDArray*> Imperative::Backward(
   return {};
 }
 
+Imperative::DCInfo::DCInfo(const std::vector<NDArray *> &inputs,
+                           const std::vector<NDArray *> &outputs) {
+  this->inputs_.reserve(inputs.size());
+  this->input_handles_.reserve(inputs.size());
+  for (const NDArray *arr : inputs) {
+    CHECK(!arr->is_none());
+    this->inputs_.push_back(*arr);
+    this->input_handles_.push_back(arr);
+  }
+
+  this->outputs_.reserve(outputs.size());
+  for (const NDArray *arr : outputs) {
+    CHECK(!arr->is_none());
+    this->outputs_.push_back(*arr);
+  }
+}
+
+Imperative::DCInfo &
+Imperative::DCInfo::Create(const nnvm::ObjectPtr &node,
+                           const std::vector<NDArray *> &inputs,
+                           const std::vector<NDArray *> &outputs) {
+  node->info.construct<DCInfo>(inputs, outputs);
+  return Imperative::DCInfo::Get(node);
+}
+
+void Imperative::DCInfo::Compute(const NDArray &arr) {
+  if (Imperative::DCInfo::IsComputed(arr)) {
+    if (!shape_is_known(arr.shape())) {
+      // We can't call arr.WaitToRead(); here, as WaitToRead calls Compute
+      // leading to an infinite loop.
+      Engine::Get()->WaitForVar(arr.ptr_->var);
+      if (shape_is_known(arr.ptr_->storage_shape)) {
+        arr.SetShapeFromChunk();
+      } else {
+        CHECK(shape_is_known(arr.shape()));
+      }
+    }
+    return;
+  }
+
+  DCInfo &info = Imperative::DCInfo::Get(arr.deferredcompute_entry_.node);
+  info.is_computed_ = true;  // We will Invoke at the end of this function.
+
+  // Recursively compute input arrays
+  for (const NDArray &input : info.inputs_) {
+    Compute(input);
+  }
+
+  // Prepare pointers
+  std::vector<NDArray *> ndinputs, ndoutputs;
+  ndinputs.reserve(info.inputs_.size());
+  ndoutputs.reserve(info.outputs_.size());
+  for (NDArray &input : info.inputs_)
+    ndinputs.push_back(&input);
+  for (NDArray &output : info.outputs_)
+    ndoutputs.push_back(&output);
+
+  // Compute this array
+  Imperative::Get()->Invoke(Context::CPU(),
+                            arr.deferredcompute_entry_.node->attrs, ndinputs,
+                            ndoutputs);
+  if (!shape_is_known(arr.shape())) {
+      arr.WaitToRead();
+      arr.SetShapeFromChunk();
+  }
+
+  // Deallocate copies
+  info.inputs_.clear();
+  info.outputs_.clear();
+}
+
 }  // namespace mxnet
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 21d5298..12546ae 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -96,7 +96,12 @@ inline Context GetContext(const nnvm::NodeAttrs& attrs,
   return ctx;
 }
 
-// Set the shape, dtype, storage type and dispatch mode via the attribute inference functions
+/*! \brief Set the shape, dtype, storage type and dispatch mode via the
+ * attribute inference functions
+ *
+ * Inferred information is stored in MXAPIThreadLocalEntry. Existing information
+ * is overwritten.
+ */
 inline void SetShapeType(const Context& ctx,
                          const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray*>& inputs,
@@ -123,6 +128,17 @@ inline void SetShapeType(const Context& ctx,
   if (!infershape.count(attrs.op)) {
     is_dynamic_shape_existing = true;
   } else {
+    // If any of the inputs is a deferred computed array with unknown shape, we
+    // can't infer shapes.
+    for (const NDArray *i : inputs) {
+      if (!shape_is_known(i->shape()) && !Imperative::DCInfo::IsNone(*i)) {
+        is_dynamic_shape_existing = true;
+        break;
+      }
+    }
+  }
+
+  if (!is_dynamic_shape_existing) {
     if (!Imperative::Get()->is_np_shape()) {
       common::ConvertToNumpyShape(&in_shapes);
       common::ConvertToNumpyShape(&out_shapes);
@@ -202,7 +218,8 @@ inline void SetShapeType(const Context& ctx,
 
   for (size_t i = 0; i < outputs.size(); ++i) {
     NDArrayStorageType storage_type = static_cast<NDArrayStorageType>(out_storage_types[i]);
-    if (outputs[i]->is_none() || mxnet::op::shape_is_none(outputs[i]->shape())) {
+    if (outputs[i]->is_none() || (mxnet::op::shape_is_none(outputs[i]->shape()) &&
+                                   Imperative::DCInfo::IsNone(*outputs[i]))) {
       if (is_dynamic_shape_existing) {
         // once there is dynamic shape somewhere, we could not pre-determine the shape.
         *outputs[i] = NDArray(ctx, out_types[i]);
@@ -214,19 +231,35 @@ inline void SetShapeType(const Context& ctx,
         *outputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]);
         outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name);
       }
+    } else if (mxnet::op::shape_is_none(outputs[i]->shape()) &&
+               !Imperative::DCInfo::IsNone(*outputs[i])) {
+      // For deferred computed arrays with unknown shape (following dynamic
+      // shape operator), don't use copy assignment as it would destroy the
+      // deferredcompute metadata.
+      if (!is_dynamic_shape_existing) {
+        outputs[i]->Init(out_shapes[i]);
+      }
+      CHECK_EQ(outputs[i]->dtype(), out_types[i])
+        << i << "-th output has invalid dtype. "
+        << "Expecting " << out_types[i] << " got " << outputs[i]->dtype()
+        << " in operator " << attrs.op->name;
     } else {
       CHECK_EQ(outputs[i]->shape(), out_shapes[i])
         << i << "-th output has invalid shape. "
         << "Expecting " << out_shapes[i] << " got "
         << outputs[i]->shape() << " in operator " << attrs.op->name;
       CHECK_EQ(outputs[i]->dtype(), out_types[i])
-        << i << "-th output has invalid shape. "
+        << i << "-th output has invalid dtype. "
         << "Expecting " << out_types[i] << " got "
         << outputs[i]->dtype()  << " in operator " << attrs.op->name;
     }
   }
 }
 
+/*! \brief Set read and write vars, resource requests and mutate_idx
+ *
+ * For inputs and outputs arguments only NDArray::var() is accessed.
+ */
 inline void SetDependency(const nnvm::NodeAttrs& attrs,
                    const Context& ctx,
                    const std::vector<NDArray*>& inputs,
@@ -300,6 +333,11 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs,
   Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars);
 }
 
+/*! \brief Reset vector of OpReqType *req based on input and output NDArrays.
+ *
+ * Set to kWriteInplace if corresponding output shares variable with any input
+ * NDArray. Set to kWriteTo otherwise.
+ */
 inline void SetWriteInplaceReq(const std::vector<NDArray*>& inputs,
                         const std::vector<NDArray*>& outputs,
                         std::vector<OpReqType> *req) {
@@ -385,6 +423,9 @@ inline void SetNumOutputs(const nnvm::Op *op,
   }
 }
 
+/*!
+ * \brief Copy-construct NDArrays referenced by inputs and outputs to p_inputs and p_outputs
+ */
 inline void DerefInputOutput(const std::vector<NDArray*>& inputs,
                              const std::vector<NDArray*>& outputs,
                              std::vector<NDArray>* p_inputs,
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index f851383..8f5612c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -51,7 +51,7 @@ namespace mxnet {
 NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Context ctx,
     bool delay_alloc, int dtype, std::vector<int> aux_types,
     mxnet::ShapeVector aux_shapes, mxnet::TShape storage_shape) : shape_(shape),
-  dtype_(dtype), storage_type_(stype), entry_(nullptr) {
+  dtype_(dtype), storage_type_(stype), autograd_entry_(nullptr) {
   // Assign default aux types if not given
   if (aux_types.size() == 0
       && stype != kDefaultStorage) {
@@ -113,7 +113,7 @@ void NDArray::AssignStorageInfo(const std::string& profiler_scope,
   }
 }
 
-void NDArray::SetShapeFromChunk() {
+void NDArray::SetShapeFromChunk() const {
   if (Imperative::Get()->is_np_shape() ||
       !(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
     shape_ = ptr_->storage_shape;
@@ -182,7 +182,7 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) {
 
 NDArray NDArray::grad() const {
   if (Imperative::AGInfo::IsNone(*this)) return NDArray();
-  Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
+  Imperative::AGInfo& info = Imperative::AGInfo::Get(autograd_entry_.node);
   if (info.out_grads.size()) {
     CHECK_EQ(info.out_grads.size(), 1);
     return info.out_grads[0];
@@ -194,14 +194,14 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
   CHECK(!Imperative::AGInfo::IsNone(*this))
     << "NDArray is not part of a computation graph. Did you forget to turn on recording?";
   nnvm::Symbol ret;
-  ret.outputs.emplace_back(entry_);
+  ret.outputs.emplace_back(autograd_entry_);
   return ret;
 }
 
 #if MXNET_USE_MKLDNN == 1
 
 NDArray::NDArray(const mkldnn::memory::desc &md)
-    : storage_type_(kDefaultStorage), entry_(nullptr) {
+    : storage_type_(kDefaultStorage), autograd_entry_(nullptr) {
   shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
   dtype_ = get_mxnet_type(md.data.data_type);
   ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
@@ -210,7 +210,7 @@ NDArray::NDArray(const mkldnn::memory::desc &md)
 }
 
 NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
-    : storage_type_(kDefaultStorage), entry_(nullptr) {
+    : storage_type_(kDefaultStorage), autograd_entry_(nullptr) {
   auto mem_desc = mkldnn_mem->get_desc();
   shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
   dtype_ = get_mxnet_type(mem_desc.data.data_type);
@@ -284,12 +284,34 @@ NDArray NDArray::Reshape(const mxnet::TShape &shape) const {
 }
 
 NDArray NDArray::ReshapeWithRecord(const mxnet::TShape &shape) {
-  NDArray ret = this->Reshape(shape);
-  if (!Imperative::Get()->is_recording()) return ret;
+  bool is_recording = Imperative::Get()->is_recording();
+  bool is_deferred_compute = Imperative::Get()->is_deferred_compute();
+  NDArray ret;
+  if (!is_deferred_compute) {
+    // The new array shares memory with this array, thus make sure this array
+    // has been computed already computed. (noop if this array is not deferred)
+    Imperative::DCInfo::Compute(*this);
+    ret = this->Reshape(shape);
+    if (!is_recording) {
+      return ret;
+    }
+  } else {
+    if (shape_is_known(this->shape())) {
+      // Imperative reshape only works if shape is already known.
+      ret = this->Reshape(shape);
+    } else {
+      // Reshape called on after dynamic shape operator.
+      ret = this->Detach();
+    }
+  }
+
+  if (!is_deferred_compute || shape_is_known(this->shape())) {
+    CHECK_EQ(shape_.Size(), shape.Size())
+        << "NDArray.Reshape: target shape must have the same size as "
+        << "current shape when recording with autograd "
+        << "or in deferred compute mode.";
+  }
 
-  CHECK_EQ(shape_.Size(), shape.Size())
-    << "NDArray.Reshape: target shape must have the same size as "
-    << "current shape when recording with autograd.";
   nnvm::NodeAttrs attrs;
   attrs.op = nnvm::Op::Get("Reshape");;
   std::ostringstream os;
@@ -297,7 +319,12 @@ NDArray NDArray::ReshapeWithRecord(const mxnet::TShape &shape) {
   attrs.dict.insert({"shape", os.str()});
   attrs.op->attr_parser(&attrs);
   std::vector<NDArray*> inputs(1, this), outputs(1, &ret);
-  Imperative::Get()->RecordOp(std::move(attrs), inputs, outputs);
+
+  if (is_recording) {
+    Imperative::Get()->RecordOp(std::move(attrs), inputs, outputs);
+  } else if (is_deferred_compute) {
+    Imperative::Get()->RecordDeferredCompute(std::move(attrs), inputs, outputs);
+  }
   return ret;
 }
 
@@ -318,8 +345,27 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
 }
 
 NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
-  NDArray ret = this->Slice(begin, end);
-  if (!Imperative::Get()->is_recording()) return ret;
+  bool is_recording = Imperative::Get()->is_recording();
+  bool is_deferred_compute = Imperative::Get()->is_deferred_compute();
+  NDArray ret;
+  if (!is_deferred_compute) {
+    // The new array shares memory with this array, thus make sure this array
+    // has been computed already computed. (noop if this array is not deferred)
+    Imperative::DCInfo::Compute(*this);
+    ret = this->Slice(begin, end);
+    if (!is_recording) {
+      return ret;
+    }
+  } else {
+    if (shape_is_known(this->shape())) {
+      // Imperative slice only works if shape is already known.
+      ret = this->Slice(begin, end);
+    } else {
+      // Slice called on after dynamic shape operator.
+      ret = this->Detach();
+    }
+  }
+
   // fake a slice op
   nnvm::NodeAttrs attrs;
   attrs.op = nnvm::Op::Get("slice");
@@ -327,7 +373,13 @@ NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
   attrs.dict.insert({"end", std::to_string(end)});
   attrs.op->attr_parser(&attrs);
   std::vector<NDArray*> inputs(1, this), outputs(1, &ret);
-  Imperative::Get()->RecordOp(std::move(attrs), inputs, outputs);
+
+  if (is_recording) {
+    Imperative::Get()->RecordOp(std::move(attrs), inputs, outputs);
+  } else if (is_deferred_compute) {
+    Imperative::Get()->RecordDeferredCompute(std::move(attrs), inputs, outputs);
+  }
+
   return ret;
 }
 
@@ -406,7 +458,7 @@ NDArray NDArray::FromDLPack(const DLManagedTensor* tensor, bool transient_handle
 
 bool NDArray::fresh_out_grad() const {
   if (Imperative::AGInfo::IsNone(*this)) return false;
-  Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
+  Imperative::AGInfo& info = Imperative::AGInfo::Get(autograd_entry_.node);
   return info.fresh_out_grad;
 }
 
@@ -414,7 +466,7 @@ bool NDArray::fresh_out_grad() const {
 void NDArray::set_fresh_out_grad(bool state) const {
   CHECK(!Imperative::AGInfo::IsNone(*this))
     << "NDArray has not been marked as a variable and does not have gradient state";
-  Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
+  Imperative::AGInfo& info = Imperative::AGInfo::Get(autograd_entry_.node);
   info.fresh_out_grad = state;
 }
 
@@ -2057,8 +2109,9 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const {
   }
   TBlob dst(data, dshape, cpu::kDevMask, this->dtype_, 0); // NOLINT(*)
 
+  this->WaitToRead();
+
   if (this->ctx().dev_mask() == cpu::kDevMask) {
-    this->WaitToRead();
     RunContext rctx{this->ctx(), nullptr, nullptr, false};
     NDArray src = *this;
 #if MXNET_USE_MKLDNN == 1
@@ -2119,6 +2172,22 @@ void NDArray::SyncCheckFormat(const bool full_check) const {
   CHECK_EQ(err, kNormalErr) << "Check the validity of this sparse NDArray";
 }
 
+void NDArray::WaitToRead() const {
+  if (is_none()) return;
+  Imperative::DCInfo::Compute(*this);
+  Engine::Get()->WaitForVar(ptr_->var);
+}
+
+void NDArray::WaitToWrite() const {
+  if (is_none()) return;
+  Imperative::DCInfo::Compute(*this);
+  // Push an empty mutable function to flush all preceding reads to the variable.
+  Engine::Get()->PushAsync(
+      [](RunContext, Engine::CallbackOnComplete on_complete) { on_complete(); },
+      Context{}, {}, {ptr_->var});
+  Engine::Get()->WaitForVar(ptr_->var);
+}
+
 #if MXNET_PREDICT_ONLY == 0
 // register API function
 // those with underscore will be registered at NDArray
@@ -2148,6 +2217,17 @@ void CopyFromToSimple(
   CopyFromTo(inputs[0], outputs[0], 0, true);
 }
 
+bool CopyToType(const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs,
+                   std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  int in_type = in_attrs->at(0);
+  if (out_attrs->at(0) == -1) {
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_type);
+  }
+  return out_attrs->at(0) != -1;
+}
+
 // copy function is special
 // that we need to remove kAcceptEmptyMutateTarget from it
 NNVM_REGISTER_OP(_copyto)
@@ -2155,10 +2235,7 @@ NNVM_REGISTER_OP(_copyto)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr<mxnet::FInferShape>("FInferShape", op::ElemwiseShape<1, 1>)
-.set_attr<nnvm::FInferType>("FInferType",
-  [](const NodeAttrs& attrs, std::vector<int> *in_type, std::vector<int> *out_type) {
-    return !op::type_is_none((*in_type)[0]) && !op::type_is_none((*out_type)[0]);
-  })
+.set_attr<nnvm::FInferType>("FInferType", CopyToType)
 .set_attr<FInferStorageType>("FInferStorageType",
   [](const NodeAttrs& attrs,
      const int dev_mask,
diff --git a/tests/python/gpu/test_deferred_compute_gpu.py b/tests/python/gpu/test_deferred_compute_gpu.py
new file mode 100644
index 0000000..7503c7b
--- /dev/null
+++ b/tests/python/gpu/test_deferred_compute_gpu.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+
+import mxnet as mx
+mx.test_utils.set_default_context(mx.gpu(0))
+
+curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+sys.path.insert(0, os.path.join(curr_path, '../unittest'))
+# We import all tests from ../unittest/test_deferred_compute.py
+# They will be detected by nose, as long as the current file has a different filename
+from test_deferred_compute import *
+
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()
diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py
new file mode 100644
index 0000000..cebb690
--- /dev/null
+++ b/tests/python/unittest/test_deferred_compute.py
@@ -0,0 +1,536 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import functools
+import operator
+
+import numpy as np
+from nose.tools import raises
+
+import mxnet as mx
+import mxnet._deferred_compute as dc
+from mxnet.base import MXNetError
+
+
+def _all_same(arrays1, arrays2, message=''):
+    same = all(np.array_equal(a1, a2) for a1, a2 in zip(arrays1, arrays2))
+    if not same:
+        raise AssertionError('Arrays not equal ({}):\n{}\n\n{}'.format(message, arrays1, arrays2))
+
+
+def _assert_dc(setup, compute, mode='all', setup_is_deterministic=True, numpy=True):
+    """Compare results of deferred compute and normal imperative mode.
+
+    Parameters
+    ----------
+    setup : callable
+        Setup function computing inputs for compute function. Always called
+        outside of deferred compute.
+    compute : callable
+        Compute function. We compare the output between normal computation and
+        deferred compute.
+    mode : {'all', 'symbolic', 'imperative', 'imperativewithnondccompute'}
+        Compare deferred compute outputs triggered via imperative computation
+        (eg. asnumpy() conversion) or obtained from the exported symbol or
+        both.
+    setup_is_deterministic : bool
+        If True, setup function may be called multiple times. If False, will
+        only be called once.
+    numpy : bool
+        If True, use mx.np. Otherwise mx.nd.
+
+    """
+    try:
+        nd = mx.np if numpy else mx.nd
+        if numpy:
+            mx.npx.set_np()
+
+        xs = setup(nd=nd)
+        ys = compute(*xs, nd=nd)
+
+        ys_np = [y.asnumpy() for y in ys]
+
+        if setup_is_deterministic:
+            xs = setup(nd=nd)
+
+        xs_names = list(map(str, range(len(xs))))
+        symbol_inputs = [
+            mx.symbol.var(name).as_np_ndarray()
+            if numpy else mx.symbol.var(name)
+            for arg, name in zip(xs, xs_names)
+        ]
+        dc.set_variable(xs, symbol_inputs)
+        with dc.context():
+            ys_dc = compute(*xs, nd=nd)
+
+        assert mode in ('all', 'symbolic', 'imperative', 'imperativewithnondccompute')
+        if mode in ('all', 'imperativewithnondccompute'):
+            ys_dc_np = [(y + 0).asnumpy() for y in ys_dc]
+            _all_same(ys_np, ys_dc_np)
+
+        if mode in ('all', 'imperative'):
+            ys_dc_np = [y.asnumpy() for y in ys_dc]
+            _all_same(ys_np, ys_dc_np)
+
+        if mode in ('all', 'symbolic'):
+            sym = dc.get_symbol(ys_dc, sym_cls=mx.sym.np._Symbol if numpy else mx.sym.Symbol)
+
+            if setup_is_deterministic:
+                xs = setup(nd=nd)
+
+            args = {name: x for name, x in zip(xs_names, xs)}
+            ys_sym = sym.bind(mx.context.current_context(), args=args).forward()
+
+            ys_sym_np = [y.asnumpy() for y in ys_sym]
+            _all_same(ys_np, ys_sym_np)
+    finally:
+        if numpy:
+            mx.npx.reset_np()
+
+
+def _all_assert_dc(setup, compute, setup_is_deterministic=True, numpy=(False, True)):
+    for mode in ('all', 'symbolic', 'imperative', 'imperativewithnondccompute'):
+        for numpy_ in numpy:
+            _assert_dc(setup, compute, mode=mode, setup_is_deterministic=True, numpy=numpy_)
+
+
+###############################################################################
+# Test cases without inputs
+###############################################################################
+def _dc_empty_setup(*, nd):
+    return []
+
+
+def test_dc_no_inputs_single_output():
+    def f(*, nd):
+        a = nd.arange(10)
+        b = a + nd.arange(a.shape[0])
+        c = b - 1
+        return [c]
+
+    _all_assert_dc(_dc_empty_setup, f)
+
+
+def test_dc_no_inputs_reshape():
+    def f(*, nd):
+        a = nd.arange(10)
+        b = a + nd.arange(a.shape[0])
+        c = b.reshape((5, 2))
+        d = b.reshape((2, 5))
+        e = (c.reshape((-1, )) + d.reshape((-1, ))) / 2
+        return [c + 1, d + 1, e]
+
+    _all_assert_dc(_dc_empty_setup, f)
+
+
+def test_dc_no_inputs_slice():
+    def f(*, nd):
+        a = nd.arange(10)
+        b = a[:5]
+        if nd is mx.nd:
+            c = nd.concat(b, b, dim=0)
+        else:
+            c = nd.concatenate([b, b], axis=0)
+        return [c + a]
+
+    _all_assert_dc(_dc_empty_setup, f)
+
+
+def test_dc_no_inputs_subset_of_output():
+    def f(*, nd):
+        a = nd.arange(10)
+        if nd is mx.nd:
+            b, c = mx.nd.split(a, 2, axis=0)
+        else:
+            b, c = mx.np.array_split(a, 2, axis=0)
+        return [b]
+
+    _all_assert_dc(_dc_empty_setup, f)
+
+
+###############################################################################
+# Test cases with inputs
+###############################################################################
+def _dc_simple_setup(shape=(10, ), *, nd):
+    n = functools.reduce(operator.mul, shape, 1)
+    return [nd.arange(n).reshape(shape)]
+
+
+def test_dc_single_output():
+    def f(a, *, nd):
+        b = a + nd.arange(a.shape[0])
+        c = b - 1
+        return [c]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_reshape():
+    def f(a, *, nd):
+        b = a + nd.arange(a.shape[0])
+        c = b.reshape((5, 2))
+        d = b.reshape((2, 5))
+        e = (c.reshape((-1, )) + d.reshape((-1, ))) / 2
+        return [c + 1, d + 1, e]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_slice():
+    def f(a, *, nd):
+        b = a[:5]
+        if nd is mx.nd:
+            c = nd.concat(b, b, dim=0)
+        else:
+            c = nd.concatenate([b, b], axis=0)
+        return [c + a]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_subset_of_output():
+    def f(a, *, nd):
+        if nd is mx.nd:
+            b, c = mx.nd.split(a, 2, axis=0)
+        else:
+            b, c = mx.np.array_split(a, 2, axis=0)
+        return [b]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+@raises(MXNetError)  # Should raise NotImplementedError https://github.com/apache/incubator-mxnet/issues/17522
+def test_dc_inplace():
+    def f(a, *, nd):
+        a[:5] = 0
+        b = a + 1
+        return [a, b]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+###############################################################################
+# Special cases
+###############################################################################
+def test_dc_input_part_of_output():
+    a = mx.np.arange(10)
+    dc.set_variable(a, mx.sym.var('a'))
+    with dc.context():
+        b = a + 1
+    dc.get_symbol([a, b])
+
+
+def test_dc_get_symbol_called_twice():
+    a = mx.np.arange(10)
+    dc.set_variable(a, mx.sym.var('a'))
+    with dc.context():
+        b = a + 1
+    sym1 = dc.get_symbol(b)
+    sym2 = dc.get_symbol(b)
+    assert sym1.list_inputs() == ['a']
+    assert sym2.list_inputs() == ['a']
+
+
+@raises(MXNetError)  # Should raise ValueError https://github.com/apache/incubator-mxnet/issues/17522
+def test_dc_set_variable_called_twice():
+    a = mx.np.arange(10)
+    dc.set_variable(a, mx.sym.var('a'))
+    dc.set_variable(a, mx.sym.var('b'))
+
+
+def test_dc_no_inputs_context_switch():
+    def f(*, nd):
+        a = nd.arange(10)
+        if nd is mx.nd:
+            b = a.as_in_context(mx.cpu(1))
+            c = (b - 1).as_in_context(mx.context.current_context())
+        else:
+            b = a.as_in_ctx(mx.cpu(1))
+            c = (b - 1).as_in_ctx(mx.context.current_context())
+        return [c]
+
+    _assert_dc(_dc_empty_setup, f)
+
+
+def test_dc_context_switch():
+    def f(a, *, nd):
+        if nd is mx.nd:
+            b = a.as_in_context(mx.cpu(1))
+            c = (b - 1).as_in_context(mx.context.current_context())
+        else:
+            b = a.as_in_ctx(mx.cpu(1))
+            c = (b - 1).as_in_ctx(mx.context.current_context())
+        return [c]
+
+    _assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_astype():
+    def f(a, *, nd):
+        a = a.astype(np.int32)
+        b = nd.zeros_like(a)
+        return [a + b]
+
+    _assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_dynamic_shape():
+    def f(a, *, nd):
+        return [mx.nd.np.flatnonzero(a)]
+
+    # Skip GraphExecutor test due to https://github.com/apache/incubator-mxnet/issues/17810
+    for mode in ('imperative', 'imperativewithnondccompute'):
+        _assert_dc(_dc_simple_setup, f, mode=mode, numpy=True)
+
+
+###############################################################################
+# Indexing specific tests
+###############################################################################
+def test_dc_integer_indexing():
+    def f(a, *, nd):
+        return [a[1] + 1]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_slice_indexing():
+    def f(a, *, nd):
+        b = a.reshape((5, 2))
+        return [b[:2, 1] + 1]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_tuple_indexing():
+    def f(a, *, nd):
+        b = a.reshape((5, 2))
+        return [b[(1, 1)] + 1]
+
+    _all_assert_dc(_dc_simple_setup, f)
+
+
+def test_dc_simple_boolean_indexing():
+    if mx.test_utils.default_context() == mx.gpu(0) and mx.runtime.Features().is_enabled("TVM_OP"):
+        # Skip due to https://github.com/apache/incubator-mxnet/issues/17886
+        return
+
+    def setup(*, nd):
+        assert nd is mx.np
+        x = mx.np.array([[0, 1], [1, 1], [2, 2]])
+        return [x, x < 2]
+
+    def f(a, idx, *, nd):
+        assert nd is mx.np
+        return [a[idx].reshape((2, 2))]
+
+    # Skip GraphExecutor test due to https://github.com/apache/incubator-mxnet/issues/17810
+    for mode in ('imperative', 'imperativewithnondccompute'):
+        _assert_dc(setup, f, mode=mode)
+
+
+@raises(TypeError)  # Advanced indexing
+def test_dc_list_indexing():
+    def f(a, *, nd):
+        assert nd is mx.np
+        return [a[[1, 2, 3]]]
+
+    for mode in ('all', 'symbolic', 'imperative', 'imperativewithnondccompute'):
+        _assert_dc(_dc_simple_setup, f, mode=mode)
+
+
+@raises(TypeError)  # Advanced indexing
+def test_dc_numpy_indexing():
+    def f(a, *, nd):
+        assert nd is mx.np
+        return [a[np.array([1, 2, 3])]]
+
+    for mode in ('all', 'symbolic', 'imperative', 'imperativewithnondccompute'):
+        _assert_dc(_dc_simple_setup, f, mode=mode)
+
+
+###############################################################################
+# Gluon
+###############################################################################
+def _assert_dc_gluon(setup, net, setup_is_deterministic=True, numpy=True, autograd=True):
+    """Compare results of deferred compute and normal imperative mode.
+
+    Parameters
+    ----------
+    setup : callable
+        Setup function computing inputs for compute function. Always called
+        outside of deferred compute.
+    net : Block
+    setup_is_deterministic : bool
+        If True, setup function may be called multiple times. If False, will
+        only be called once.
+    numpy : bool
+        If True, use mx.np. Otherwise mx.nd.
+    autograd : bool
+        Wrap in autograd
+
+    """
+    nd = mx.np if numpy else mx.nd
+
+    xs = setup(nd=nd)
+    ys = net(*xs)
+    ys_np = [y.asnumpy() for y in ys]
+
+    net.hybridize()
+    if setup_is_deterministic:
+        xs = setup(nd=nd)
+
+    if autograd:
+        with mx.autograd.record():
+            ys_hybrid = net(*xs)
+        mx.autograd.backward(ys_hybrid)
+        [p.grad() for p in net.collect_params().values()]
+    else:
+        ys_hybrid = net(*xs)
+    ys_hybrid_np = [y.asnumpy() for y in ys_hybrid]
+
+    _all_same(ys_np, ys_hybrid_np)
+
+
+def _dc_gluon_simple_setup(shape=(8, 10), *, nd):
+    return [nd.ones(shape=shape, ctx=mx.context.current_context())]
+
+
+def test_dc_hybridblock():
+    class MyBlock(mx.gluon.HybridBlock):
+        def __init__(self, *, prefix=None, params=None):
+            super().__init__(prefix, params)
+            with self.name_scope():
+                self.dense = mx.gluon.nn.Dense(units=10, in_units=10)
+                self.weight = self.params.get('weight', shape=(10, ))
+
+        def forward(self, x):
+            assert x.shape[1] == 10  # due to in_units=10 above
+            return self.dense(x) + self.weight.data(x.context)
+
+    net = MyBlock()
+    net.initialize()
+    _assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=False)
+    with mx.util.np_array(True):
+        net = MyBlock()
+        net.initialize()
+        _assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True)
+
+
+@raises(RuntimeError)
+def test_dc_hybridblock_deferred_init_no_infer_shape():
+    class MyBlock(mx.gluon.HybridBlock):
+        def __init__(self, *, prefix=None, params=None):
+            super().__init__(prefix, params)
+            with self.name_scope():
+                self.dense = mx.gluon.nn.Dense(units=10)
+                self.weight = self.params.get('weight', allow_deferred_init=True)
+
+        def forward(self, x):
+            return self.dense(x) + self.weight.data(x.context)
+
+    net = MyBlock()
+    net.initialize()
+    data = mx.nd.ones(shape=(8, 10), ctx=mx.context.current_context())
+    net(data)  # Raises RuntimeError
+
+
+def test_dc_hybridblock_deferred_init():
+    class MyBlock(mx.gluon.HybridBlock):
+        def __init__(self, *, prefix=None, params=None):
+            super().__init__(prefix, params)
+            with self.name_scope():
+                self.dense = mx.gluon.nn.Dense(units=10)
+                self.weight = self.params.get('weight', allow_deferred_init=True)
+
+        def infer_shape(self, x):
+            self.weight.shape = (x.shape[1], )
+
+        def forward(self, x):
+            return self.dense(x) + self.weight.data(x.context)
+
+    net = MyBlock()
+    net.initialize()
+    _assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=False)
+    with mx.util.np_array(True):
+        net = MyBlock()
+        net.initialize()
+        _assert_dc_gluon(_dc_gluon_simple_setup, net, numpy=True)
+
+
+def test_dc_hybridblock_dynamic_shape():
+    if mx.test_utils.default_context() == mx.gpu(0) and mx.runtime.Features().is_enabled("TVM_OP"):
+        # Skip due to https://github.com/apache/incubator-mxnet/issues/17886
+        return
+
+    class MyBlock(mx.gluon.HybridBlock):
+        def __init__(self, *, prefix=None, params=None):
+            super().__init__(prefix, params)
+            with self.name_scope():
+                self.dense = mx.gluon.nn.Dense(units=10)
+
+        def forward(self, x, idx):
+            return x[idx].reshape((2, 2)), mx.np.flatnonzero(self.dense(x))
+
+    def setup(*, nd):
+        assert nd is mx.np
+        x = mx.np.array([[0, 1], [1, 1], [2, 2]])
+        return [x, x < 2]
+
+    with mx.util.np_array(True):
+        net = MyBlock()
+        net.initialize()
+        _assert_dc_gluon(setup, net, numpy=True)
+
+
+@raises(RuntimeError)
+def test_dc_hybridblock_symbolblock():
+    model = mx.gluon.nn.HybridSequential()
+    model.add(mx.gluon.nn.Dense(128, activation='tanh'))
+    model.add(mx.gluon.nn.Dropout(0.5))
+    model.add(mx.gluon.nn.Dense(64, activation='tanh'),
+              mx.gluon.nn.Dense(32, in_units=64))
+    model.add(mx.gluon.nn.Activation('relu'))
+    model.initialize()
+    inputs = mx.sym.var('data')
+    outputs = model(inputs).get_internals()
+    smodel = mx.gluon.SymbolBlock(outputs, inputs, params=model.collect_params())
+    assert len(smodel(mx.nd.zeros((16, 10)))) == 14
+
+    class Net(mx.gluon.HybridBlock):
+        def __init__(self, model):
+            super(Net, self).__init__()
+            self.model = model
+
+        def forward(self, x):
+            out = self.model(x)
+            return mx.nd.add_n(*[i.sum() for i in out])
+
+    net = Net(smodel)
+    net.hybridize()
+    data = mx.nd.zeros((16, 10))
+    out = net(data)
+    out.asnumpy()
+
+    net.hybridize()
+    out_hybrid = net(data)  # Raises RuntimeError
+
+    _all_same([out], [out_hybrid])
+
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()