You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/29 15:12:19 UTC

[GitHub] zheng-da closed pull request #11450: [jsut for testing] Foreach1

zheng-da closed pull request #11450: [jsut for testing] Foreach1
URL: https://github.com/apache/incubator-mxnet/pull/11450
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py
new file mode 100644
index 00000000000..5e41b7508b6
--- /dev/null
+++ b/benchmark/python/control_flow/rnn.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import subprocess
+import mxnet as mx
+from mxnet import gluon
+import time
+import copy
+
+def get_gpus():
+    """
+    return a list of GPUs
+    """
+    try:
+        re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
+    except OSError:
+        return []
+    return range(len([i for i in re.split('\n') if 'GPU' in i]))
+
+class TestRNNLayer(gluon.HybridBlock):
+    def __init__(self, cell, prefix=None, params=None):
+        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+        self.cell = cell
+
+    def hybrid_forward(self, F, inputs, states):
+        out, states = F.contrib.foreach(self.cell, inputs, states)
+        return out
+
+def benchmark_rnn(cell, rnn_data, states):
+    ctx = rnn_data.context
+    num_batches = 20
+
+    # Imperative
+    cell0 = copy.deepcopy(cell)
+    layer0 = TestRNNLayer(cell0)
+    layer0.initialize(ctx=ctx)
+
+    # Hybridize
+    cell1 = copy.deepcopy(cell)
+    cell1.hybridize()
+    layer1 = TestRNNLayer(cell1)
+    layer1.initialize(ctx=ctx)
+
+    # Hybridize
+    cell2 = copy.deepcopy(cell)
+    layer2 = TestRNNLayer(cell2)
+    layer2.initialize(ctx=ctx)
+    layer2.hybridize()
+    layer2(rnn_data, states)
+
+    # Hybridize
+    cell3 = copy.deepcopy(cell)
+    cell3.hybridize(static_alloc=True)
+    layer3 = TestRNNLayer(cell3)
+    layer3.initialize(ctx=ctx)
+
+    tic = time.time()
+    for i in range(num_batches):
+        res0 = layer0(rnn_data, states)
+        mx.nd.waitall()
+    print("Imperative inference takes " + str(time.time() - tic))
+
+    tic = time.time()
+    for i in range(num_batches):
+        res1 = layer1(rnn_data, states)
+        mx.nd.waitall()
+    print("Hybrid-cell inference takes " + str(time.time() - tic))
+
+    tic = time.time()
+    for i in range(num_batches):
+        res3 = layer3(rnn_data, states)
+        mx.nd.waitall()
+    print("Static-hybrid-cell inference takes " + str(time.time() - tic))
+
+    tic = time.time()
+    for i in range(num_batches):
+        res2 = layer2(rnn_data, states)
+        mx.nd.waitall()
+    print("Hybrid inference takes " + str(time.time() - tic))
+
+    layer2.export("foreach_rnn")
+    symnet = mx.symbol.load('foreach_rnn-symbol.json')
+    args1 = {}
+    params = layer2.collect_params()
+    for key in params.keys():
+        args1[key] = params[key].data()
+    args1['data0'] = rnn_data
+    for i in range(len(states)):
+        args1['data' + str(i + 1)] = states[i]
+    exe = symnet.bind(ctx=ctx, args=args1)
+    tic = time.time()
+    for i in range(num_batches):
+        exe.forward(is_train=False)
+        mx.nd.waitall()
+    print("Symbol inference takes " + str(time.time() - tic))
+
+    tic = time.time()
+    for i in range(num_batches):
+        with mx.autograd.record():
+            res0 = layer0(rnn_data, states)
+        res0.backward()
+        mx.nd.waitall()
+    print("Imperative training takes " + str(time.time() - tic))
+
+    tic = time.time()
+    for i in range(num_batches):
+        with mx.autograd.record():
+            res1 = layer1(rnn_data, states)
+        res1.backward()
+        mx.nd.waitall()
+    print("Hybrid-cell training takes " + str(time.time() - tic))
+
+    tic = time.time()
+    for i in range(num_batches):
+        with mx.autograd.record():
+            res3 = layer3(rnn_data, states)
+        res3.backward()
+        mx.nd.waitall()
+    print("Static-hybrid-cell training takes " + str(time.time() - tic))
+
+    tic = time.time()
+    for i in range(num_batches):
+        with mx.autograd.record():
+            res2 = layer2(rnn_data, states)
+        res2.backward()
+        mx.nd.waitall()
+    print("Hybrid training takes " + str(time.time() - tic))
+
+    # gradients for the backward of the foreach symbol
+    args_grad1 = {}
+    for key in args1.keys():
+        args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
+    exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
+    tic = time.time()
+    for i in range(num_batches):
+        exe.forward(is_train=True)
+        exe.backward(res2)
+        mx.nd.waitall()
+    print("Symbol training takes " + str(time.time() - tic))
+    print("")
+
+if __name__ == '__main__':
+    ndim = 512
+    seq_len = 100
+    batch_sizes = [1, 32]
+    cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
+             gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
+    ctxs = [mx.cpu(0), mx.gpu(0)]
+    for cell in cells:
+        for ctx in ctxs:
+            for batch_size in batch_sizes:
+                if len(get_gpus()) == 0 and ctx == mx.gpu(0):
+                    continue
+
+                if isinstance(cell, gluon.rnn.GRUCell):
+                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
+                                            ctx=mx.cpu(0))
+                    states = []
+                    states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
+                                               ctx=mx.cpu(0)))
+                elif isinstance(cell, gluon.rnn.LSTMCell):
+                    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
+                                            ctx=mx.cpu(0))
+                    states = []
+                    states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
+                                               ctx=mx.cpu(0)))
+                    states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
+                                               ctx=mx.cpu(0)))
+                if ctx == mx.gpu(0):
+                    dev = "GPU"
+                else:
+                    dev = "CPU"
+                print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev,
+                                                                   batch_size))
+                benchmark_rnn(cell, rnn_data, states)
diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md
index b017c601208..36a2c151e85 100644
--- a/docs/api/python/ndarray/contrib.md
+++ b/docs/api/python/ndarray/contrib.md
@@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
     fft
     ifft
     quantize
+    foreach
 ```
 
 ## API Reference
diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md
index f2bb3f15dee..66471656050 100644
--- a/docs/api/python/symbol/contrib.md
+++ b/docs/api/python/symbol/contrib.md
@@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
     fft
     ifft
     quantize
+    foreach
 ```
 
 ## API Reference
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 4dd858a51c4..6c7626b917a 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1051,6 +1051,28 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
  */
 MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
                                           const char **name);
