You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/03/10 17:59:22 UTC

[incubator-tvm] branch master updated: [Torch] Add initial control flow support (#4964)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 06e9542  [Torch] Add initial control flow support  (#4964)
06e9542 is described below

commit 06e9542ee0bfd014bd06a4dd4fdb3af9d2d29eb0
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Mar 11 02:59:09 2020 +0900

    [Torch] Add initial control flow support  (#4964)
    
    * Add support for prim::If and prim::Loop with test cases
    
    * rebase and fix tests
    
    * add some comments
    
    * simplifying, fix float cast
    
    * parse -> convert
    
    * recursivly retrive ops in get_all_op_names
    
    * use multiple return values from block correctly, simplify loop convert
    
    * choose dtype properly for zeros and ones
    
    * simplifying, replace convert_inputs with _get_relay_input_vars
    
    * fix for while loop with non input dependent init cond
    
    * add assert on loop var update
    
    * move the condition around
    
    * better testing for seg models
    
    * rebase fix, disable inception v3 in quant test as it is too slow to
    load with torch-1.4 + torchvision 0.5
    
    * simplify and add more comparison op converter
---
 python/tvm/relay/frontend/pytorch.py          | 223 ++++++++++++++++++++++----
 tests/python/frontend/pytorch/qnn_test.py     |   3 +-
 tests/python/frontend/pytorch/test_forward.py | 197 ++++++++++++++++++++++-
 3 files changed, 385 insertions(+), 38 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index ff37f82..6da91c1 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -20,6 +20,7 @@
 """PT: PyTorch frontend."""
 import itertools
 import logging
+import sys
 
 import numpy as np
 
@@ -29,6 +30,7 @@ from tvm.ir import module as _module
 from .. import analysis as _analysis
 from .. import expr as _expr
 from .. import op as _op
+from ..loops import while_loop
 from .common import get_relay_op
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
@@ -107,9 +109,8 @@ def _select():
     def _impl(inputs, input_types):
         data = inputs[0]
         dim = int(inputs[1])
-        index = int(inputs[2])
-
-        return _op.transform.take(data, _expr.const(index, dtype="int32"), axis=dim)
+        index = _wrap_const(inputs[2])
+        return _op.transform.take(data, index, axis=dim)
     return _impl
 
 def _ones():
@@ -126,7 +127,10 @@ def _ones():
         else:
             assert "data type {} could not be parsed in ones op" % (type(data))
 
-        return _op.full(_expr.const(1), shape, dtype=_convert_data_type(input_types[0]))
+        dtype_map = {6: "float32", 3: "int32"}
+        dtype_id = inputs[1]
+        assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
+        return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id])
     return _impl
 
 def _zeros():
@@ -143,7 +147,10 @@ def _zeros():
         else:
             assert "data type {} could not be parsed in zeros op" % (type(data))
 
-        return _op.full(_expr.const(0), shape, dtype=_convert_data_type(input_types[0]))
+        dtype_map = {6: "float32", 3: "int32"}
+        dtype_id = inputs[1]
+        assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
+        return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id])
     return _impl
 
 def _relu():
@@ -222,12 +229,10 @@ def _convolution():
         else:
             assert "data type {} could not be parsed in conv op" % (type(weight))
 
-        # TODO: Add reshape when channel multiplier > 1. Pending PR #4644
         channels = weight_shape[0]
         groups = int(inputs[8])
 
         if groups > 1:
-            # in torch, groups == in_channels for depth wise conv
             channel_multiplier = channels // groups
             new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
             weight = _op.transform.reshape(weight, new_weight_shape)
@@ -496,7 +501,7 @@ def _dropout():
     return _impl
 
 def _reduce(name):
-    def _impl(inputs, attrs, params):
+    def _impl(inputs, input_types):
         data = inputs[0]
         return get_relay_op(name)(data)
     return _impl
@@ -714,7 +719,6 @@ def _upsample(method):
 
     return _impl
 
-
 def _expand_as():
     def _impl(inputs, input_types):
         # TODO: maybe fix this
@@ -724,6 +728,29 @@ def _expand_as():
         return inputs[0]
     return _impl
 