+
+/*!
+ * \brief Get the input symbols of the graph.
+ * \param sym The graph.
+ * \param inputs The input symbols of the graph.
+ * \param input_size the number of input symbols returned.
+ */
+MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs,
+                                      int *input_size);
+
+/*!
+ * \brief Cut a subgraph whose nodes are marked with a subgraph attribute.
+ * The input graph will be modified. A variable node will be created for each
+ * edge that connects to nodes outside the subgraph. The outside nodes that
+ * connect to the subgraph will be returned.
+ * \param sym The graph.
+ * \param inputs The nodes that connect to the subgraph.
+ * \param input_size The number of such nodes.
+ */
+MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs,
+                                  int *input_size);
+
 /*!
  * \brief Get the detailed information about atomic symbol.
  * \param creator the AtomicSymbolCreator.
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index f4694efad29..2bb2462d486 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -64,8 +64,10 @@ enum OpReqType {
  * \sa Resource
  */
 struct OpContext {
+  /*! \brief whether there is a backward phase to compute gradients. */
+  bool need_grad;
   /*! \brief whether it is training phase */
-  int is_train;
+  bool is_train;
   /*! \brief RunContext related resources */
   RunContext run_ctx;
   /*! \brief the callback when operation completes, used by asynchronize ops */
@@ -98,7 +100,12 @@ enum class ExecType {
    *  In current implementation, copy operator is specially handled by executor.
    *  This flag is used for special case treatment and future extension of different copy ops.
    */
-  kCrossDeviceCopy
+  kCrossDeviceCopy,
+  /*!
+   * \brief A subgraph execution should happen in the main thread, instead of
+   *  in the execution engine.
+   */
+  kSubgraphExec,
 };
 
 /*! \brief the dispatch mode of the operator */
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index ba402e6f3f8..3914b53825a 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -21,6 +21,8 @@
 import math
 from ..context import current_context
 from ..random import uniform
+from ..base import _as_list
+from . import ndarray
 try:
     from .gen_contrib import *
 except ImportError:
@@ -96,3 +98,97 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
     expected_count_sampled = expected_prob_sampled * num_sampled
     return sampled_classes, expected_count_true, expected_count_sampled
 # pylint: enable=line-too-long
+
+def foreach(body, data, init_states):
+    """Run a for loop with user-defined computation over NDArrays on dimension 0.
+
+    This operator simulates a for loop and body has the computation for an iteration
+    of the for loop. It runs the computation in body on each slice from the input
+    NDArrays.
+
+    body takes two arguments as input and outputs a tuple of two elements,
+    as illustrated below:
+
+    out, states = body(data1, states)
+
+    data1 can be either an NDArray or a list of NDArrays. If data is an NDArray,
+    data1 is an NDArray. Otherwise, data1 is a list of NDArrays and has the same
+    size as data. states is a list of NDArrays and have the same size as init_states.
+    Similarly, out can be either an NDArray or a list of NDArrays, which are concatenated
+    as the first output of foreach; states from the last execution of body
+    are the second output of foreach.
+
+    The computation done by this operator is equivalent to the pseudo code below
+    when the input data is NDArray:
+
+    states = init_states
+    outs = []
+    for i in data.shape[0]:
+        s = data[i]
+        out, states = body(s, states)
+        outs.append(out)
+    outs = stack(*outs)
+
+
+    Parameters
+    ----------
+    body : a Python function.
+        Define computation in an iteration.
+    data: an NDArray or a list of NDArrays.
+        The input data.
+    init_states: an NDArray or a list of NDArrays.
+        The initial values of the loop states.
+    name: string.
+        The name of the operator.
+
+    Returns
+    -------
+    outputs: an NDArray or a list of NDArrays.
+        The output data concatenated from the output of all iterations.
+    states: a list of NDArrays.
+        The loop states in the last iteration.
+
+    Examples
+    --------
+    >>> step = lambda data, states: (data + states[0], [states[0] * 2])
+    >>> data = mx.nd.random.uniform(shape=(2, 10))
+    >>> states = [mx.nd.random.uniform(shape=(10))]
+    >>> outs, states = mx.nd.contrib.foreach(step, data, states)
+    """
+
+    def check_input(inputs, in_type, msg):
+        is_NDArray_or_list = True
+        if isinstance(inputs, list):
+            for i in inputs:
+                if not isinstance(i, in_type):
+                    is_NDArray_or_list = False
+                    break
+        else:
+            is_NDArray_or_list = isinstance(inputs, in_type)
+        assert is_NDArray_or_list, msg
+
+    check_input(data, ndarray.NDArray, "data should be an NDArray or a list of NDArrays")
+    check_input(init_states, ndarray.NDArray,
+                "init_states should be an NDArray or a list of NDArrays")
+
+    not_data_list = isinstance(data, ndarray.NDArray)
+    num_iters = data.shape[0] if not_data_list else data[0].shape[0]
+    states = init_states
+    outputs = []
+    for i in range(num_iters):
+        if not_data_list:
+            eles = data[i]
+        else:
+            eles = [d[i] for d in data]
+        outs, states = body(eles, states)
+        outs = _as_list(outs)
+        outputs.append(outs)
+    outputs = zip(*outputs)
+    tmp_outputs = []
+    for out in outputs:
+        tmp_outputs.append(ndarray.op.stack(*out))
+    outputs = tmp_outputs
+
+    if not_data_list and len(outputs) == 1:
+        outputs = outputs[0]
+    return (outputs, states)
diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 83e90e68732..4c03753e183 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -19,6 +19,9 @@
 # pylint: disable=wildcard-import, unused-wildcard-import
 """Contrib Symbol API of MXNet."""
 import math
+import ctypes
+import copy
+
 from .random import uniform
 from .symbol import Symbol
 try:
@@ -26,7 +29,12 @@
 except ImportError:
     pass
 
-__all__ = ["rand_zipfian"]
+from . import symbol
+from ..base import _LIB, check_call
+from ..base import SymbolHandle, _as_list
+from ..attribute import AttrScope
+
+__all__ = ["rand_zipfian", "foreach"]
 
 def rand_zipfian(true_classes, num_sampled, range_max):
     """Draw random samples from an approximately log-uniform or Zipfian distribution.
@@ -91,3 +99,236 @@ def rand_zipfian(true_classes, num_sampled, range_max):
     expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
     expected_count_sampled = expected_prob_sampled * num_sampled
     return sampled_classes, expected_count_true, expected_count_sampled
+
+def _get_graph_inputs(subg):
+    num_handles = ctypes.c_int(0)
+    handles = ctypes.POINTER(SymbolHandle)()
+    check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, ctypes.byref(handles),
+                                            ctypes.byref(num_handles)))
+
+    syms = []
+    for i in range(num_handles.value):
+        s = Symbol(ctypes.cast(handles[i], SymbolHandle))
+        syms.append(s)
+    return syms
+
+def _cut_subgraph(subg):
+    num_handles = ctypes.c_int(0)
+    handles = ctypes.POINTER(SymbolHandle)()
+    check_call(_LIB.MXSymbolCutSubgraph(subg.handle, ctypes.byref(handles),
+                                        ctypes.byref(num_handles)))
+
+    syms = []
+    for i in range(num_handles.value):
+        s = Symbol(ctypes.cast(handles[i], SymbolHandle))
+        syms.append(s)
+    return syms
+
+# This construct a subgraph for given output nodes.
+# If an output node is one of the input nodes, we call identity to make sure
+# that outputs nodes are different from input nodes.
+def _construct_subgraph(sym_out, sym_states):
+    sym_out = _as_list(sym_out)
+    sym_states = _as_list(sym_states)
+    all_outputs = []
+    all_outputs.extend(sym_out)
+    all_outputs.extend(sym_states)
+    g = symbol.Group(all_outputs)
+
+    flat_out = []
+    all_input_names = g.list_inputs()
+    output_names = [o.name for o in sym_out]
+    for o in sym_out:
+        if o.name in all_input_names:
+            flat_out.append(symbol.op.identity(o))
+        else:
+            flat_out.append(o)
+
+    for s in sym_states:
+        if s.name in all_input_names or s.name in output_names:
+            # There is a problem if the outputs are the same as the inputs
+            # or the first output. By calling identity, we can make sure that
+            # all symbols will refer to different NDArrays.
+            flat_out.append(symbol.op.identity(s))
+        else:
+            flat_out.append(s)
+    return symbol.Group(flat_out)
+
+def foreach(body, data, init_states, name="foreach"):
+    """Run a for loop with user-defined computation over Symbols on dimension 0.
+
+    This operator simulates a for loop and body has the computation for an iteration
+    of the for loop. It runs the computation in body on each slice from the input
+    NDArrays.
+
+    body takes two arguments as input and outputs a tuple of two elements,
+    as illustrated below:
+
+    out, states = body(data1, states)
+
+    data1 can be either a symbol or a list of symbols. If data is a symbol,
+    data1 is a symbol. Otherwise, data1 is a list of symbols and has the same
+    size as data. states is a list of symbols and have the same size as init_states.
+    Similarly, out can be either a symbol or a list of symbols, which are concatenated
+    as the first output of foreach; states from the last execution of body
+    are the second output of foreach.
+
+    The computation done by this operator is equivalent to the pseudo code below
+    when the input data is NDArray:
+
+    states = init_states
+    outs = []
+    for i in data.shape[0]:
+        s = data[i]
+        out, states = body(s, states)
+        outs.append(out)
+    outs = stack(*outs)
+
+
+    Parameters
+    ----------
+    body : a Python function.
+        Define computation in an iteration.
+    data: a symbol or a list of symbols.
+        The input data.
+    init_states: a symbol or a list of symbols.
+        The initial values of the loop states.
+    name: string.
+        The name of the operator.
+
+    Returns
+    -------
+    outputs: a Symbol or a list of Symbols.
+        The output data concatenated from the output of all iterations.
+    states: a list of Symbols.
+        The loop states in the last iteration.
+
+    Examples
+    --------
+    >>> step = lambda data, states: (data + states[0], [states[0] * 2])
+    >>> data = mx.sym.var('data')
+    >>> states = [mx.sym.var('state')]
+    >>> outs, states = mx.sym.contrib.foreach(step, data, states)
+    """
+
+    def check_data(inputs, in_type, msg):
+        is_NDArray_or_list = True
+        if isinstance(inputs, list):
+            for i in inputs:
+                if not isinstance(i, in_type):
+                    is_NDArray_or_list = False
+                    break
+        else:
+            is_NDArray_or_list = isinstance(inputs, in_type)
+        assert is_NDArray_or_list, msg
+
+    check_data(data, symbol.Symbol, "data should be a symbol or a list of symbols")
+    check_data(init_states, symbol.Symbol, "init_states should be a symbol or a list of symbols")
+    not_state_list = isinstance(init_states, symbol.Symbol)
+
+    # If the input python function references to the symbols outside
+    # the python function, we need to prune the computation graph constructed from
+    # the function. One way of doing it is to mark the nodes in the computation graph
+    # with AttrScope and prune the nodes without the special attribute.
+    with AttrScope(__subgraph_name__=name):
+        if isinstance(data, list):
+            in_eles = [symbol.var(sym.name) for sym in data]
+        else:
+            in_eles = symbol.var(data.name)
+        if isinstance(init_states, list):
+            states = [symbol.var(s.name) for s in init_states]
+        else:
+            states = symbol.var(init_states.name)
+        sym_out, sym_states = body(in_eles, states)
+
+        check_data(sym_out, symbol.Symbol,
+                   "the output should be an NDArray or a list of NDArrays")
+        check_data(sym_states, symbol.Symbol,
+                   "the output states should be an NDArray or a list of NDArrays")
+        if isinstance(sym_states, list):
+            assert isinstance(init_states, list) and len(sym_states) == len(init_states), \
+                    "the number of output states (%d) should be the same as input states (%d)" \
+                    % (len(sym_states), len(init_states))
+        num_out_data = len(sym_out)
+        num_states = len(sym_states)
+        num_outputs = num_out_data + num_states
+        g = _construct_subgraph(sym_out, sym_states)
+
+    input_syms = _get_graph_inputs(g)
+    cut_syms = _cut_subgraph(g)
+    input_syms = _get_graph_inputs(g)
+
+    # Here we need to find out how the input symbols are ordered as well as
+    # where the loop states are located in the list of inputs.
+
+    # This dict contains the symbols of the subgraph.
+    input_syms = {sym.name:sym for sym in input_syms}
+    gin_names = input_syms.keys()
+    # This array contains the symbols for the inputs of foreach.
+    # They are ordered according to the inputs of the subgraph.
+    init_states = _as_list(init_states)
+    state_names = [sym.name for sym in init_states]
+    data_syms = _as_list(data)
+    data_names = [sym.name for sym in data_syms]
+    cut_var_map = {sym.list_outputs()[0]:sym for sym in cut_syms}
+    cut_var_names = cut_var_map.keys()
+
+    subg_input_names = g.list_inputs()
+    # ordered_ins contains input symbols in the following order:
+    # data_syms, state_syms, followed by cut_vars and vars in the closure.
+    ordered_ins = data_syms
+    # this defines the location of data_syms in the list of subgraph inputs
+    in_data_locs = []
+    for dname in data_names:
+        # Some data may not be used.
+        if dname in subg_input_names:
+            in_data_locs.append(subg_input_names.index(dname))
+        else:
+            raise AssertionError("the data arrays have to be used in the loop body")
+
+    ordered_ins.extend(init_states)
+    # this defines the location of state_syms in the list of subgraph inputs.
+    in_state_locs = []
+    for sname in state_names:
+        # Some state may not be used.
+        if sname in subg_input_names:
+            in_state_locs.append(subg_input_names.index(sname))
+        else:
+            raise AssertionError("the state arrays have to be used in the loop body")
+
+    remain_locs = []
+    for in_name in subg_input_names:
+        assert in_name in gin_names, "The input variable %s can't be found in graph inputs: %s" \
+                % (in_name, str(gin_names))
+        if in_name in cut_var_names:
+            ordered_ins.append(cut_var_map[in_name])
+            remain_locs.append(subg_input_names.index(in_name))
+        elif in_name not in data_names and in_name not in state_names:
+            # The remaining inputs are the variable nodes created inside the UDF.
+            # The subgraph can't have nodes shared with the main graph. As such,
+            # we need to make a copy of these variable nodes.
+            assert in_name in gin_names
+            ordered_ins.append(copy.deepcopy(input_syms[in_name]))
+            remain_locs.append(subg_input_names.index(in_name))
+
+    ret = symbol._internal._foreach(g, *ordered_ins, num_outputs=num_outputs,
+                                    num_out_data=num_out_data, in_state_locs=in_state_locs,
+                                    in_data_locs=in_data_locs, remain_locs=remain_locs)
+    if num_outputs - num_states > 1:
+        outs = []
+        for i in range(num_outputs - num_states):
+            outs.append(ret[i])
+    elif num_outputs - num_states == 1:
+        outs = ret[0]
+    else:
+        outs = []
+    states = []
+    for i in range(num_states):
+        states.append(ret[num_outputs - num_states + i])
+
+    if not_state_list:
+        # If there is only one input state, there should be only one output state.
+        assert len(states) == 1
+        states = states[0]
+
+    return (outs, states)
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 9d51ddcb674..ca50a741012 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -33,7 +33,7 @@ private[mxnet] object CToScalaUtils {
       case "double" | "doubleorNone" => "Double"
       case "string" => "String"
       case "boolean" | "booleanorNone" => "Boolean"
-      case "tupleof<float>" | "tupleof<double>" | "ptr" | "" => "Any"
+      case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => "Any"
       case default => throw new IllegalArgumentException(
         s"Invalid type for args: $default, $argType")
     }
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index e5e9b522890..c27a59a67c6 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -344,6 +344,77 @@ int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
   API_END();
 }
 
+namespace mxnet {
+
+extern std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym);
+extern bool CutGraphInputs(const std::vector<nnvm::NodeEntry *> &input_entries,
+                           bool skip_var, std::vector<nnvm::NodeEntry> *orig_entries);
+
+}
+
+int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) {
+  API_BEGIN();
+  nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
+  std::vector<nnvm::Symbol *> input_syms = mxnet::GetInputSymbols(*s);
+  *input_size = input_syms.size();
+
+  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  ret->ret_handles.clear();
+  ret->ret_handles.reserve(*input_size);
+  for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
+  *input_arr = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
+  API_END_HANDLE_ERROR();
+}
+
+int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
+                        int *input_size) {
+  // Given a graph, we want to fetch the nodes that have been marked as part of
+  // a subgraph.
+  API_BEGIN();
+  nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
+  std::string subg_attr = "__subgraph_name__";
+  auto out_node = s->outputs[0].node;
+  auto it = out_node->attrs.dict.find(subg_attr);
+  if (it != out_node->attrs.dict.end()) {
+    std::string subg_name = it->second;
+    std::vector<nnvm::NodeEntry *> input_entries;
+    DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries]
+             (nnvm::NodePtr n) {
+      // If the node itself isn't in the subgraph, we ignore it.
+      auto it = n->attrs.dict.find(subg_attr);
+      if (it == n->attrs.dict.end() || it->second != subg_name)
+        return;
+
+      // We search for nodes whose node entries aren't in the subgraph.
+      for (size_t j = 0; j < n->inputs.size(); j++) {
+        auto in_node = n->inputs[j].node;
+        auto it = in_node->attrs.dict.find(subg_attr);
+        if (it == in_node->attrs.dict.end() || it->second != subg_name)
+          input_entries.push_back(&n->inputs[j]);
+      }
+    });
+
+    std::vector<nnvm::NodeEntry> orig_entries;
+    CutGraphInputs(input_entries, false, &orig_entries);
+    std::vector<nnvm::Symbol *> input_syms(orig_entries.size());
+    for (size_t i = 0; i < input_syms.size(); i++) {
+      input_syms[i] = new nnvm::Symbol();
+      input_syms[i]->outputs.push_back(orig_entries[i]);
+    }
+    *input_size = input_syms.size();
+
+    MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+    ret->ret_handles.clear();
+    ret->ret_handles.reserve(*input_size);
+    for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]);
+    *input_symbols = reinterpret_cast<SymbolHandle*>(dmlc::BeginPtr(ret->ret_handles));
+  } else {
+    *input_size = 0;
+  }
+
+  API_END_HANDLE_ERROR();
+}
+
 int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) {
   nnvm::Symbol *s = new nnvm::Symbol();
   API_BEGIN();
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 831b5f90023..7386de4d12e 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -39,6 +39,7 @@ namespace exec {
 
 GraphExecutor::GraphExecutor() {
   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
+  need_grad_ = false;
 }
 
 GraphExecutor::~GraphExecutor() {
@@ -257,11 +258,11 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
 
   nnvm::Graph g;
   g.outputs = symbol.outputs;
-  bool need_grad = false;
+  need_grad_ = false;
   for (OpReqType req : grad_req_types) {
-    if (req != kNullOp) need_grad = true;
+    if (req != kNullOp) need_grad_ = true;
   }
-  if (!need_grad) return g;
+  if (!need_grad_) return g;
   for (size_t i = 0; i < g.outputs.size(); ++i) {
     NodeEntry ngrad{nnvm::Node::Create(), 0, 0};
     head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i]));
@@ -1591,6 +1592,7 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
     const auto& inode = idx[nid];
     if (inode.source->is_variable()) continue;
     opnode.exec->op_ctx.is_train = is_train;
+    opnode.exec->op_ctx.need_grad = need_grad_;
   }
 
   // Push Ops
@@ -1609,11 +1611,15 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
     OpNode& opnode = op_nodes_[nid];
     if (op_nodes_[nid].skip_exec_node) continue;
     opnode.exec->op_ctx.is_train = is_train;
+    opnode.exec->op_ctx.need_grad = need_grad_;
     if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
       CHECK_EQ(inode.inputs.size(), 1U);
       CHECK_EQ(opnode.exec->in_array.size(), 1U);
       CHECK_EQ(opnode.exec->out_array.size(), 1U);
       CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
+    } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) {
+      // If the node contains a subgraph, we can't execute it in the engine.
+      opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
     } else if (opnode.cached_opr != nullptr) {
       bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
       Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 24f98894912..bfc415b4526 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -213,6 +213,8 @@ class GraphExecutor : public Executor {
   // perform bulking and segmentation on a training graph
   void BulkTrainingOpSegs(size_t total_num_nodes);
 
+  // indicate whether there is a backward graph for gradients.
+  bool need_grad_;
   // internal graph
   nnvm::Graph graph_;
   // operator node
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5a3d44c04ce..5e48c5a26f7 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -591,6 +591,7 @@ void CachedOp::StaticRunOps(
     const Context& default_ctx,
     const nnvm::Graph& g,
     const OpStatePtr& state_ptr,
+    const std::vector<NDArray *> &state_arrays,
     size_t start_nid,
     size_t end_nid) {
   static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
@@ -624,7 +625,7 @@ void CachedOp::StaticRunOps(
       ndinputs.clear();
       ndinputs.reserve(node.inputs.size());
       for (const auto& j : node.inputs) {
-        ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
+        ndinputs.emplace_back(state_arrays[idx.entry_id(j)]);
         CHECK(!ndinputs.back()->is_none());
       }
       ndoutputs.clear();
@@ -633,7 +634,7 @@ void CachedOp::StaticRunOps(
       req.reserve(num_outputs);
       for (size_t j = 0; j < num_outputs; ++j) {
         size_t eid = idx.entry_id(i, j);
-        ndoutputs.emplace_back(state.arrays[eid]);
+        ndoutputs.emplace_back(state_arrays[eid]);
         req.push_back(state.array_reqs[eid]);
         CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
       }
@@ -688,25 +689,29 @@ OpStatePtr CachedOp::StaticForward(
     StaticAllocMemory(state_ptr, recording, false);
   }
 
+  // We are going to add input and output arrays to the array list.
+  // The input and output arrays should only be valid for this run,
+  // so we shouldn't modify the state's array list.
+  auto arrays = state.arrays;
   if (config_.static_shape) {
     for (auto i : config_.param_indices) {
       auto nid = idx.input_nodes()[i];
-      if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
+      if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
         match = false;
         auto ptr = &state.buff[idx.entry_id(nid, 0)];
-        CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
-        *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
+        CHECK_EQ(arrays[idx.entry_id(nid, 0)], ptr);
+        *arrays[idx.entry_id(nid, 0)] = *inputs[i];
         state.dynamic_entries[idx.entry_id(nid, 0)] = false;
       }
     }
     for (auto i : config_.data_indices) {
       auto eid = idx.entry_id(idx.input_nodes()[i], 0);
-      state.arrays[eid] = inputs[i];
+      arrays[eid] = inputs[i];
     }
   } else {
     for (size_t i = 0; i < num_inputs(); ++i) {
       auto nid = idx.input_nodes()[i];
-      state.arrays[idx.entry_id(nid, 0)] = inputs[i];
+      arrays[idx.entry_id(nid, 0)] = inputs[i];
     }
   }
 
@@ -720,13 +725,16 @@ OpStatePtr CachedOp::StaticForward(
 
   for (size_t i = 0; i < outputs.size(); ++i) {
     auto eid = idx.entry_id(idx.outputs()[i]);
-    state.arrays[eid] = outputs[i];
+    // An input and an output may share the same array.
+    if (!arrays[eid]->is_none())
+      *outputs[i] = arrays[eid]->Detach();
+    arrays[eid] = outputs[i];
     if (!outputs[i]->is_none()) continue;
     *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
                           shapes[eid], default_ctx, true, dtypes[eid]);
   }
 
-  StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
+  StaticRunOps(default_ctx, g, state_ptr, arrays, 0, idx.num_nodes());
 
   return recording ? state_ptr : OpStatePtr();
 }
@@ -810,7 +818,7 @@ OpStatePtr CachedOp::DynamicForward(
   return op_state;
 }
 
-void CachedOp::Forward(
+OpStatePtr CachedOp::Forward(
     const std::shared_ptr<CachedOp>& op_ptr,
     const std::vector<NDArray*>& inputs,
     const std::vector<NDArray*>& outputs) {
@@ -850,6 +858,7 @@ void CachedOp::Forward(
         std::move(attrs), inputs, outputs, op_state,
         &save_inputs(), &save_outputs());
   }
+  return op_state;
 }
 
 
@@ -891,7 +900,11 @@ void CachedOp::DynamicBackward(
   }
   for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
     if (reqs[i] == kNullOp) continue;
-    arrays[idx.entry_id(idx.outputs()[j++])] = outputs[i];
+    auto eid = idx.entry_id(idx.outputs()[j++]);
+    // An input and an output may share the same array.
+    if (!arrays[eid]->is_none())
+      *outputs[i] = arrays[eid]->Detach();
+    arrays[eid] = outputs[i];
   }
 
   // Allocate NDArrays
@@ -952,6 +965,15 @@ void CachedOp::StaticBackward(
     StaticAllocMemory(state_ptr, true, true);
   }
 
+  // We are going to add input and output arrays to the array list.
+  // The input and output arrays should only be valid for this run,
+  // so we shouldn't modify the state's array list.
+  auto arrays = state.arrays;
+  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
+    auto eid = state.info.bwd_input_eid[i];
+    if (state.dynamic_entries[eid]) arrays[eid] = inputs[i];
+  }
+
   if (config_.static_shape) {
     for (auto i : config_.param_indices) {
       const auto iter = fwd_input_to_grad_output_.find(i);
@@ -959,11 +981,14 @@ void CachedOp::StaticBackward(
       auto entry = grad_graph_.outputs[iter->second];
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
-      if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
+      if (!arrays[eid]->IsSame(*outputs[iter->second]) ||
           !(state.array_reqs[eid] == reqs[iter->second])) {
         match = false;
         state.array_reqs[eid] = reqs[iter->second];
-        *state.arrays[eid] = *outputs[iter->second];
+        // An input and an output may share the same array.
+        if (!arrays[eid]->is_none())
+          *outputs[iter->second] = arrays[eid]->Detach();
+        *arrays[eid] = *outputs[iter->second];
         state.dynamic_entries[eid] = false;
       }
     }
@@ -974,7 +999,10 @@ void CachedOp::StaticBackward(
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
       state.array_reqs[eid] = reqs[iter->second];
-      state.arrays[eid] = outputs[iter->second];
+      // An input and an output may share the same array.
+      if (!arrays[eid]->is_none())
+        *outputs[iter->second] = arrays[eid]->Detach();
+      arrays[eid] = outputs[iter->second];
     }
   } else {
     for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
@@ -982,7 +1010,10 @@ void CachedOp::StaticBackward(
       if (!idx.exist(entry.node.get())) continue;
       auto eid = idx.entry_id(entry);
       state.array_reqs[eid] = reqs[i];
-      state.arrays[eid] = outputs[i];
+      // An input and an output may share the same array.
+      if (!arrays[eid]->is_none())
+        *outputs[i] = arrays[eid]->Detach();
+      arrays[eid] = outputs[i];
     }
   }
 
@@ -990,12 +1021,7 @@ void CachedOp::StaticBackward(
     StaticInitExec(state_ptr, true, true);
   }
 
-  for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
-    auto eid = state.info.bwd_input_eid[i];
-    if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
-  }
-
-  StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
+  StaticRunOps(default_ctx, g, state_ptr, arrays, num_forward_nodes, idx.num_nodes());
 }
 
 void CachedOp::Backward(
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 6b94c67a94e..4f4dfdcc14d 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -92,7 +92,7 @@ class CachedOp {
   std::vector<nnvm::NodeEntry> Gradient(
       const nnvm::NodePtr& node,
       const std::vector<nnvm::NodeEntry>& ograds) const;
-  void Forward(
+  OpStatePtr Forward(
       const std::shared_ptr<CachedOp>& op_ptr,
       const std::vector<NDArray*>& inputs,
       const std::vector<NDArray*>& outputs);
@@ -154,6 +154,7 @@ class CachedOp {
       const Context& default_ctx,
       const nnvm::Graph& g,
       const OpStatePtr& state_ptr,
+      const std::vector<NDArray *> &state_arrays,
       size_t start_nid,
       size_t end_nid);
   OpStatePtr StaticForward(
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index faff5f173fe..2331d7be155 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -373,6 +373,7 @@ inline void PushFCompute(const FCompute& fn,
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
 
   bool is_train = Imperative::Get()->is_training();
+  bool need_grad = Imperative::Get()->is_recording();
   ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
   CHECK(exec_type == ExecType::kSync);
   std::vector<NDArray> inputs, outputs;
@@ -393,7 +394,7 @@ inline void PushFCompute(const FCompute& fn,
                              &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst,
                              &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx);
       // setup context
-      OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+      OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested};
       bool is_gpu = ctx.dev_mask() == gpu::kDevMask;
       // pre-fcompute fallback, cast to default storage type
       CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu);
@@ -420,11 +421,12 @@ inline void PushFComputeEx(const FComputeEx& fn,
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
 
   bool is_train = Imperative::Get()->is_training();
+  bool need_grad = Imperative::Get()->is_recording();
   ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
   const auto& run = [=](RunContext rctx) {
-      OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+      OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested};
 #if MXNET_USE_MKLDNN == 1
       InvalidateOutputs(outputs, req);
 #endif
@@ -459,6 +461,7 @@ inline void PushOperator(const OpStatePtr& state,
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
 
   bool is_train = Imperative::Get()->is_training();
+  bool need_grad = Imperative::Get()->is_recording();
   ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
@@ -470,17 +473,23 @@ inline void PushOperator(const OpStatePtr& state,
   if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
     const auto& run = [=](RunContext rctx,
                           engine::CallbackOnComplete on_complete) {
-      OpContext opctx{is_train, rctx, on_complete, requested};
+      OpContext opctx{need_grad, is_train, rctx, on_complete, requested};
 #if MXNET_USE_MKLDNN == 1
       InvalidateOutputs(outputs, req);
 #endif
       fcompute_ex(state, opctx, inputs, req, outputs);
-      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
+          && rctx.get_stream<gpu>()) {
         rctx.get_stream<gpu>()->Wait();
       }
     };
 
-    if (exec_type == ExecType::kSync) {
+    // For operators with subgraphs, we need to invoke them in the main thread
+    // instead of the threaded engine.
+    if (exec_type == ExecType::kSubgraphExec) {
+      RunContext rctx{ctx, nullptr};
+      run(rctx, engine::CallbackOnComplete());
+    } else if (exec_type == ExecType::kSync) {
       Engine::Get()->PushSync(
           [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
           ctx, read_vars, write_vars, FnProperty::kNormal, 0,
@@ -497,7 +506,7 @@ inline void PushOperator(const OpStatePtr& state,
         << "for stateful operator " << op->name;
 
     const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) {
-        OpContext opctx{is_train, rctx, on_complete, requested};
+        OpContext opctx{need_grad, is_train, rctx, on_complete, requested};
 
         std::vector<TBlob> input_blobs, output_blobs;
         // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
@@ -519,12 +528,16 @@ inline void PushOperator(const OpStatePtr& state,
         fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
         // post-fcompute fallback, cast to original storage type, if necessary
         CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-        if (is_gpu && exec_type == ExecType::kSync) {
+        if (is_gpu && exec_type == ExecType::kSync
+            && rctx.get_stream<gpu>()) {
           rctx.get_stream<gpu>()->Wait();
         }
       };
 
-    if (exec_type == ExecType::kSync) {
+    if (exec_type == ExecType::kSubgraphExec) {
+      RunContext rctx{ctx, nullptr};
+      run(rctx, engine::CallbackOnComplete());
+    } else if (exec_type == ExecType::kSync) {
       Engine::Get()->PushSync(
           [=](RunContext rctx) {
             run(rctx, engine::CallbackOnComplete());
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index e90fb6319d7..0b2beed3391 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1109,7 +1109,8 @@ void CopyFromToImpl(const NDArray& from, const NDArray& to,
   const Context to_ctx = to.ctx();
   bool is_train = Imperative::Get()->is_training();
 
-  OpContext opctx{is_train,
+  OpContext opctx{Imperative::Get()->is_recording(),
+                  is_train,
                   rctx,
                   engine::CallbackOnComplete(),
                   requested};
diff --git a/src/nnvm/graph_editor.cc b/src/nnvm/graph_editor.cc
new file mode 100644
index 00000000000..1dee3c14ee4
--- /dev/null
+++ b/src/nnvm/graph_editor.cc
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_editor.cc
+ * The functions in this file edit an NNVM graph. Potentially,
+ * these functions should be moved to NNVM in the future.
+ */
+
+#include <nnvm/symbolic.h>
+#include <nnvm/graph.h>
+#include <nnvm/node.h>
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+/*
+ * Given a computation graph, this function finds the input nodes of the graph
+ * and create symbols for the input nodes. It returns the input symbols.
+ */
+std::vector<nnvm::Symbol *> GetInputSymbols(const nnvm::Symbol &sym) {
+  nnvm::Graph g;
+  std::vector<nnvm::Symbol *> input_syms;
+  g.outputs = sym.outputs;
+  const nnvm::IndexedGraph& idx = g.indexed_graph();
+  // Go through all nodes and return the ones representing variables.
+  for (size_t i = 0; i < idx.num_nodes(); i++) {
+    const nnvm::Node &n = *idx[i].source;
+    for (const nnvm::NodeEntry &e : n.inputs) {
+      auto p = e.node;
+      if (p->is_variable()) {
+        nnvm::Symbol *s = new nnvm::Symbol();
+        s->outputs.push_back(e);
+        input_syms.push_back(s);
+      }
+    }
+  }
+  return input_syms;
+}
+
+/*
+ * Given a computation graph and a set of input node entries, this function cuts
+ * the node entries and creates new variable nodes as the input nodes of the
+ * subgraph. It returns the nodes that connect to the subgraph directly and
+ * the names of the new variable nodes.
+ */
+bool CutGraphInputs(const std::vector<nnvm::NodeEntry *> &input_entries,
+                    bool skip_var, std::vector<nnvm::NodeEntry> *orig_entries) {
+  struct pred_entry {
+    nnvm::NodeEntry e;
+    explicit pred_entry(const nnvm::NodeEntry &_e): e(_e) {}
+    bool operator()(const nnvm::NodeEntry e1) {
+      return e.node == e1.node && e.index == e1.index;
+    }
+  };
+
+  std::vector<nnvm::NodePtr> var_nodes;
+  orig_entries->clear();
+  orig_entries->reserve(input_entries.size());
+  for (size_t i = 0; i < input_entries.size(); i++) {
+    nnvm::NodeEntry *e = input_entries[i];
+    // If the node is a variable itself, we may want to skip the node.
+    if (e->node->is_variable() && skip_var)
+      continue;
+
+    auto it = std::find_if(orig_entries->begin(), orig_entries->end(),
+                           pred_entry(*e));
+    bool exist = (it != orig_entries->end());
+    orig_entries->push_back(*e);
+    nnvm::NodePtr n;
+    // If we haven't seen the entry before, we need to create a new var node
+    // for the node entry.
+    if (!exist) {
+      nnvm::Symbol sym;
+      sym.outputs.push_back(*e);
+      n = nnvm::CreateVariableNode(sym.ListOutputNames()[0]);
+    } else {
+      // Otherwise, we use the var node created before.
+      size_t idx = it - orig_entries->begin();
+      CHECK_LT(idx, var_nodes.size());
+      n = var_nodes[idx];
+    }
+    var_nodes.push_back(n);
+    *e = nnvm::NodeEntry{n, 0, 0};
+  }
+  return true;
+}
+
+}  // namespace mxnet
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
new file mode 100644
index 00000000000..c091fdb67e0
--- /dev/null
+++ b/src/operator/control_flow.cc
@@ -0,0 +1,545 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <mxnet/io.h>
+#include <mxnet/base.h>
+#include <mxnet/ndarray.h>
+#include <mxnet/operator.h>
+#include <mxnet/operator_util.h>
+#include <dmlc/logging.h>
+#include <dmlc/optional.h>
+#include "./operator_common.h"
+#include "./elemwise_op_common.h"
+#include "../imperative/imperative_utils.h"
+#include "./subgraph_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct ForeachParam : public dmlc::Parameter<ForeachParam> {
+  int num_args;
+  int num_outputs;
+  int num_out_data;
+  // The location of states in the subgraph inputs.
+  nnvm::Tuple<dim_t> in_state_locs;
+  // The location of data arrays in the subgraph inputs.
+  nnvm::Tuple<dim_t> in_data_locs;
+  // The location of remaining arrays in the subgraph inputs.
+  nnvm::Tuple<dim_t> remain_locs;
+  DMLC_DECLARE_PARAMETER(ForeachParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
+    .describe("Number of inputs.");
+    DMLC_DECLARE_FIELD(num_outputs)
+    .describe("The number of outputs of the subgraph.");
+    DMLC_DECLARE_FIELD(num_out_data)
+    .describe("The number of output data of the subgraph.");
+    DMLC_DECLARE_FIELD(in_state_locs)
+    .describe("The locations of loop states among the inputs.");
+    DMLC_DECLARE_FIELD(in_data_locs)
+    .describe("The locations of input data among the inputs.");
+    DMLC_DECLARE_FIELD(remain_locs)
+    .describe("The locations of remaining data among the inputs.");
+  }
+};  // struct ForeachParam
+
+DMLC_REGISTER_PARAMETER(ForeachParam);
+
+class ForeachState: public LoopState {
+ public:
+  ForeachParam params;
+  int num_iterations;
+
+  ForeachState(const Symbol &g, const ForeachParam &params) : LoopState(g) {
+    this->params = params;
+  }
+};
+
+static void ForeachComputeExCPU(const OpStatePtr& state_ptr,
+                                const OpContext& ctx,
+                                const std::vector<NDArray>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<NDArray>& outputs) {
+  ForeachState &state = state_ptr.get_state<ForeachState>();
+  const ForeachParam& params = state.params;
+  const size_t iter_dim = 0;
+  CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
+  CHECK_GT(params.in_data_locs.ndim(), 0);
+  size_t len = inputs[0].shape()[iter_dim];
+  state.num_iterations = len;
+  for (size_t i = 1; i < params.in_data_locs.ndim(); i++)
+    CHECK_EQ(inputs[i].shape()[iter_dim], len);
+  for (size_t i = 0; i < (size_t) params.num_out_data; i++)
+    CHECK_EQ(len, outputs[i].shape()[iter_dim]);
+  for (const auto &arr : outputs)
+    CHECK_EQ(arr.storage_type(), kDefaultStorage)
+        << "The for operator doesn't support the sparse format";
+
+  // Initialize the outputs of the subgraph is a little trickier.
+  // The states from the previous iteration are used as the inputs of the next
+  // iteration, so I have to maintain two arrays, so the inputs and outputs
+  // of the subgraph share the same memory.
+  std::vector<NDArray> subg_outputs1(outputs.size());
+  std::vector<NDArray> subg_outputs2(outputs.size());
+  std::vector<NDArray> *subg_outputs[2]{&subg_outputs1, &subg_outputs2};
+  // If the length is an odd number, the last iteration will use the first set
+  // of outputs. In this way, we don't need to copy the results from the
+  // subgraph to the final outputs of the loop.
+  if (len % 2 == 1) {
+    for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) {
+      subg_outputs1[i] = outputs[i];
+      subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
+                                 outputs[i].dtype());
+    }
+  } else {
+    // Otherwise, we'll use the second set of outputs.
+    for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) {
+      subg_outputs1[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
+                                 outputs[i].dtype());
+      subg_outputs2[i] = outputs[i];
+    }
+  }
+
+  // Initialize the inputs for the subgraph.
+  // In each iteration, we need to update the subgraph inputs for input data
+  // and the loop states.
+  std::vector<NDArray> subg_inputs(inputs.size());
+  // The remaining arrays (other than input data and states) only need to be set once.
+  for (size_t j = 0; j < params.remain_locs.ndim(); j++) {
+    CHECK_LT(params.remain_locs[j], subg_inputs.size());
+    subg_inputs[params.remain_locs[j]] = inputs[j + params.in_data_locs.ndim()
+        + params.in_state_locs.ndim()];
+  }
+
+  // Here we iterate over the first dimension of the first input array.
+  for (size_t i = 0; i < len; i++) {
+    // Initialize outputs for the subgraph.
+    std::vector<NDArray> *subg_out_curr = subg_outputs[i % 2];
+    std::vector<NDArray> *subg_out_prev = subg_outputs[(i + 1) % 2];
+    for (int j = 0; j < params.num_out_data; j++)
+      (*subg_out_curr)[j] = outputs[j].At(i);
+    // When recording for backward computation, we should make sure
+    // that output arrays are actually different in each iteration.
+    if (ctx.need_grad && i < len - 1) {
+      for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
+        (*subg_out_curr)[j] = NDArray(outputs[j].shape(), outputs[j].ctx(),
+                                      true, outputs[j].dtype());
+    } else if (ctx.need_grad && i == len - 1) {
+      // For the last iteration, we need to write data to the output array
+      // directly.
+      for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
+        (*subg_out_curr)[j] = outputs[j];
+    }
+
+    // Initialize inputs for the subgraph.
+    // Get a slice from the input data arrays.
+    for (size_t j = 0; j < params.in_data_locs.ndim(); j++) {
+      size_t loc = params.in_data_locs[j];
+      subg_inputs[loc] = inputs[j].At(i);
+    }
+    // For the rest of the iterations, the states are the outputs
+    // from the previous iteration.
+    if (i > 0) {
+      for (size_t j = params.num_out_data; j < subg_out_prev->size(); j++) {
+        size_t idx = j - params.num_out_data;
+        CHECK_LT(params.in_state_locs[idx], subg_inputs.size());
+        subg_inputs[params.in_state_locs[idx]] = (*subg_out_prev)[j];
+      }
+    } else {
+      for (size_t j = 0; j < params.in_state_locs.ndim(); j++) {
+        CHECK_LT(params.in_state_locs[j], subg_inputs.size());
+        subg_inputs[params.in_state_locs[j]] = inputs[j + params.in_data_locs.ndim()];
+      }
+    }
+
+    state.Forward(i, subg_inputs, req, *subg_out_curr, ctx.need_grad);
+  }
+}
+
+static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr,
+                                    const OpContext& ctx,
+                                    const std::vector<NDArray>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<NDArray>& outputs) {
+  ForeachState &state = state_ptr.get_state<ForeachState>();
+  const ForeachParam& params = state.params;
+  CHECK_EQ(outputs.size(), (size_t) params.num_args - 1);
+  CHECK_GT(params.in_data_locs.ndim(), 0);
+  for (const auto &arr : outputs)
+    CHECK_EQ(arr.storage_type(), kDefaultStorage)
+        << "The for operator doesn't support the sparse format";
+  int len = state.num_iterations;
+  size_t num_output_data = params.num_out_data;
+
+  // In backward computation, we need to run iterations from backwards.
+  std::vector<NDArray> subg_ograds(params.num_outputs);
+  std::vector<NDArray> subg_igrads(outputs.size());
+  for (size_t i = num_output_data; i < subg_ograds.size(); i++)
+    subg_ograds[i] = inputs[i];
+  std::vector<OpReqType> subg_req(req.size());
+  for (auto r : req)
+    CHECK_NE(r, kWriteInplace);
+
+  // There are three types of arrays in igrads.
+  // * data gradients.
+  // * loop variable gradients.
+  // * remaining variable gradients.
+  // They are in the following order:
+  // [data vars], [loop vars], [remaining vars]
+
+  // [remaining vars]
+  for (size_t i = 0; i < params.remain_locs.ndim(); i++) {
+    size_t loc = params.remain_locs[i];
+    size_t orig_loc = i + params.in_data_locs.ndim() + params.in_state_locs.ndim();
+    subg_igrads[loc] = outputs[orig_loc];
+    subg_req[loc] = req[orig_loc];
+  }
+
+  for (int iter_num = len - 1; iter_num >= 0; iter_num--) {
+    for (int i = 0; i < params.num_out_data; i++)
+      subg_ograds[i] = inputs[i].At(iter_num);
+    if (iter_num < len - 1) {
+      // For the rest of the iterations, we should add graidents to the
+      // remaining vars.
+      for (size_t i = 0; i < params.remain_locs.ndim(); i++) {
+        size_t loc = params.remain_locs[i];
+        subg_req[loc] = kAddTo;
+      }
+    }
+
+    // [data vars]
+    for (size_t i = 0; i < params.in_data_locs.ndim(); i++) {
+      size_t loc = params.in_data_locs[i];
+      subg_igrads[loc] = outputs[i].At(iter_num);
+      subg_req[loc] = req[i];
+    }
+    // [loop vars]
+    for (size_t i = 0; i < params.in_state_locs.ndim(); i++) {
+      size_t loc = params.in_state_locs[i];
+      const NDArray &output = outputs[i + params.in_data_locs.ndim()];
+      if (iter_num != 0) {
+        // For state gradients, we need to allocate new NDArrays
+        // because intermediate state gradients won't be returned to the users.
+        subg_igrads[loc] = NDArray(output.shape(), output.ctx(), true, output.dtype());
+      } else {
+        subg_igrads[loc] = output;
+      }
+      // For the first iteration, we need to use the request provided by
+      // the user to write state gradients to the outputs.
+      subg_req[loc] = iter_num != 0 ? kWriteTo : req[i + params.in_data_locs.ndim()];
+    }
+
+    state.Backward(iter_num, subg_ograds, subg_req, subg_igrads);
+
+    size_t num_states = subg_ograds.size() - num_output_data;
+    for (size_t i = 0; i < num_states; i++) {
+      size_t loc = params.in_state_locs[i];
+      CHECK_LT(loc, subg_igrads.size());
+      subg_ograds[i + num_output_data] = subg_igrads[loc];
+    }
+  }
+  state.Cleanup();
+}
+
+template<typename T>
+static void remap(const std::vector<T> &op_in, size_t start,
+                  const nnvm::Tuple<dim_t> &locs, std::vector<T> *subg_in) {
+  auto op_in_it = op_in.begin() + start;
+  for (size_t i = 0; i < locs.ndim(); i++) {
+    dim_t loc = locs[i];
+    subg_in->at(loc) = *(op_in_it + i);
+  }
+}
+
+static inline TShape SliceFirstDim(const TShape &s) {
+  if (s.ndim() > 1) {
+    return TShape(s.begin() + 1, s.end());
+  } else {
+    return TShape(mshadow::Shape1(1));
+  }
+}
+
+static bool ForeachShape(const nnvm::NodeAttrs& attrs,
+                         std::vector<TShape> *in_shape,
+                         std::vector<TShape> *out_shape) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 1U);
+
+  std::vector<TShape> subg_in_shape(in_shape->size());
+  // data shape
+  std::vector<bool> data_1d(params.in_data_locs.ndim(), false);
+  for (size_t i = 0; i < params.in_data_locs.ndim(); i++) {
+    size_t loc = params.in_data_locs[i];
+    if (in_shape->at(i).ndim() == 1)
+      data_1d[i] = true;
+    subg_in_shape[loc] = SliceFirstDim(in_shape->at(i));
+  }
+  // state shape
+  remap(*in_shape, params.in_data_locs.ndim(), params.in_state_locs,
+        &subg_in_shape);
+  // remaining shape
+  remap(*in_shape, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
+        params.remain_locs, &subg_in_shape);
+
+  std::vector<TShape> subg_out_shape = *out_shape;
+  for (int i = 0; i < params.num_out_data; i++) {
+    TShape shape = subg_out_shape[i];
+    // If we don't have shape info, we don't need to do anything.
+    if (shape.ndim() == 0)
+      continue;
+    subg_out_shape[i] = SliceFirstDim(shape);
+  }
+
+  bool infer_success = InferSubgraphShape(*attrs.subgraphs[0],
+                                          &subg_in_shape, &subg_out_shape);
+
+  // After inference, we need to move inferred information back to in_shape and
+  // out_shape.
+
+  // For the shape of output data.
+  size_t len = in_shape->at(0)[0];
+  CHECK_GT(len, 0);
+  for (int i = 0; i < params.num_out_data; i++) {
+    // If the output shape isn't inferred, we don't need to propogate the info.
+    const auto& g_out_shape = subg_out_shape[i];
+    if (g_out_shape.ndim() == 0)
+      continue;
+
+    auto out = TShape(g_out_shape.ndim() + 1);
+    out[0] = len;
+    for (size_t i = 1; i < out.ndim(); i++)
+      out[i] = g_out_shape[i - 1];
+    SHAPE_ASSIGN_CHECK(*out_shape, i, out);
+  }
+  // For the shape of output states.
+  for (size_t i = params.num_out_data; i < subg_out_shape.size(); i++)
+    SHAPE_ASSIGN_CHECK(*out_shape, i, subg_out_shape[i]);
+
+  // For the shape of input data.
+  for (size_t i = 0; i < params.in_data_locs.ndim(); i++) {
+    size_t loc = params.in_data_locs[i];
+    const auto &shape = subg_in_shape[loc];
+    // If the input data shape isn't inferred, we don't need to propogate the
+    // info.
+    if (shape.ndim() == 0)
+      continue;
+
+    if (data_1d[i]) {
+      TShape s(1);
+      s[0] = len;
+      SHAPE_ASSIGN_CHECK(*in_shape, i, s);
+    } else {
+      auto in = TShape(shape.ndim() + 1);
+      in[0] = len;
+      for (size_t i = 1; i < in.ndim(); i++)
+        in[i] = shape[i - 1];
+      SHAPE_ASSIGN_CHECK(*in_shape, i, in);
+    }
+  }
+  // For the shape of state.
+  for (size_t i = 0; i < params.in_state_locs.ndim(); i++) {
+    size_t loc = params.in_state_locs[i];
+    SHAPE_ASSIGN_CHECK(*in_shape, i + params.in_data_locs.ndim(),
+                       subg_in_shape[loc]);
+  }
+  // For the shape of remaining data.
+  for (size_t i = 0; i < params.remain_locs.ndim(); i++) {
+    size_t loc = params.remain_locs[i];
+    SHAPE_ASSIGN_CHECK(*in_shape,
+                       i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
+                       subg_in_shape[loc]);
+  }
+
+  if (infer_success) {
+    size_t num_states = out_shape->size() - params.num_out_data;
+    for (size_t i = 0; i < num_states; i++) {
+      CHECK_EQ((*out_shape)[i + params.num_out_data],
+               (*in_shape)[i + params.in_data_locs.ndim()]);
+    }
+  }
+  // Check if we have inferred the shapes correctly.
+  return infer_success;
+}
+
+static bool ForeachType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int> *in_type, std::vector<int> *out_type) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 1U);
+  std::vector<int> subg_in_type(in_type->size(), 0);
+  remap(*in_type, 0, params.in_data_locs, &subg_in_type);
+  remap(*in_type, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_type);
+  remap(*in_type, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
+        params.remain_locs, &subg_in_type);
+  bool success = InferSubgraphDataType(*attrs.subgraphs[0], &subg_in_type, out_type);
+  for (size_t i = 0; i < params.in_data_locs.ndim(); i++) {
+    size_t loc = params.in_data_locs[i];
+    TYPE_ASSIGN_CHECK(*in_type, i, subg_in_type[loc]);
+  }
+  for (size_t i = 0; i < params.in_state_locs.ndim(); i++) {
+    size_t loc = params.in_state_locs[i];
+    TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim(), subg_in_type[loc]);
+  }
+  for (size_t i = 0; i < params.remain_locs.ndim(); i++) {
+    size_t loc = params.remain_locs[i];
+    TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
+                      subg_in_type[loc]);
+  }
+  return success;
+}
+
+static bool ForeachStorageType(const nnvm::NodeAttrs& attrs,
+                               const int dev_mask,
+                               DispatchMode* dispatch_mode,
+                               std::vector<int> *in_attrs,
+                               std::vector<int> *out_attrs) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
+  CHECK_EQ(attrs.subgraphs.size(), 1U);
+  std::vector<int> subg_in_attrs(in_attrs->size(), kUndefinedStorage);
+  remap(*in_attrs, 0, params.in_data_locs, &subg_in_attrs);
+  remap(*in_attrs, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_attrs);
+  remap(*in_attrs, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
+        params.remain_locs, &subg_in_attrs);
+  bool success = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask,
+                                      dispatch_mode, &subg_in_attrs, out_attrs);
+  for (size_t i = 0; i < params.in_data_locs.ndim(); i++) {
+    size_t loc = params.in_data_locs[i];
+    STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, subg_in_attrs[loc]);
+  }
+  for (size_t i = 0; i < params.in_state_locs.ndim(); i++) {
+    size_t loc = params.in_state_locs[i];
+    STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i + params.in_data_locs.ndim(),
+                              subg_in_attrs[loc]);
+  }
+  for (size_t i = 0; i < params.remain_locs.ndim(); i++) {
+    size_t loc = params.remain_locs[i];
+    STORAGE_TYPE_ASSIGN_CHECK(*in_attrs,
+                              i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
+                              subg_in_attrs[loc]);
+  }
+  return success;
+}
+
+static bool BackwardForeachStorageType(const nnvm::NodeAttrs& attrs,
+                                       const int dev_mask,
+                                       DispatchMode* dispatch_mode,
+                                       std::vector<int> *in_attrs,
+                                       std::vector<int> *out_attrs) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  CHECK_EQ(out_attrs->size(), (size_t) params.num_args - 1);
+  CHECK_EQ(in_attrs->size(), (size_t) params.num_args - 1 + params.num_outputs * 2);
+  CHECK_EQ(attrs.subgraphs.size(), 1U);
+  CachedOp op(*attrs.subgraphs[0],
+              std::vector<std::pair<std::string, std::string> >());
+  // map the operator inputs to the subgraph inputs.
+  std::vector<int> subg_forward_ins(params.num_args - 1, kUndefinedStorage);
+  remap(*in_attrs, params.num_outputs, params.in_data_locs, &subg_forward_ins);
+  remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim(),
+        params.in_state_locs, &subg_forward_ins);
+  remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
+        params.remain_locs, &subg_forward_ins);
+
+  // Copy backward input storage to backward subgraph input storage.
+  std::vector<int> subg_in_attrs = *in_attrs;
+  for (size_t i = 0; i < subg_forward_ins.size(); i++)
+    subg_in_attrs[i + params.num_outputs] = subg_forward_ins[i];
+  return op.BackwardStorageType(attrs, dev_mask, dispatch_mode,
+                                &subg_in_attrs, out_attrs);
+}
+
+static OpStatePtr CreateForeachState(const NodeAttrs& attrs,
+                                     Context ctx,
+                                     const std::vector<TShape>& ishape,
+                                     const std::vector<int>& itype) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  return OpStatePtr::Create<ForeachState>(*attrs.subgraphs[0], params);
+}
+
+static std::vector<nnvm::NodeEntry>
+ForeachGradient(const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+  ElemwiseGradUseInOut fgrad{"_backward_foreach"};
+  std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
+  entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
+  return entries;
+}
+
+NNVM_REGISTER_OP(_foreach)
+.MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation")
+.set_attr_parser(ParamParser<ForeachParam>)
+.set_attr<FInferStorageType>("FInferStorageType", ForeachStorageType)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  return params.num_args;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  return params.num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  std::vector<std::string> names;
+  names.push_back("fn");
+  for (int i = 0; i < params.num_args - 1; i++)
+    names.push_back("data" + std::to_string(i));
+  return names;
+})
+.set_attr<nnvm::FInputGraph>("FInputGraph",
+    [](const NodeAttrs& attrs) {
+  return std::vector<uint32_t>{0};
+})
+.set_attr<nnvm::FGradient>("FGradient", ForeachGradient)
+.set_attr<FCreateOpState>("FCreateOpState", CreateForeachState)
+.set_attr<nnvm::FInferShape>("FInferShape", ForeachShape)
+.set_attr<nnvm::FInferType>("FInferType", ForeachType)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachComputeExCPU)
+// Foreach operator works like an executor. Its code will always run on CPU.
+// So the same code can be registered for both CPU and GPU.
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachComputeExCPU)
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+})
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("fn", "Symbol", "Input graph.")
+.add_argument("data", "NDArray-or-Symbol[]",
+              "The input arrays that include data arrays and states.")
+.add_arguments(ForeachParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_foreach)
+.set_num_inputs([](const NodeAttrs& attrs){
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  return params.num_outputs * 2 + params.num_args - 1;
+  })
+.set_num_outputs([](const NodeAttrs& attrs){
+  const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
+  return params.num_args - 1;
+  })
+.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
+  return ExecType::kSubgraphExec;
+})
+.set_attr<FInferStorageType>("FInferStorageType", BackwardForeachStorageType)
+.set_attr_parser(ParamParser<ForeachParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachGradComputeExCPU)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachGradComputeExCPU);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc
new file mode 100644
index 00000000000..71a9a21c28c
--- /dev/null
+++ b/src/operator/subgraph_op_common.cc
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./subgraph_op_common.h"
+#include "./operator_common.h"
+#include "../imperative/imperative_utils.h"
+
+namespace mxnet {
+namespace op {
+
+bool InferSubgraphDataType(const nnvm::Symbol &subgraph,
+                           std::vector<int> *in_types,
+                           std::vector<int> *out_types) {
+  nnvm::Graph g;
+  g.outputs = subgraph.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
+  CHECK_EQ(idx_g.outputs().size(), out_types->size());
+
+  // Put the input and output data types to the dtype vector.
+  nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_types->size());
+  for (size_t i = 0; i < in_types->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    types[eid] = in_types->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_types->size());
+  for (size_t i = 0; i < out_types->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    types[eid] = out_types->at(i);
+  }
+
+  // Infer data type of the graph.
+  g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
+  g = exec::InferType(std::move(g));
+
+  const auto& types1 = g.GetAttr<nnvm::DTypeVector>("dtype");
+  // assign to in_types
+  for (size_t i = 0; i < in_types->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    TYPE_ASSIGN_CHECK(*in_types, i, types1[eid]);
+  }
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    TYPE_ASSIGN_CHECK(*out_types, i, types1[eid]);
+  }
+  // Check if we have inferred the dtypes correctly.
+  return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
+}
+
+bool InferSubgraphStorage(const nnvm::Symbol &subgraph,
+                          const int dev_mask,
+                          DispatchMode* dispatch_mode,
+                          std::vector<int> *in_stypes,
+                          std::vector<int> *out_stypes) {
+  nnvm::Graph g;
+  g.outputs = subgraph.outputs;
+  const auto& idx_g = g.indexed_graph();
+  CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size());
+  CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
+  exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
+
+  // Put the input and output storages to the storage vector.
+  nnvm::StorageVector stypes(idx_g.num_node_entries(), exec::kBadStorageID);
+  const auto &input_nids = idx_g.input_nodes();
+  CHECK_EQ(input_nids.size(), in_stypes->size());
+  for (size_t i = 0; i < in_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(input_nids[i], 0);
+    stypes[eid] = in_stypes->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_stypes->size());
+  for (size_t i = 0; i < out_stypes->size(); i++) {
+    auto eid = idx_g.entry_id(g.outputs[i]);
+    stypes[eid] = out_stypes->at(i);
+  }
+
+  // Infer storage type of the graph.
+  bool dev_match = g.attrs.count("dev_mask") &&
+                   g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
+  if (!dev_match) {
+    g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
+  }
+  g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
+  g = exec::InferStorageType(std::move(g));
+
+  const auto& stypes1 = g.GetAttr<StorageTypeVector>("storage_type");
+  // assign to in_types
+  for (size_t i = 0; i < in_stypes->size(); ++i) {
+    const auto eid = idx_g.entry_id(input_nids[i], 0);
+    STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes1[eid]);
+  }
+
+  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+  // assign to out_types
+  for (size_t i = 0; i < g.outputs.size(); ++i) {
+    const auto eid = idx_g.entry_id(g.outputs[i]);
+    STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes1[eid]);
+  }
+  // Check if we have inferred the storages correctly.
+  return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
+}
+
+bool InferSubgraphShape(const nnvm::Symbol &subgraph,
+                        std::vector<TShape> *in_shape,
+                        std::vector<TShape> *out_shape) {
+  nnvm::Graph g;
+  g.outputs = subgraph.outputs;
+  const auto& idx = g.indexed_graph();
+  CHECK_EQ(idx.input_nodes().size(), in_shape->size());
+  CHECK_EQ(idx.outputs().size(), out_shape->size());
+
+  // Put the input and output shapes to the shape vector.
+  nnvm::ShapeVector shapes(idx.num_node_entries());
+  const auto &input_nids = idx.input_nodes();
+  CHECK_EQ(input_nids.size(), in_shape->size());
+  for (size_t i = 0; i < in_shape->size(); i++) {
+    auto eid = idx.entry_id(input_nids[i], 0);
+    shapes[eid] = in_shape->at(i);
+  }
+  CHECK_EQ(g.outputs.size(), out_shape->size());
+  for (size_t i = 0; i < out_shape->size(); i++) {
+    auto eid = idx.entry_id(g.outputs[i]);
+    shapes[eid] = out_shape->at(i);
+  }
+
+  // Infer shape of the graph.
+  g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
+  g = exec::InferShape(std::move(g));
+
+  const auto& shapes1 = g.GetAttr<nnvm::ShapeVector>("shape");
+  // Inferring the shape in the subgraph may infer the shape of the inputs.
+  // We need to copy the inferred input shapes back.
+  CHECK_EQ(input_nids.size(), in_shape->size());
+  for (size_t i = 0; i < in_shape->size(); i++) {
+    auto eid = idx.entry_id(input_nids[i], 0);
+    SHAPE_ASSIGN_CHECK(*in_shape, i, shapes1[eid]);
+  }
+
+  for (size_t i = 0; i < g.outputs.size(); i++) {
+    uint32_t eid = idx.entry_id(g.outputs[i]);
+    SHAPE_ASSIGN_CHECK(*out_shape, i, shapes1[eid]);
+  }
+  return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
+}
+
+LoopState::LoopState(const Symbol &g) {
+  this->subgraph_sym = g;
+  this->subgraph.outputs = g.outputs;
+
+  std::vector<std::pair<std::string, std::string> > kwargs;
+  kwargs.push_back(std::pair<std::string, std::string>("inline_limit", "0"));
+  // We turn on static_alloc for two reasons.
+  // It avoids the overhead of unnecessary memory allocation.
+  // only static_alloc supports nested call of CachedOp.
+  kwargs.push_back(std::pair<std::string, std::string>("static_alloc", "1"));
+  iter_op = std::make_shared<CachedOp>(subgraph_sym, kwargs);
+}
+
+void LoopState::Forward(int iter_no,
+                        const std::vector<NDArray> &cinputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<NDArray> &coutputs,
+                        bool is_recording) {
+  using namespace nnvm;
+  using namespace imperative;
+
+  bool orig_is_record;
+  if (is_recording)
+    orig_is_record = Imperative::Get()->set_is_recording(true);
+  else
+    orig_is_record = Imperative::Get()->is_recording();
+
+  std::vector<NDArray> in_bufs = cinputs;
+  std::vector<NDArray> out_bufs = coutputs;
+  std::vector<NDArray *> inputs(cinputs.size());
+  std::vector<NDArray *> outputs(coutputs.size());
+  for (size_t i = 0; i < inputs.size(); i++)
+    inputs[i] = &in_bufs[i];
+  for (size_t i = 0; i < outputs.size(); i++)
+    outputs[i] = &out_bufs[i];
+
+  OpStatePtr state = iter_op->Forward(nullptr, inputs, outputs);
+  // If an input and an output share the array, the output array will be changed
+  // by CachedOp. We need to copy data to the real output.
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    if (!out_bufs[i].IsSame(coutputs[i]))
+      CopyFromTo(out_bufs[i], coutputs[i]);
+  if (is_recording) {
+    all_inputs.push_back(cinputs);
+    all_outputs.push_back(coutputs);
+    all_states.push_back(state);
+  }
+
+  Imperative::Get()->set_is_recording(orig_is_record);
+}
+
+void LoopState::Backward(int iter_no,
+                         const std::vector<NDArray> &ograds,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<NDArray> &igrads) {
+  using namespace nnvm;
+  using namespace imperative;
+
+  CHECK_GT(all_states.size(), iter_no)
+      << "We didn't record the computation for iteration " << iter_no;
+  auto op = iter_op;
+  std::vector<NDArray *> inputs;
+  std::vector<NDArray *> outputs;
+  inputs.reserve(op->num_backward_inputs());
+  outputs.reserve(op->num_inputs());
+  std::vector<NDArray> ograd_bufs = ograds;
+  std::vector<NDArray> igrad_bufs = igrads;
+  for (size_t i = 0; i < ograds.size(); i++)
+    inputs.push_back(&ograd_bufs[i]);
+
+  const std::vector<bool> &save_inputs = op->save_inputs();
+  const std::vector<bool> &save_outputs = op->save_outputs();
+  CHECK_EQ(save_inputs.size(), all_inputs[iter_no].size());
+  CHECK_EQ(op->num_outputs(), all_outputs[iter_no].size());
+  for (size_t i = 0; i < all_inputs[iter_no].size(); i++) {
+    if (save_inputs[i])
+      inputs.push_back(&all_inputs[iter_no][i]);
+  }
+  for (size_t i = 0; i < all_outputs[iter_no].size(); i++) {
+    if (save_outputs[i])
+      inputs.push_back(&all_outputs[iter_no][i]);
+  }
+  CHECK_EQ(inputs.size(), op->num_backward_inputs());
+  for (size_t i = 0; i < igrads.size(); i++)
+    outputs.push_back(&igrad_bufs[i]);
+  CHECK_EQ(outputs.size(), op->num_inputs());
+  auto state = all_states[iter_no];
+  op->Backward(false, state, inputs, req, outputs);
+  // If an input and an output share the array, the output array will be changed
+  // by CachedOp. We need to copy data to the real output.
+  for (size_t i = 0; i < igrads.size(); i++)
+    if (!igrads[i].IsSame(igrad_bufs[i]))
+      CopyFromTo(igrad_bufs[i], igrads[i]);
+}
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h
new file mode 100644
index 00000000000..79078409e21
--- /dev/null
+++ b/src/operator/subgraph_op_common.h
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_
+#define MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_
+
+#include <mxnet/io.h>
+#include <mxnet/base.h>
+#include <mxnet/op_attr_types.h>
+#include <vector>
+#include "../imperative/cached_op.h"
+#include "../imperative/imperative_utils.h"
+
+namespace mxnet {
+namespace op {
+
+/*
+ * Infer the data types of inputs and outputs of an operator that contains a
+ * subgraph.
+ */
+bool InferSubgraphDataType(const nnvm::Symbol &subgraph, std::vector<int> *in_type,
+                           std::vector<int> *out_type);
+
+/*
+ * Infer the shape of inputs and outputs of an operator that contains a
+ * subgraph.
+ */
+bool InferSubgraphShape(const nnvm::Symbol &subgraph,
+                        std::vector<TShape> *in_shape,
+                        std::vector<TShape> *out_shape);
+
+/*
+ * Infer the storage types of inputs and outputs of an operator that contains a
+ * subgraph.
+ */
+bool InferSubgraphStorage(const nnvm::Symbol &subgraph,
+                          const int dev_mask,
+                          DispatchMode* dispatch_mode,
+                          std::vector<int> *in_attrs,
+                          std::vector<int> *out_attrs);
+
+/*
+ * This contains the states for running a loop and provides methods
+ * of running the subgraph computation for an iteration.
+ */
+class LoopState {
+  // These are output arrays from all iterations.
+  // They also contain the Op state for each CachedOp.
+  std::vector<std::vector<NDArray> > all_outputs;
+  std::vector<std::vector<NDArray> > all_inputs;
+  // For inference, there should be only one cached op because we
+  // want to share the memory in iterations.
+  // For training, each iteration has a cached op because each iteration
+  // needs to maintain a set of memory buffers for all computation states,
+  // which will be used in the backward.
+  CachedOpPtr iter_op;
+  std::vector<OpStatePtr> all_states;
+  Symbol subgraph_sym;
+  nnvm::Graph subgraph;
+
+ public:
+  explicit LoopState(const Symbol &g);
+
+  void Forward(int iter_no,
+               const std::vector<NDArray> &inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray> &outputs,
+               bool is_recording);
+  void Backward(int iter_no,
+                const std::vector<NDArray> &ograds,
+                const std::vector<OpReqType> &req,
+                const std::vector<NDArray> &igrads);
+  void Cleanup() {
+    all_outputs.clear();
+    all_inputs.clear();
+    all_states.clear();
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 8dd6934ff23..f3f66c453ff 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1362,6 +1362,56 @@ def test_hybrid_static_memory_recording():
     net(x)
 
 
+def test_share_inputs_outputs():
+    class TestIOBackward(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestIOBackward, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, in1, in2):
+            return in1 + in2
+
+    class TestIOForward(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestIOForward, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, in1):
+            return in1
+
+    d1 = mx.nd.arange(10)
+    d2 = mx.nd.arange(10)
+
+    params=[{'inline_limit':0},
+            {'inline_limit':0, 'static_alloc':True},
+            {'inline_limit':0, 'static_alloc':True, 'static_shape':True}]
+    # Test the case that inputs and outputs of a forward graph share NDArrays.
+    for param in params:
+        t = TestIOForward()
+        t.hybridize(**param)
+        for i in range(5):
+            d1.attach_grad()
+            out_grad = mx.nd.random.uniform(shape=(10))
+            res = t(d1)
+            assert_almost_equal(res.asnumpy(), d1.asnumpy())
+
+    param = deepcopy(params[2])
+    param['param_indices'] = (1)
+    param['data_indices'] = (0)
+    params.append(param)
+    # Test the case that inputs and outputs of a backward graph share NDArrays.
+    for param in params:
+        t = TestIOBackward()
+        t.hybridize(**param)
+        for i in range(5):
+            d1.attach_grad()
+            d2.attach_grad()
+            out_grad = mx.nd.random.uniform(shape=(10))
+            with mx.autograd.record():
+                res = t(d1, d2)
+            res.backward(out_grad=out_grad)
+            assert_almost_equal(out_grad.asnumpy(), d1.grad.asnumpy())
+            assert_almost_equal(out_grad.asnumpy(), d2.grad.asnumpy())
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 9dbcb3b3be8..6f661ac5e35 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -18,9 +18,10 @@
 import mxnet as mx
 from mxnet import gluon
 import numpy as np
+import copy
 from numpy.testing import assert_allclose
 import unittest
-from mxnet.test_utils import almost_equal
+from mxnet.test_utils import almost_equal, assert_almost_equal
 
 
 def test_rnn():
@@ -28,13 +29,69 @@ def test_rnn():
     inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
     outputs, _ = cell.unroll(3, inputs)
     outputs = mx.sym.Group(outputs)
-    assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
+    assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight',
+                                                    'rnn_i2h_bias', 'rnn_i2h_weight']
     assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']
 
     args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
     assert outs == [(10, 100), (10, 100), (10, 100)]
 
 
+class TestRNNLayer(gluon.HybridBlock):
+    def __init__(self, cell_type, hidden_size, prefix=None, params=None):
+        super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
+        self.cell = cell_type(hidden_size, prefix='rnn_')
+
+    def hybrid_forward(self, F, inputs, states):
+        out, states = F.contrib.foreach(self.cell, inputs, states)
+        return out
+
+def check_contrib_rnn(cell_type, num_states):
+    batch_size = 10
+    hidden_size = 100
+    rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50))
+    state_shape = (batch_size, hidden_size)
+    states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)]
+    layer = TestRNNLayer(cell_type, hidden_size)
+    layer.initialize(ctx=mx.cpu(0))
+    res1 = layer(rnn_data, states)
+    params1 = layer.collect_params()
+    orig_params1 = copy.deepcopy(params1)
+
+    trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
+    with mx.autograd.record():
+        res1 = layer(rnn_data, states)
+    res1.backward()
+    trainer.step(batch_size)
+
+    layer = TestRNNLayer(cell_type, hidden_size)
+    layer.initialize(ctx=mx.cpu(0))
+    layer.hybridize()
+    res2 = layer(rnn_data, states)
+    params2 = layer.collect_params()
+    for key, val in orig_params1.items():
+        params2[key].set_data(val.data())
+
+    trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
+    with mx.autograd.record():
+        res2 = layer(rnn_data, states)
+    assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
+    res2.backward()
+    trainer.step(batch_size)
+
+    for key, val in params1.items():
+        weight1 = val.data()
+        weight2 = params2[key].data()
+        assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), rtol=0.001, atol=0.0001)
+
+
+def test_contrib_rnn():
+    cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2),
+            (gluon.rnn.GRUCell, 1)]
+    for cell_type, num_states in cell_types:
+        check_contrib_rnn(cell_type, num_states)
+
+
 def test_lstm():
     cell = gluon.rnn.LSTMCell(100, prefix='rnn_')
     inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1a1c548c595..0da52577f5e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -19,12 +19,13 @@
 from __future__ import print_function
 import numpy as np
 import mxnet as mx
+import copy
 import math
 import random
 import itertools
 from numpy.testing import assert_allclose, assert_array_equal
 from mxnet.test_utils import *
-from mxnet.base import py_str, MXNetError
+from mxnet.base import py_str, MXNetError, _as_list
 from common import setup_module, with_seed, teardown
 import unittest
 
@@ -5873,6 +5874,477 @@ def test_float16_min_max():
     assert np.finfo('float16').max == mx.nd.max(a).asscalar()
 
 
+@with_seed()
+def test_foreach():
+    v3 = mx.sym.var("v0")
+    v4 = mx.sym.var("v1")
+    v5 = mx.sym.var("v2")
+    v6 = mx.sym.var("v3")
+    v7 = mx.sym.var("v4")
+    v8 = mx.sym.var("v5")
+
+    def verify_foreach(step, in_syms, state_syms, free_syms,
+            in_arrs, init_states, frees, out_grads, is_train=True,
+            free_vars_func=None, num_iters=1):
+        step_sym = lambda in_syms, state_syms : step(in_syms, state_syms, free_syms)
+        res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms)
+        out = _as_list(res)
+        num_outputs = len(out)
+        for i in range(num_outputs):
+            out[i] = out[i] * 2
+        out.extend(states)
+        out = mx.sym.Group(out)
+        js_1 = out.tojson()
+        out = mx.sym.load_json(js_1)
+        js_2 = out.tojson()
+        assert js_1 == js_2
+        arr_grads = []
+        arg_dict = {}
+        arg_grad_dict = {}
+        i = 0
+        for arr in _as_list(in_arrs):
+            arr_grad = mx.nd.empty(arr.shape)
+            arr_grads.append(arr_grad)
+            arg_dict['v'+str(i)] = arr
+            arg_grad_dict['v'+str(i)] = arr_grad
+            i = i + 1
+        for arr in init_states:
+            arr_grad = mx.nd.empty(arr.shape)
+            arr_grads.append(arr_grad)
+            arg_dict['v'+str(i)] = arr
+            arg_grad_dict['v'+str(i)] = arr_grad
+            i = i + 1
+        for arr in frees:
+            arr_grad = mx.nd.empty(arr.shape)
+            arr_grads.append(arr_grad)
+            arg_dict['v'+str(i)] = arr
+            arg_grad_dict['v'+str(i)] = arr_grad
+            i = i + 1
+
+        if is_train:
+            e = out.bind(ctx=default_context(), args=arg_dict, args_grad=arg_grad_dict)
+        else:
+            e = out.bind(ctx=default_context(), args=arg_dict)
+        # the inputs to forward and backward are the same so forward and backward
+        # should always return the same outputs.
+        for i in range(num_iters):
+            e.forward(is_train=is_train)
+            if (is_train):
+                # backward
+                tmp_grads = out_grads[0][:]
+                tmp_grads.extend(out_grads[1])
+                e.backward(tmp_grads)
+
+        # Below we use imperative to reimplement foreach and compute its gradients.
+        res = []
+        for i in range(len(_as_list(out_grads[0]))):
+            res.append([])
+        for arr in _as_list(in_arrs):
+            arr.attach_grad()
+        for arr in init_states:
+            arr.attach_grad()
+        for arr in frees:
+            arr.attach_grad()
+        with mx.autograd.record():
+            frees_imp = frees if free_vars_func is None else free_vars_func(frees)
+            step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs, frees_imp)
+            states = [mx.nd.expand_dims(s, 0) for s in init_states]
+            res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states)
+
+            res2 = _as_list(res)
+            for i in range(len(res2)):
+                res2[i] = res2[i] * 2
+            outs = []
+            outs[:] = res2[:]
+            if isinstance(states, list):
+                outs.extend(states)
+                states = [mx.nd.expand_dims(s, 0) for s in states]
+                res2.extend(states)
+            else:
+                outs.append(states)
+                states = mx.nd.expand_dims(states, 0)
+                res2.append(states)
+            if is_train:
+                res = mx.nd.concat(*res2, dim=0)
+
+        tmp_grads = out_grads[0][:]
+        tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]]
+        tmp_grads.extend(tmp_grads1)
+        if is_train:
+            res.backward(mx.nd.concat(*tmp_grads, dim=0))
+        for i in range(len(outs)):
+            assert e.outputs[i].shape == outs[i].shape
+            assert_almost_equal(e.outputs[i].asnumpy(), outs[i].asnumpy(),
+                    rtol=0.001, atol=0.0001)
+        if (is_train):
+            all_ins = _as_list(in_arrs)[:]
+            all_ins.extend(init_states)
+            all_ins.extend(frees)
+            size = min(len(all_ins), len(e.grad_arrays))
+            for i in range(size):
+                assert_almost_equal(all_ins[i].grad.asnumpy(),
+                        e.grad_arrays[i].asnumpy(),
+                        rtol=0.001, atol=0.0001)
+
+    # Test cases:
+    # * graph inputs are stored in different orders.
+    #   This is to test if foreach finds the data arrays and weight arrays
+    #   in the right location.
+    # * the number of iterations: odd or even.
+    # * multiple inputs and multiple outputs.
+    # * inference.
+    def step1(in1, states, free):
+        out = in1 * 2 + states[0] + free[0]
+        return (out, [out])
+    frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1]
+    arrs = mx.nd.arange(6).reshape(shape=(3, 2))
+    states = [mx.nd.arange(2)]
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+    verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True,
+            lambda frees : [frees[0] + frees[1]])
+    verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False,
+            lambda frees : [frees[0] + frees[1]])
+    verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True,
+            lambda frees : [frees[0] + frees[1]], 5)
+    verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False,
+            lambda frees : [frees[0] + frees[1]], 5)
+
+    # Test the even number of iterations.
+    frees = [mx.nd.random.uniform(shape=(2))]
+    arrs = mx.nd.random.uniform(shape=(2, 2))
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+    verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+    # Test the odd number of iterations
+    arrs = mx.nd.random.uniform(shape=(3, 2))
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+    verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+
+    # Reorder the input and state in the subgraph inputs.
+    def step2(in1, states, free):
+        out = states[0] + in1 * 2 + free[0]
+        return (out, [out])
+    # Test the even number of iterations.
+    arrs = mx.nd.random.uniform(shape=(2, 2))
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+    verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+    # Test the odd number of iterations.
+    arrs = mx.nd.random.uniform(shape=(3, 2))
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+    verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+
+    # Test multiple inputs and outputs.
+    def step3(in1, states, free):
+        out = in1[0] + in1[1] * 2 + states[0] + states[1] * 2 + free[0]
+        return ([out, out], [out * 2, out * 3])
+    arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))]
+    states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
+    verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads)
+    verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False)
+
+    # Test multiple inputs and outputs.
+    # The order of subgraph inputs doesn't match the operator inputs
+    def step4(in1, states, free):
+        out = in1[1] * 2 + states[0] + free[0] + states[1] * 2 + in1[0]
+        return ([out, out * 2], [out * 2, out * 3])
+    arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))]
+    states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
+    verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads)
+    verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False)
+
+    # Test multiple inputs and outputs.
+    # The data inputs and states have different shapes.
+    def step5(in1, states, free):
+        if isinstance(in1[0], mx.nd.NDArray):
+            out1 = mx.nd.broadcast_add(states[0] + free[1], in1[1] * 2)
+            out2 = mx.nd.broadcast_add(in1[0], free[0] + states[1] * 2)
+        else:
+            out1 = mx.sym.broadcast_add(states[0] + free[1], in1[1] * 2)
+            out2 = mx.sym.broadcast_add(in1[0], free[0] + states[1] * 2)
+        return ([out1, out2 * 2], [states[0] * 2, states[1] * 3])
+    frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2, 2))]
+    arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), mx.nd.random.uniform(shape=(3, 2))]
+    states = [mx.nd.random.uniform(shape=(2, 2)), mx.nd.random.uniform(shape=(2))]
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
+    verify_foreach(step5, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False)
+
+    # Test multiple inputs and outputs.
+    # The data inputs and states have different shapes and data types.
+    def step6(in1, states, free):
+        if isinstance(in1[0], mx.nd.NDArray):
+            out1 = mx.nd.broadcast_add(states[0] + mx.nd.cast(free[1], 'float32'),
+                    mx.nd.cast(in1[1], 'float32') * 2)
+            out2 = mx.nd.broadcast_add(in1[0],
+                    free[0] + mx.nd.cast(states[1], 'float32') * 2)
+        else:
+            out1 = mx.sym.broadcast_add(states[0] + mx.sym.cast(free[1], 'float32'),
+                    mx.sym.cast(in1[1], 'float32') * 2)
+            out2 = mx.sym.broadcast_add(in1[0],
+                    free[0] + mx.sym.cast(states[1], 'float32') * 2)
+        return ([out1, out2 * 2], [states[0] * 2, states[1] * 3])
+    frees = [mx.nd.random.uniform(shape=(2)),
+            mx.nd.cast(mx.nd.random.uniform(shape=(2, 2)), 'float64')]
+    arrs = [mx.nd.random.uniform(shape=(3, 2, 2)),
+            mx.nd.cast(mx.nd.random.uniform(shape=(3, 2)), dtype='float16')]
+    states = [mx.nd.random.uniform(shape=(2, 2)),
+            mx.nd.cast(mx.nd.random.uniform(shape=(2)), dtype='int32')]
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
+    verify_foreach(step6, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False)
+
+    # Test multiple inputs and outputs.
+    # some of the inputs are used twice.
+    def step7(in1, states, free):
+        out1 = states[0] + in1[0] + free[1] + in1[1] * 2 + free[0]
+        out2 = in1[0] + free[0] + states[1] * 2 + in1[1]
+        return ([out1, out2 * 2], [states[0] * 2, states[1] * 3])
+    frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
+    arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))]
+    states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]]
+    verify_foreach(step7, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False)
+
+    # Test the case that the output is the input.
+    arrs = mx.nd.random.uniform(shape=(3, 2))
+    states = [mx.nd.arange(2)]
+    frees = [mx.nd.random.uniform(shape=(2))]
+    out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)],
+            [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+    def step8(in1, states, free):
+        return (in1, [states[0] * free[0]])
+    verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+    def step9(in1, states, free):
+        return (in1 * free[0], states)
+    verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+
+    # Test the case that not all inputs are used.
+    def step10(in1, states, free):
+        return (in1, states)
+    verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+    def step11(in1, states, free):
+        return (in1, free)
+    try:
+        verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads)
+        verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+    except AssertionError:
+        print("the states have to be used")
+    def step12(in1, states, free):
+        return (in1, [states[0] + 1, states[0] + 2])
+    states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))]
+    frees = []
+    try:
+        verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads)
+        verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads, False)
+    except AssertionError:
+        print("the states have to be used")
+
+    # test without free variables.
+    def step13(in1, states, free):
+        return (in1, states)
+    states = [mx.nd.random.uniform(shape=(2))]
+    verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads)
+    verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads, False)
+
+    # test when there isn't output data or output states.
+    def step14(in1, states, free):
+        return (in1 + free[0], [])
+    frees = [mx.nd.random.uniform(shape=(2))]
+    verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads)
+    verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads, False)
+    def step15(in1, states, free):
+        return ([], [in1 * states[0] * free[0]])
+    out_grads = [[], [mx.nd.random.uniform(-10, 10, states[0].shape)]]
+    verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads)
+    verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads, False)
+
+    # Test the case of iterating on a 1D data array.
+    def step16(in1, states, free):
+        return ([in1[0] * states[0]], [states[0] * 2])
+    arrs = [mx.nd.arange(3)]
+    states = [mx.nd.random.uniform(shape=(1))]
+    out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))],
+            [mx.nd.random.uniform(-10, 10, (1))]]
+    verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads)
+    verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads, False)
+    def step17(in1, states, free):
+        return ([in1[1] * in1[0] * states[0]], [states[0] * 2])
+    arrs = [mx.nd.random.uniform(shape=(3, 1)), mx.nd.arange(3)]
+    states = [mx.nd.random.uniform(shape=(1))]
+    out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))],
+            [mx.nd.random.uniform(-10, 10, (1))]]
+    verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads)
+    verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads, False)
+
+
+@with_seed()
+def test_foreach_nested():
+    # Test nested foreach.
+    def step_in(in1, states):
+        out = in1 * 2 + states[0]
+        return (out, [out])
+
+    def step_sym(in1, states):
+        out1 = mx.sym.contrib.foreach(step_in, in1, states)
+        out = mx.sym.broadcast_add(out1[0], states[0])
+        return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1, 2)))])
+    def step_nd(in1, states):
+        out1 = mx.nd.contrib.foreach(step_in, in1, states)
+        out = mx.nd.broadcast_add(out1[0], states[0])
+        return (out, [mx.nd.squeeze(mx.nd.slice(out, begin=(0, 0), end=(1, 2)))])
+
+    data_sym = mx.sym.var("v1")
+    state_sym = mx.sym.var("v2")
+    out, states = mx.sym.contrib.foreach(step_sym, data_sym, [state_sym])
+    assert isinstance(states, list)
+    assert len(states) == 1
+    out = mx.sym.broadcast_add(out, states[0])
+
+    js_1 = out.tojson()
+    out = mx.sym.load_json(js_1)
+    js_2 = out.tojson()
+    assert js_1 == js_2
+
+    data = mx.nd.arange(8).reshape((2, 2, 2))
+    state = mx.nd.arange(2)
+    data_grad = mx.nd.empty(data.shape)
+    state_grad = mx.nd.empty(state.shape)
+    e = out.bind(ctx=default_context(), args={'v1':data, 'v2':state},
+            args_grad={'v1':data_grad, 'v2':state_grad})
+    e.forward(is_train=True)
+    out_grads = []
+    for out in e.outputs:
+        out_grads.append(mx.nd.random.uniform(shape=out.shape))
+    e.backward(out_grads)
+
+    data.attach_grad()
+    state.attach_grad()
+    with mx.autograd.record():
+        out, states = mx.nd.contrib.foreach(step_nd, data, [state])
+        assert isinstance(states, list)
+        assert len(states) == 1
+        res = mx.nd.broadcast_add(out, states[0])
+    assert_almost_equal(res.asnumpy(), e.outputs[0].asnumpy(), rtol=0.001, atol=0.0001)
+
+    res.backward(out_grads[0])
+    assert_almost_equal(data.grad.asnumpy(), data_grad.asnumpy())
+    assert_almost_equal(state.grad.asnumpy(), state_grad.asnumpy())
+
+
+def check_foreach_rnn(cell_type, num_states):
+    data = mx.sym.var("data")
+    params = mx.rnn.RNNParams()
+    hidden_dim = 4
+    input_dim = 5
+    seq_len = 2
+    batch_size = 2
+
+    # This tests foreach with accumulation sum.
+    def step(in1, states):
+        rnn = cell_type(hidden_dim, prefix='', params=params)
+        next_h, states = rnn(in1, states)
+        return (next_h, states)
+
+    def sym_group(out):
+        if (isinstance(out[0], mx.sym.Symbol)):
+            ret = [out[0]]
+        else:
+            ret = out[0]
+        ret.extend(out[1])
+        return mx.sym.Group(ret)
+
+    rnn = cell_type(hidden_dim, prefix='', params=params)
+    if num_states == 2:
+        init_states = [mx.sym.var("h"), mx.sym.var("c")]
+    else:
+        init_states = [mx.sym.var("h")]
+    out = mx.sym.contrib.foreach(step, data, init_states)
+    out = sym_group(out)
+    arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=(seq_len, batch_size, input_dim),
+            h=(batch_size, hidden_dim))
+    rnn_inputs = out.list_inputs()
+
+    # Inputs
+    args1 = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)}
+    args2 = copy.deepcopy(args1)
+    # gradients for the backward of the foreach symbol
+    args_grad1 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)}
+    # gradients for the backward of the unrolled symbol.
+    args_grad2 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)}
+
+    # Symbol of running LSTM with foreach.
+    out = mx.sym.contrib.foreach(step, data, init_states)
+    out = sym_group(out)
+    js_1 = out.tojson()
+    out = mx.sym.load_json(js_1)
+    js_2 = out.tojson()
+    assert js_1 == js_2
+    e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1)
+
+    # Symbol of running unrolled LSTM.
+    lstm = cell_type(hidden_dim, prefix='')
+    unroll_outs = []
+    states = init_states
+    for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True):
+        h, states = lstm(inputs, states)
+        unroll_outs.append(mx.sym.expand_dims(h, axis=0))
+    unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0))
+    unroll_outs.extend(states)
+    out = mx.sym.Group(unroll_outs)
+    js_1 = out.tojson()
+    out = mx.sym.load_json(js_1)
+    js_2 = out.tojson()
+    assert js_1 == js_2
+    e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2)
+
+    for i in range(5):
+        out_grads = []
+        for arr in e1.outputs:
+            out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape))
+
+        args = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)}
+
+        e1.forward(is_train=True, **args)
+        outputs1 = e1.outputs
+        e1.backward(out_grads)
+
+        e2.forward(is_train=True, **args)
+        outputs2 = e2.outputs
+        e2.backward(out_grads)
+
+        for i in range(len(outputs2)):
+            assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(),
+                    rtol=0.001, atol=0.0001)
+        input_names = out.list_inputs()
+        for i in range(len(e1.grad_arrays)):
+            name = input_names[i]
+            assert_almost_equal(args_grad1[name].asnumpy(), args_grad2[name].asnumpy(),
+                    rtol=0.001, atol=0.0001)
+
+
+@with_seed()
+def test_foreach_rnn():
+    cell_types = [(mx.rnn.LSTMCell, 2), (mx.rnn.RNNCell, 1), (mx.rnn.GRUCell, 1)]
+    for cell_type, num_states in cell_types:
+        check_foreach_rnn(cell_type, num_states)
+
+
 @with_seed()
 def test_squeeze_op():
     def check_squeeze_op(shape, axis=None):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services