+def _neg():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _op.tensor.negative(data)
+    return _impl
+
+def _tanh():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _op.tensor.tanh(data)
+    return _impl
+
+def _Bool():
+    def _impl(inputs, input_types):
+        assert len(inputs) == 1
+        return inputs[0]
+    return _impl
+
+def _Float():
+    def _impl(inputs, input_types):
+        assert len(inputs) == 1
+        return _op.cast(inputs[0], "float32")
+    return _impl
 
 # Helper functions for operator implementation
 
@@ -780,6 +807,11 @@ def _convert_elemwise_input(data, input_type):
     else:
         return data
 
+def _wrap_const(c):
+    if not isinstance(c, _expr.Expr) and not isinstance(c, list):
+        return _expr.const(c)
+    return c
+
 # Operator mappings
 
 _convert_map = {
@@ -845,7 +877,16 @@ _convert_map = {
     "aten::detach"                          : _identity(),
     "aten::upsample_bilinear2d"             : _upsample("bilinear"),
     "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
-    "aten::expand_as"                       : _expand_as()
+    "aten::expand_as"                       : _expand_as(),
+    "aten::lt"                              : _elemwise("less"),
+    "aten::gt"                              : _elemwise("greater"),
+    "aten::le"                              : _elemwise("less_equal"),
+    "aten::ge"                              : _elemwise("greater_equal"),
+    "aten::ne"                              : _elemwise("not_equal"),
+    "aten::Bool"                            : _Bool(),
+    "aten::Float"                           : _Float(),
+    "aten::neg"                             : _neg(),
+    "aten::tanh"                            : _tanh(),
 }
 
 
@@ -894,7 +935,8 @@ def _report_missing_conversion(op_names):
     """ Check if all ops in an input graph are supported by TVM """
     known_ops = ["prim::Constant", "prim::GetAttr",
                  "prim::ListConstruct", "prim::ListUnpack",
-                 "prim::TupleConstruct", "prim::TupleUnpack"]
+                 "prim::TupleConstruct", "prim::TupleUnpack",
+                 "prim::If", "prim::Loop"]
     known_ops += list(_convert_map.keys())
     known_ops += list(qnn_torch.convert_map.keys())
 
@@ -939,9 +981,13 @@ def _get_input_types(op_node):
         input_node_kind = in_ty.kind()
         if input_node_kind == 'TensorType':
             if in_ty.scalarType() is None:
-                input_list_types.append(None)
+                # Tensor's type can be unknown if we use torch.jit.script(...)
+                # Defaults to float for now
+                logging.warning("Untyped Tensor found, assume it is float")
+                input_list_types.append("float")
             else:
                 input_list_types.append(in_ty.scalarType().lower())
+
         elif input_node_kind == 'ListType':
             input_list_types.append(str(in_ty.getElementType()).lower())
         elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
@@ -1004,15 +1050,10 @@ def _get_operator_nodes(nodes):
     return ops
 
 
-def parse_inputs(graph_inputs, input_shapes):
-    """ Return Relay vars from torch input vars """
-    ir_inputs = list(graph_inputs)
-    input_vars = {}
-
-    for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
-        input_vars[input_name] = _expr.var(input_name,
-                                           shape=input_shapes[input_name])
-    return input_vars
+def _get_relay_input_vars(input_shapes):
+    """ Return Relay vars from input shapes """
+    return {iname: _expr.var(iname, shape=ishape)
+            for iname, ishape in input_shapes.items()}
 
 
 def get_use_chains(root_node, terminate=lambda _: False):
@@ -1055,7 +1096,7 @@ def get_attr_chains(root_getattr_node):
     return get_use_chains(root_getattr_node, terminate)
 
 
-def parse_params(graph, state_dict):
+def convert_params(graph, state_dict):
     """
     Return Relay vars and TVM NDArrays for input parameters
     A chain of prim::GetAttr nodes is processed one at a time
@@ -1090,7 +1131,109 @@ def parse_params(graph, state_dict):
     return params, param_tensors, packed_param_map
 
 
-def parse_operators(operators, outputs, output_index_map, ret_name):
+def convert_block(block, outputs, output_index_map):
+    """ Translate Torch "Block", used for prim::If and prim::Loop """
+    ops = _get_operator_nodes(block.nodes())
+    ret_names = _get_input_names(block.returnNode())
+    return convert_operators(ops, outputs, output_index_map, ret_names)
+
+
+def convert_if(if_node, outputs, output_index_map):
+    """ Translate Torch prim::If to Relay If """
+    cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
+    blocks = list(if_node.blocks())
+    true_branch = convert_block(blocks[0], outputs, output_index_map)
+    false_branch = convert_block(blocks[1], outputs, output_index_map)
+    assert len(true_branch) == 1 and len(false_branch) == 1
+    return _expr.If(cond, true_branch[0], false_branch[0])
+
+
+def convert_loop(loop_node, outputs, output_index_map):
+    """ Translate Torch prim::Loop to Relay while_loop """
+    def get_input(index):
+        ivalue = loop_node.inputsAt(index)
+        inode = ivalue.node()
+        if inode.kind() == "prim::Constant":
+            return _expr.const(_get_constant(inode))
+        var_name = ivalue.debugName()
+        assert var_name in output_index_map
+        return _wrap_const(outputs[output_index_map[var_name]])
+
+    # Refer to the spec for prim::Loop below
+    # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
+    # The first input: %max_trip_count
+    # The second input: %initial_condition
+    # The rest of input: loop variables
+    max_loop_count = get_input(0)
+    init_cond = get_input(1)
+    num_loop_var = len(list(loop_node.inputs())) - 2
+    init_vals = [get_input(i + 2) for i in range(num_loop_var)]
+
+    # while loop has always max_loop_count being int64 max
+    # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again
+    is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
+                     _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)
+
+    body_block = list(loop_node.blocks())[0]
+    block_input_names = _get_input_names(body_block)
+
+    def cond(*current_vals):
+        i = current_vals[0]
+
+        if is_while_loop:
+            return _op.equal(i, _expr.const(True, 'bool'))
+
+        return _op.less(i, max_loop_count)
+
+    def body(*current_vals):
+        # Update loop variables using the prev iteration outputs
+        assert len(current_vals) == len(block_input_names)
+        for (i, iname) in enumerate(block_input_names):
+            outputs[output_index_map[iname]] = current_vals[i]
+
+        block_outputs = convert_block(body_block, outputs, output_index_map)
+
+        if not is_while_loop:
+            # iter var increment implicit in torch, so do it manually
+            # for while loop, block_outputs[0] is already a boolean,
+            # the result of termination check
+            incr = _expr.const(1, dtype="int32")
+            block_outputs[0] = current_vals[0] + incr
+
+        return block_outputs
+
+    def get_var(name, val):
+        if isinstance(val, _expr.Constant):
+            return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
+        return _expr.var(name)
+
+    if is_while_loop:
+        loop_iter_dtype = "bool"
+        # while loop with non input dependent condition such as while i < 10:
+        # init_cond is int, need to cast to bool to type check
+        if isinstance(init_cond, _expr.Constant):
+            init_cond = _op.cast(init_cond, "bool")
+        init_loop_iter_val = init_cond
+    else:
+        loop_iter_dtype = "int32"
+        # always count from 0
+        init_loop_iter_val = _expr.const(0, dtype="int32")
+
+    name_val_pairs = list(zip(block_input_names,
+                              [init_loop_iter_val] + init_vals))
+    _update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)
+
+    loop_iter_var = _expr.var(block_input_names[0], shape=(),
+                              dtype=loop_iter_dtype)
+    loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
+    loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
+    loop_val = loop(init_loop_iter_val, *init_vals)
+
+    # The first element is a loop counter or boolean condition, ignore it
+    return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
+
+
+def convert_operators(operators, outputs, output_index_map, ret_names):
     """ Convert each Torch IR operators to Relay equivalent """
     for node_name, op_node in operators:
         operator = op_node.kind()
@@ -1110,17 +1253,35 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
             unpacked_names = _get_output_names(op_node)
             _update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
                                        outputs, output_index_map)
+        elif operator == "prim::If":
+            if_out = convert_if(op_node, outputs, output_index_map)
+            output_index_map[node_name] = len(outputs)
+            outputs.append(if_out)
+        elif operator == "prim::Loop":
+            loop_out = convert_loop(op_node, outputs, output_index_map)
+            unpacked_names = _get_output_names(op_node)
+            assert len(loop_out) == len(unpacked_names)
+            _update_outputs_from_pairs(zip(unpacked_names, loop_out),
+                                       outputs, output_index_map)
         else:
             output_index_map[node_name] = len(outputs)
             relay_op = _convert_map[operator]
             outputs.append(relay_op(inputs, _get_input_types(op_node)))
 
-    return outputs[output_index_map[ret_name]]
+    return [_wrap_const(outputs[output_index_map[ret_name]])
+            for ret_name in ret_names]
 
 
 def get_all_op_names(graph):
     """ Return all operator names in the input graph """
-    return set(node.kind() for node in graph.nodes())
+    nodes = list(graph.nodes())
+    prim_with_blocks = ["prim::If", "prim::Loop"]
+    for prim in prim_with_blocks:
+        prim_nodes = graph.findAllNodes(prim, recurse=True)
+        for prim_node in prim_nodes:
+            for block in prim_node.blocks():
+                nodes += block.nodes()
+    return set(node.kind() for node in nodes)
 
 
 def get_graph_input_names(script_module):
@@ -1167,14 +1328,14 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
     _check_input_names(script_module, input_shapes)
 
     params = script_module.state_dict()
-    input_vars = parse_inputs(graph.inputs(), input_shapes)
-    param_vars, tensors, packed_param_map = parse_params(graph, params)
+    input_vars = _get_relay_input_vars(input_shapes)
+    param_vars, tensors, packed_param_map = convert_params(graph, params)
     tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
     input_vars.update(param_vars)
     outputs = list(input_vars.values())
     output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
-    ret_name = _get_input_names(graph.return_node())[0]
+    ret_name = _get_input_names(graph.return_node())
 
     # For quantized models
     if "aten::quantize_per_tensor" in op_names:
@@ -1186,8 +1347,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
         qnn_torch.add_quant_params(tvm_params, weight_quant_params)
         _convert_map.update(qnn_torch.convert_map)
 
-    body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
-                           output_index_map, ret_name)
-    func = tvm.relay.Function(_analysis.free_vars(body), body)
+    ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
+                            output_index_map, ret_name)
+    func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
 
     return _module.IRModule.from_expr(func), tvm_params
diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py
index e3a876c..23fcb7c 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -347,7 +347,8 @@ def test_quantized_imagenet():
         qmodels += [
             ("resnet18", qresnet.resnet18(pretrained=True), per_channel),
             ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
-            ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
+            # disable inception test for now, since loading it takes ~5min on torchvision-0.5
+            #("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
             ("googlenet", qgooglenet(pretrained=True), per_channel),
         ]
 
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index eed47ea..59f93b4 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -756,7 +756,6 @@ def test_vgg11_bn():
     verify_model("vgg11_bn")
 """
 
-
 def test_custom_conversion_map():
     def get_roi_align():
         pool_size = 5
@@ -801,11 +800,193 @@ def test_segmentaton_models():
 
     inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
 
-    for model in [fcn, deeplab]:
-        # depthwise + dilated covolution not supported on x86
-        # see https://github.com/apache/incubator-tvm/issues/4962
-        verify_model(SegmentationModelWrapper(model.eval()), inp,
-                     ctx_list=[("cuda", tvm.gpu(0))])
+    verify_model(SegmentationModelWrapper(fcn.eval()), inp)
+
+    # depthwise + dilated covolution not supported on x86
+    # see https://github.com/apache/incubator-tvm/issues/4962
+    cuda_ctx = ("cuda", tvm.gpu(0))
+    if cuda_ctx[1].exist:
+        verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
+
+
+def verify_script_model(pt_model, ishapes):
+    script_module = torch.jit.script(pt_model)
+    input_names = get_graph_input_names(script_module)
+    input_shapes = dict(zip(input_names, ishapes))
+
+    inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
+              for input_name in input_names]
+
+    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+
+    executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
+                                     target="llvm")
+    evaluator = executor.evaluate()
+
+    for name, inp in zip(input_names, inputs):
+        params[name] = inp.numpy()
+
+    op_res = evaluator(**params)
+
+    with torch.no_grad():
+        pt_result = pt_model(*inputs)
+
+    if not isinstance(pt_result, torch.Tensor):
+        tvm_res = op_res.asnumpy().item()
+        assert pt_result == tvm_res
+    else:
+        tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
+                                    rtol=1e-5, atol=1e-5)
+
+
+def test_control_flow():
+    class SimpleIf(torch.nn.Module):
+        def __init__(self, N, M):
+            super().__init__()
+            self.weight = torch.nn.Parameter(torch.rand(N, M))
+
+        def forward(self, inp):
+            if inp.sum() > 0.:
+                output = self.weight + inp
+            else:
+                output = self.weight - inp
+            return output
+
+    class NestedIf(torch.nn.Module):
+        def __init__(self, N, M):
+            super().__init__()
+            self.weight = torch.nn.Parameter(torch.rand(N, M))
+
+        def forward(self, inp):
+            if inp.sum() > 0.:
+                if inp.mean() > 0.:
+                    output = self.weight + inp
+                else:
+                    output = self.weight - inp
+            else:
+                if inp.mean() >= 0.:
+                    output = self.weight * inp
+                else:
+                    output = self.weight / inp
+
+            return output
+
+    class ScalarLoop(torch.nn.Module):
+        def forward(self, inp):
+            a = 0
+            for i in range(inp.size(0)):
+                b = i * i
+                b = b + 1
+                a += b
+            if a != 0:
+                a += 1
+            else:
+                a += 2
+            return a
+
+    class SimpleLoop(torch.nn.Module):
+        def forward(self, inp):
+            a = inp
+            for i in range(inp.size(0)):
+                b = a * 2.
+                c = a + b
+                a += c
+            return a
+
+    class LoopWithIf(torch.nn.Module):
+        def forward(self, inp):
+            a = inp
+            for i in range(inp.size(0)):
+                b = a * 2.
+                b = a + b
+                if b.sum() > 0.0:
+                    a += b
+                else:
+                    a -= b
+            return a
+
+    class NestedLoop(torch.nn.Module):
+        def forward(self, inp):
+            a = inp
+            for i in range(inp.size(0)):
+                b = a * float(i)
+                for j in range(inp.size(1)):
+                    a += b * float(j)
+            return a
+
+    class SimpleScalarWhileLoop(torch.nn.Module):
+        def forward(self, inp):
+            a = 1
+            i = 0
+            while i <= inp.size(0):
+                a += i
+                i += 2
+            i = 0
+            # also test constant init cond
+            while i < 10:
+                a += i
+                i += 3
+            return a
+
+    class SimpleWhileLoop(torch.nn.Module):
+        def forward(self, inp):
+            a = inp
+            i = 0
+            while i < inp.size(0):
+                a += a * float(i) * 2.0
+                i += 1
+            return a
+
+    models = [
+        SimpleIf(10, 20),
+        NestedIf(10, 20),
+        ScalarLoop(),
+        SimpleLoop(),
+        LoopWithIf(),
+        SimpleScalarWhileLoop(),
+        SimpleWhileLoop(),
+        NestedLoop(),
+    ]
+
+    for pt_model in models:
+        verify_script_model(pt_model.eval(), [(10, 20)])
+
+
+def test_simple_rnn():
+    # The mixed tracing and scripting example from
+    # https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#mixing-scripting-and-tracing
+    class DecisionGate(torch.nn.Module):
+        def forward(self, x):
+            if x.sum() > 0:
+                return x
+            else:
+                return -x
+
+    class Cell(torch.nn.Module):
+        def __init__(self, dg):
+            super(Cell, self).__init__()
+            self.dg = dg
+            self.linear = torch.nn.Linear(4, 4)
+
+        def forward(self, x, h):
+            new_h = torch.tanh(self.dg(self.linear(x)) + h)
+            return new_h, new_h
+
+    class RNNLoop(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            x = torch.rand(10, 4, dtype=torch.float)
+            h = torch.rand(10, 4, dtype=torch.float)
+            self.cell = torch.jit.trace(Cell(DecisionGate()), (x, h))
+
+        def forward(self, xs):
+            h = torch.zeros(10, 4, dtype=torch.float)
+            y = torch.zeros(10, 4, dtype=torch.float)
+            for i in range(xs.size(0)):
+                y, h = self.cell(xs[i], h)
+            return y
+
+    verify_script_model(RNNLoop().eval(), [(10, 10, 4)])
 
 
 if __name__ == "__main__":
@@ -860,3 +1041,7 @@ if __name__ == "__main__":
 
     test_quantized_modules()
     test_quantized_imagenet()
+
+    # Test simple conditionals and loop
+    test_control_flow()
+    test_simple_rnn()