You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/04/13 06:12:07 UTC

[incubator-tvm] branch master updated: [Torch] Support Python list, more realistic recurrent networks (#5306)

This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 0145cd5  [Torch] Support Python list, more realistic recurrent networks (#5306)
0145cd5 is described below

commit 0145cd504585e25b776bef83688d10ff0ca44082
Author: masahi <ma...@gmail.com>
AuthorDate: Mon Apr 13 15:11:57 2020 +0900

    [Torch] Support Python list, more realistic recurrent networks (#5306)
    
    * use funcs from prelude, pass around convert_map
    
    * get relay input type from user ishape
    
    * handle tuple unpack
    
    * experimenting with static tensor array
    
    * use prelude concat instead of cons + rev
    
    * minor clean up
    
    * fix layer norm conversion bug, unwrap tensor array
    
    * add infer shape on tensor array
    
    * pass around prelude for now
    
    * compile worked but runtime error
    
    * fix tensor array wrapping
    
    * begin list dynamic test
    
    * is_list_dynamic first version
    
    * finish dynamic list test
    
    * a few fix
    
    * use shape_of function if Any is found
    
    * improve size conversion
    
    * working on adding free vars to loop block
    
    * fixed inlined inner loop issue
    
    * clean up free var handling
    
    * add support for tensor array concat
    
    * adding ta concat on last axis
    
    * fix concat, but got runtime error
    
    * disable concat on axis -1 for now
    
    * add lstm tests
    
    * revert unrelated change
    
    * fix stacked bidir test
    
    * minor fix to test
    
    * relax tol a bit, revert dnnl change to avoid conflict
    
    * simplify infer type, use input tensor shape rather than concat shape
    
    * more shape fix
---
 python/tvm/relay/frontend/pytorch.py          | 618 ++++++++++++++++++--------
 tests/python/frontend/pytorch/lstm_test.py    | 335 ++++++++++++++
 tests/python/frontend/pytorch/test_forward.py |  33 +-
 3 files changed, 787 insertions(+), 199 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index a542ccc..506f6ba 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -25,20 +25,95 @@ import sys
 import numpy as np
 
 import tvm
-from tvm.ir import module as _module
 
 from .. import analysis as _analysis
 from .. import expr as _expr
 from .. import op as _op
+from ..ty import TupleType, TensorType, Any
 from ..loops import while_loop
 from .common import get_relay_op
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
+from .common import infer_type as _infer_type
+from ..prelude import Prelude, StaticTensorArrayOps
 
 from . import qnn_torch
 
 __all__ = ["from_pytorch"]
 
+
+# List ADT utilities
+def _infer_type_with_prelude(val, prelude):
+    body = _infer_type(val, prelude.mod)
+    return body.checked_type
+
+
+def _convert_to_list_adt(py_lst, prelude):
+    elem_tys = [_infer_type_with_prelude(elem, prelude) for elem in py_lst]
+    msg = "List elements should have identical types"
+    assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
+
+    adt_lst = prelude.nil()
+    for elem in reversed(py_lst):
+        adt_lst = prelude.cons(elem, adt_lst)
+    return adt_lst
+
+
+def _map_tensor_array_constructor(adt_lst, prelude, shape):
+    static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
+    static_tensor_array_ops.register()
+    tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape)
+    return prelude.map(tensor_create, adt_lst)
+
+
+def _convert_to_tensor_array(adt_lst, prelude):
+    if prelude.length(adt_lst) == 0:
+        return prelude.nil()
+
+    checked_type = _infer_type_with_prelude(prelude.hd(adt_lst), prelude)
+    shape = checked_type.shape
+    tensor_array = _map_tensor_array_constructor(adt_lst, prelude, shape)
+    return tensor_array, tuple(shape)
+
+
+def _should_construct_dynamic_list(list_construct_node):
+    # if this list is element-accessed or modified at runtime, generate List ADT
+    def is_used_by_list_add(uses):
+        for use in uses:
+            op_name = use.user.kind()
+            output_type = _get_node_type(use.user)
+            if op_name in ["aten::add", "aten::add_"] and output_type == "ListType":
+                return True
+        return False
+
+    def inplace_add_to_add(op_name):
+        if op_name == "aten::add_":
+            return "aten::add"
+        else:
+            return op_name
+
+    uses = _get_uses(list_construct_node)
+
+    for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses):
+        block_input_index = loop_use.offset - 1
+        block = list(loop_use.user.blocks())[0]
+        list_loop_var = list(block.inputs())[block_input_index]
+        uses += _get_uses(list_loop_var.node())
+
+    op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses))
+
+    list_ops = set(["aten::add", "aten::__getitem__", "aten::stack"])
+    intersect = list_ops.intersection(op_names)
+
+    if len(intersect) > 0 and intersect != set(["aten::add"]):
+        return True
+
+    if is_used_by_list_add(filter(lambda use: use.user.kind() != "prim::Loop", uses)):
+        return True
+
+    return False
+
+
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
@@ -103,11 +178,27 @@ def _unsqueeze():
         return _op.transform.expand_dims(data, int(axis), 1)
     return _impl
 
-def _concatenate():
+
+def _concatenate(prelude):
+    def tensor_array_concat(lst, axis):
+        assert axis == 0, "Tensor array concat supported only for axis 0"
+        tensor_array, shape = _convert_to_tensor_array(lst, prelude)
+        concat_shape = (Any(),) + shape[1:]
+        static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
+        static_tensor_array_ops.define_tensor_get_data(concat_shape)
+
+        concat = prelude.get_var_static('tensor_array_concat', "float32", shape)
+        concatenated = concat(tensor_array)
+        get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
+        return get_tensor(concatenated)
+
     def _impl(inputs, input_types):
         data = inputs[0]
         axis = inputs[1]
 
+        if not isinstance(data, list):
+            return tensor_array_concat(data, axis)
+
         if isinstance(data, _expr.Expr):
             data = [data]
 
@@ -130,7 +221,7 @@ def _slice():
         else:
             end = data.shape
 
-        begin = [0]*len(end)
+        begin = [0] * len(end)
         dim = int(inputs[1])
         begin[dim] = int(inputs[2])
 
@@ -371,7 +462,7 @@ def _maxpool_2d():
         ceil_mode = int(inputs[5])
 
         if dilation != (1, 1):
-            msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), )
+            msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation))
             raise NotImplementedError(msg)
 
         return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
@@ -388,7 +479,7 @@ def _maxpool_1d():
         ceil_mode = int(inputs[5])
 
         if dilation != (1,):
-            msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), )
+            msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation))
             raise NotImplementedError(msg)
 
         return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
@@ -404,7 +495,7 @@ def _maxpool_3d():
         dilation = _infer_shape(inputs[4])
         ceil_mode = int(inputs[5])
         if dilation != (1, 1, 1):
-            msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), )
+            msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation))
             raise NotImplementedError(msg)
 
         return _op.nn.max_pool3d(data,
@@ -618,13 +709,13 @@ def _layer_norm():
                                  scale=True)
     return _impl
 
-def _transpose():
+def _transpose(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
 
         import torch
         if isinstance(data, _expr.Expr):
-            ndims = len(_infer_shape(data))
+            ndims = len(_infer_shape(data, prelude.mod))
         elif isinstance(data, list):
             ndims = data
         elif isinstance(data, (torch.Tensor, np.ndarray)):
@@ -693,15 +784,30 @@ def _dense():
             return dense_out
     return _impl
 
-def _size():
+
+def _size(prelude):
+    def _impl_dynamic(inp, axis):
+        shape_dynamic = _op.shape_of(inp)
+        if axis is not None:
+            return _op.take(shape_dynamic, _expr.const(axis), 0)
+        return shape_dynamic
+
     def _impl(inputs, input_types):
-        shape = _infer_shape(inputs[0])
+        shape = _infer_shape(inputs[0], prelude.mod)
+        axis = None
         if len(inputs) > 1:
             axis = int(inputs[1])
+
+        if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)):
+            if axis is None or isinstance(shape[axis], tvm.tir.expr.Any):
+                return _impl_dynamic(inputs[0], axis)
+
+        if axis is not None:
             return shape[axis]
         return shape
     return _impl
 
+
 def _numtotensor():
     def _impl(inputs, input_types):
         val = inputs[0]
@@ -862,7 +968,7 @@ def _mean():
 
     return _impl
 
-def _chunk():
+def _chunk(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
 
@@ -870,7 +976,7 @@ def _chunk():
         axis = int(inputs[2])
 
         if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
+            inferred_shape = _infer_shape(data, prelude.mod)
 
         shape = []
         for infer in inferred_shape:
@@ -894,7 +1000,6 @@ def _chunk():
             chunk_out = _op.transform.strided_slice(data, begin, end, stride)
             chunks.append(chunk_out)
 
-
         if dim % num_chunks:
             begin = [0] * len(shape)
             end = shape[:]
@@ -1077,6 +1182,49 @@ def _Float():
         return _op.cast(inputs[0], "float32")
     return _impl
 
+
+def _mm():
+    def _impl(inputs, input_types):
+        return _op.nn.dense(inputs[0], inputs[1])
+    return _impl
+
+
+def _list_getitem(prelude):
+    def _impl(inputs, input_types):
+        return prelude.nth(inputs[0], _wrap_const(inputs[1]))
+    return _impl
+
+
+def _list_len(prelude):
+    def _impl(inputs, input_types):
+        return prelude.length(inputs[0])
+    return _impl
+
+
+def _add(prelude):
+    # add_ is overloaded for tensor add and list concat
+    def _impl(inputs, input_types):
+        if input_types[0] == "ListType":
+            return prelude.concat(inputs[0], inputs[1])
+        return _elemwise("add")(inputs, input_types)
+    return _impl
+
+
+def _tensor_array_stack(prelude):
+    def _impl(inputs, input_types):
+        tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)
+        stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
+        stacked = stack(tensor_array)
+
+        stacked_shape = (Any(),) + shape
+        static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
+        static_tensor_array_ops.define_tensor_get_data(stacked_shape)
+        # passing stacked_shape below gives "'Prelude' object has no attribute" error
+        get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
+        return get_tensor(stacked)
+    return _impl
+
+
 # Helper functions for operator implementation
 def _convert_dtype_value(val):
     convert_torch_dtype_map = {7:"torch.float64",
@@ -1148,112 +1296,117 @@ def _convert_elemwise_input(data, input_type):
         return data
 
 def _wrap_const(c):
-    if not isinstance(c, _expr.Expr) and not isinstance(c, list):
+    if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)):
         return _expr.const(c)
     return c
 
 # Operator mappings
-
-_convert_map = {
-    "aten::device"                          : _none(),
-    "aten::add"                             : _elemwise("add"),
-    "aten::add_"                            : _elemwise("add"),
-    "aten::sub"                             : _elemwise("subtract"),
-    "aten::sub_"                            : _elemwise("subtract"),
-    "aten::max"                             : _elemwise("maximum"),
-    "aten::min"                             : _elemwise("minimum"),
-    "aten::mul"                             : _elemwise("multiply"),
-    "aten::mul_"                            : _elemwise("multiply"),
-    "aten::pow"                             : _elemwise("power"),
-    "aten::div"                             : _elemwise("divide"),
-    "aten::div_"                            : _elemwise("divide"),
-    "aten::abs"                             : _abs(),
-    "aten::arange"                          : _arange(),
-    "aten::ones"                            : _ones(),
-    "aten::zeros"                           : _zeros(),
-    "aten::reciprocal"                      : _reciprocal(),
-    "aten::repeat"                          : _repeat(),
-    "aten::repeat_interleave"               : _repeat_interleave(),
-    "aten::to"                              : _to(),
-    "aten::squeeze"                         : _squeeze(),
-    "aten::unsqueeze"                       : _unsqueeze(),
-    "aten::cat"                             : _concatenate(),
-    "aten::slice"                           : _slice(),
-    "aten::split"                           : _split(),
-    "aten::split_with_sizes"                : _split_with_sizes(),
-    "aten::select"                          : _select(),
-    "aten::relu"                            : _relu(),
-    "aten::relu_"                           : _relu(),
-    "aten::prelu"                           : _prelu(),
-    "aten::leaky_relu"                      : _leaky_relu(),
-    "aten::elu"                             : _elu(),
-    "aten::celu"                            : _celu(),
-    "aten::gelu"                            : _gelu(),
-    "aten::selu"                            : _selu(),
-    "aten::log_sigmoid"                     : _log_sigmoid(),
-    "aten::adaptive_avg_pool2d"             : _adaptive_avg_pool_2d(),
-    "aten::adaptive_max_pool2d"             : _adaptive_max_pool_2d(),
-    "aten::max_pool2d"                      : _maxpool_2d(),
-    "aten::max_pool2d_with_indices"         : _maxpool_2d(),
-    "aten::max_pool1d"                      : _maxpool_1d(),
-    "aten::max_pool3d"                      : _maxpool_3d(),
-    "aten::hardtanh"                        : _hardtanh(),
-    "aten::hardtanh_"                       : _hardtanh(),
-    "aten::_convolution"                    : _convolution(),
-    "aten::softmax"                         : _softmax(),
-    "aten::threshold"                       : _threshold(),
-    "aten::threshold_"                      : _threshold(),
-    "aten::contiguous"                      : _contiguous(),
-    "aten::batch_norm"                      : _batch_norm(),
-    "aten::instance_norm"                   : _instance_norm(),
-    "aten::layer_norm"                      : _layer_norm(),
-    "aten::transpose"                       : _transpose(),
-    "aten::transpose_"                      : _transpose(),
-    "aten::t"                               : _transpose(),
-    "aten::flatten"                         : _flatten(),
-    "aten::addmm"                           : _dense(),
-    "aten::size"                            : _size(),
-    "aten::view"                            : _view(),
-    "aten::reshape"                         : _reshape(),
-    "aten::clone"                           : _clone(),
-    "aten::log_softmax"                     : _log_softmax(),
-    "aten::sigmoid"                         : _sigmoid(),
-    "aten::softplus"                        : _softplus(),
-    "aten::avg_pool2d"                      : _avg_pool2d(),
-    "aten::avg_pool3d"                      : _avg_pool3d(),
-    "aten::dropout"                         : _dropout(),
-    "aten::dropout_"                        : _dropout(),
-    "aten::feature_dropout"                 : _dropout(),
-    "aten::alpha_dropout"                   : _dropout(),
-    "aten::mean"                            : _mean(),
-    "aten::chunk"                           : _chunk(),
-    "aten::matmul"                          : _matmul(),
-    "aten::expand"                          : _expand(),
-    "aten::Int"                             : _int(),
-    "prim::NumToTensor"                     : _numtotensor(),
-    "prim::ListUnpack"                      : _identity(),
-    "aten::constant_pad_nd"                 : _pad(),
-    "aten::permute"                         : _transpose(),
-    "aten::sum"                             : _reduce("sum"),
-    "aten::prod"                            : _reduce("prod"),
-    "aten::sqrt"                            : _sqrt(),
-    'aten::floor'                           : _floor(),
-    "aten::detach"                          : _identity(),
-    "aten::upsample_bilinear2d"             : _upsample("bilinear"),
-    "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
-    "aten::expand_as"                       : _expand_as(),
-    "aten::lt"                              : _elemwise("less"),
-    "aten::gt"                              : _elemwise("greater"),
-    "aten::le"                              : _elemwise("less_equal"),
-    "aten::ge"                              : _elemwise("greater_equal"),
-    "aten::ne"                              : _elemwise("not_equal"),
-    "aten::Bool"                            : _Bool(),
-    "aten::Float"                           : _Float(),
-    "aten::neg"                             : _neg(),
-    "aten::tanh"                            : _tanh(),
-    "aten::adaptive_avg_pool3d"             : _adaptive_avg_pool_3d(),
-    "aten::adaptive_max_pool3d"             : _adaptive_max_pool_3d()
-}
+def _get_convert_map(prelude):
+    convert_map = {
+        "aten::device"                          : _none(),
+        "aten::sub"                             : _elemwise("subtract"),
+        "aten::sub_"                            : _elemwise("subtract"),
+        "aten::max"                             : _elemwise("maximum"),
+        "aten::min"                             : _elemwise("minimum"),
+        "aten::mul"                             : _elemwise("multiply"),
+        "aten::mul_"                            : _elemwise("multiply"),
+        "aten::pow"                             : _elemwise("power"),
+        "aten::abs"                             : _abs(),
+        "aten::arange"                          : _arange(),
+        "aten::div"                             : _elemwise("divide"),
+        "aten::div_"                            : _elemwise("divide"),
+        "aten::ones"                            : _ones(),
+        "aten::zeros"                           : _zeros(),
+        "aten::reciprocal"                      : _reciprocal(),
+        "aten::repeat"                          : _repeat(),
+        "aten::repeat_interleave"               : _repeat_interleave(),
+        "aten::to"                              : _to(),
+        "aten::squeeze"                         : _squeeze(),
+        "aten::unsqueeze"                       : _unsqueeze(),
+        "aten::cat"                             : _concatenate(prelude),
+        "aten::slice"                           : _slice(),
+        "aten::split"                           : _split(),
+        "aten::split_with_sizes"                : _split_with_sizes(),
+        "aten::select"                          : _select(),
+        "aten::relu"                            : _relu(),
+        "aten::relu_"                           : _relu(),
+        "aten::prelu"                           : _prelu(),
+        "aten::leaky_relu"                      : _leaky_relu(),
+        "aten::elu"                             : _elu(),
+        "aten::celu"                            : _celu(),
+        "aten::gelu"                            : _gelu(),
+        "aten::selu"                            : _selu(),
+        "aten::log_sigmoid"                     : _log_sigmoid(),
+        "aten::adaptive_avg_pool2d"             : _adaptive_avg_pool_2d(),
+        "aten::adaptive_max_pool2d"             : _adaptive_max_pool_2d(),
+        "aten::max_pool2d"                      : _maxpool_2d(),
+        "aten::max_pool2d_with_indices"         : _maxpool_2d(),
+        "aten::max_pool1d"                      : _maxpool_1d(),
+        "aten::max_pool3d"                      : _maxpool_3d(),
+        "aten::hardtanh"                        : _hardtanh(),
+        "aten::hardtanh_"                       : _hardtanh(),
+        "aten::_convolution"                    : _convolution(),
+        "aten::softmax"                         : _softmax(),
+        "aten::threshold"                       : _threshold(),
+        "aten::threshold_"                      : _threshold(),
+        "aten::contiguous"                      : _contiguous(),
+        "aten::batch_norm"                      : _batch_norm(),
+        "aten::instance_norm"                   : _instance_norm(),
+        "aten::layer_norm"                      : _layer_norm(),
+        "aten::transpose"                       : _transpose(prelude),
+        "aten::transpose_"                      : _transpose(prelude),
+        "aten::t"                               : _transpose(prelude),
+        "aten::flatten"                         : _flatten(),
+        "aten::addmm"                           : _dense(),
+        "aten::size"                            : _size(prelude),
+        "aten::view"                            : _view(),
+        "aten::reshape"                         : _reshape(),
+        "aten::clone"                           : _clone(),
+        "aten::log_softmax"                     : _log_softmax(),
+        "aten::sigmoid"                         : _sigmoid(),
+        "aten::softplus"                        : _softplus(),
+        "aten::avg_pool2d"                      : _avg_pool2d(),
+        "aten::avg_pool3d"                      : _avg_pool3d(),
+        "aten::dropout"                         : _dropout(),
+        "aten::dropout_"                        : _dropout(),
+        "aten::feature_dropout"                 : _dropout(),
+        "aten::alpha_dropout"                   : _dropout(),
+        "aten::mean"                            : _mean(),
+        "aten::chunk"                           : _chunk(prelude),
+        "aten::matmul"                          : _matmul(),
+        "aten::expand"                          : _expand(),
+        "aten::Int"                             : _int(),
+        "prim::NumToTensor"                     : _numtotensor(),
+        "aten::constant_pad_nd"                 : _pad(),
+        "aten::permute"                         : _transpose(prelude),
+        "aten::sum"                             : _reduce("sum"),
+        "aten::prod"                            : _reduce("prod"),
+        "aten::sqrt"                            : _sqrt(),
+        'aten::floor'                           : _floor(),
+        "aten::detach"                          : _identity(),
+        "aten::upsample_bilinear2d"             : _upsample("bilinear"),
+        "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
+        "aten::expand_as"                       : _expand_as(),
+        "aten::lt"                              : _elemwise("less"),
+        "aten::gt"                              : _elemwise("greater"),
+        "aten::le"                              : _elemwise("less_equal"),
+        "aten::ge"                              : _elemwise("greater_equal"),
+        "aten::ne"                              : _elemwise("not_equal"),
+        "aten::Bool"                            : _Bool(),
+        "aten::Float"                           : _Float(),
+        "aten::neg"                             : _neg(),
+        "aten::tanh"                            : _tanh(),
+        "aten::adaptive_avg_pool3d"             : _adaptive_avg_pool_3d(),
+        "aten::adaptive_max_pool3d"             : _adaptive_max_pool_3d(),
+        "aten::mm"                              : _matmul(),
+        "relay::tensor_array_stack"             : _tensor_array_stack(prelude),
+        "aten::add"                             : _add(prelude),
+        "aten::add_"                            : _add(prelude),
+        "aten::stack"                           : _tensor_array_stack(prelude),
+        "aten::__getitem__"                     : _list_getitem(prelude),
+        "aten::len"                             : _list_len(prelude),
+    }
+    return convert_map
 
 
 def _run_jit_passes(graph):
@@ -1289,13 +1442,29 @@ def _get_op_inputs(op_node, outputs):
     return [outputs[name] for name in _get_input_names(op_node)]
 
 
-def _report_missing_conversion(op_names):
+def _get_node_type(node):
+    assert node.outputsSize() == 1
+    return node.output().type().kind()
+
+
+def _get_uses(node):
+    uses = []
+    for output in node.outputs():
+        uses += output.uses()
+    return uses
+
+
+def _get_users(node):
+    return [use.user for use in _get_uses(node)]
+
+
+def _report_missing_conversion(op_names, convert_map):
     """ Check if all ops in an input graph are supported by TVM """
     known_ops = ["prim::Constant", "prim::GetAttr",
                  "prim::ListConstruct", "prim::ListUnpack",
                  "prim::TupleConstruct", "prim::TupleUnpack",
                  "prim::If", "prim::Loop"]
-    known_ops += list(_convert_map.keys())
+    known_ops += list(convert_map.keys())
     known_ops += list(qnn_torch.convert_map.keys())
 
     missing = [op_name for op_name in op_names
@@ -1361,7 +1530,7 @@ def _get_input_types(op_node):
                 input_list_types.append(in_ty.scalarType().lower())
 
         elif input_node_kind == 'ListType':
-            input_list_types.append(str(in_ty.getElementType()).lower())
+            input_list_types.append("ListType")
         elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
                                  'StringType', 'OptionalType']:
             input_list_types.append(str(in_ty).lower())
@@ -1422,21 +1591,69 @@ def _get_operator_nodes(nodes):
     return ops
 
 
-def _get_relay_input_vars(graph, input_shapes):
+def _get_graph_input_names(graph):
+    """ Get the graph input names (use after graph copy and run jit passes) """
+    # Variable names could change the first time a copy is made and after
+    # _run_jit_passes is called, expected that those functions already invoked
+    ir_inputs = _get_input_names(graph)
+    return ir_inputs[1:]  # remove self at the 0th arg
+
+
+def _get_relay_input_vars(graph, input_shapes, prelude):
     """
     Return Relay vars from input shapes and create entries based on
     expected graph inputs - to allow translation
     """
+    def get_relay_ty(ishape):
+        if _is_int_seq(ishape) or len(ishape) == 0:
+            return TensorType(ishape)
+        elif isinstance(ishape, tuple):
+            return TupleType([get_relay_ty(elem) for elem in ishape])
+        elif isinstance(ishape, list):
+            assert len(ishape) > 0
+            elem_tys = [get_relay_ty(s) for s in ishape]
+            msg = "List elements should have identical types"
+            assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
+            return prelude.l(elem_tys[0])
+        raise NotImplementedError("unsupported input type")
+
+    input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes]
     input_vars = {}
     ir_inputs = _get_graph_input_names(graph)
-    for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
-        inp = _expr.var(name, shape=shape)
+    for ir_input, (name, itype) in zip(ir_inputs, input_types):
+        inp = _expr.var(name, type_annotation=itype)
         # Translate from graph input to user input name
         input_vars[ir_input] = inp
 
     return input_vars
 
 
+def _unpack_tuple(tup):
+    def unpack(tup, num_fields):
+        return [_expr.TupleGetItem(tup, i) for i in range(num_fields)]
+
+    if isinstance(tup, _expr.Tuple):
+        return unpack(tup, len(tup.fields))
+    elif isinstance(tup.type_annotation, TupleType):
+        return unpack(tup, len(tup.type_annotation.fields))
+    # shouldn't happen
+    assert False
+
+
+def _get_free_vars_from_block(block):
+    block_inp_names = _get_input_names(block)
+    bound_names = block_inp_names
+    free_vars = set()
+
+    for node in block.nodes():
+        inp_names = _get_input_names(node)
+        list_diff = [name for name in inp_names if name not in bound_names]
+        free_vars.update(list_diff)
+        bound_names += _get_output_names(node)
+
+    return free_vars
+
+
 def get_use_chains(root_node, terminate=lambda _: False):
     """
     Track a chain of users of this node forward, returning a list of chains
@@ -1446,9 +1663,7 @@ def get_use_chains(root_node, terminate=lambda _: False):
         return itertools.chain.from_iterable(lists)
 
     def inner(current, accum):
-        users = []
-        for output in current.outputs():
-            users += [use.user for use in output.uses()]
+        users = _get_users(current)
 
         if not users or terminate(users):
             return [accum]
@@ -1512,24 +1727,24 @@ def convert_params(graph, state_dict):
     return params, param_tensors, packed_param_map
 
 
-def convert_block(block, outputs):
+def convert_block(block, outputs, convert_map, prelude):
     """ Translate Torch "Block", used for prim::If and prim::Loop """
     ops = _get_operator_nodes(block.nodes())
     ret_names = _get_input_names(block.returnNode())
-    return convert_operators(ops, outputs, ret_names)
+    return convert_operators(ops, outputs, ret_names, convert_map, prelude)
 
 
-def convert_if(if_node, outputs):
+def convert_if(if_node, outputs, convert_map, prelude):
     """ Translate Torch prim::If to Relay If """
     cond = outputs[if_node.inputsAt(0).debugName()]
     blocks = list(if_node.blocks())
-    true_branch = convert_block(blocks[0], outputs)
-    false_branch = convert_block(blocks[1], outputs)
+    true_branch = convert_block(blocks[0], outputs, convert_map, prelude)
+    false_branch = convert_block(blocks[1], outputs, convert_map, prelude)
     assert len(true_branch) == 1 and len(false_branch) == 1
     return _expr.If(cond, true_branch[0], false_branch[0])
 
 
-def convert_loop(loop_node, outputs):
+def convert_loop(loop_node, outputs, convert_map, prelude):
     """ Translate Torch prim::Loop to Relay while_loop """
     def get_input(index):
         ivalue = loop_node.inputsAt(index)
@@ -1555,8 +1770,54 @@ def convert_loop(loop_node, outputs):
     is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
                      _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)
 
+    if is_while_loop:
+        loop_iter_dtype = "bool"
+        # while loop with non input dependent condition such as while i < 10:
+        # init_cond is int, need to cast to bool to type check
+        if isinstance(init_cond, _expr.Constant):
+            init_cond = _op.cast(init_cond, "bool")
+        init_loop_iter_val = init_cond
+    else:
+        loop_iter_dtype = "int32"
+        # always count from 0
+        init_loop_iter_val = _expr.const(0, dtype="int32")
+
     body_block = list(loop_node.blocks())[0]
     block_input_names = _get_input_names(body_block)
+    num_block_inputs = len(block_input_names)
+    name_val_pairs = list(zip(block_input_names,
+                              [init_loop_iter_val] + init_vals))
+    outputs.update(name_val_pairs)
+
+    def get_var(name, val):
+        if val:
+            checked_type = _infer_type_with_prelude(val, prelude)
+            return _expr.var(name, type_annotation=checked_type)
+        return _expr.var(name)
+
+    loop_iter_var = _expr.var(block_input_names[0], shape=(),
+                              dtype=loop_iter_dtype)
+    loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
+
+    # Add non constant free variables to loop variables to prevent code blow up
+    # Without this, if there are two for loops in a row, which often happens
+    # if the outer loop is unrolled, the computation corresponding to the first for loop
+    # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2).
+    # This issue was found when converting from Stacked LSTM test. Torch does not add the output
+    # of the eariler loop into loop variables of the next loop.
+    # So the variable corresponding to the first loop output appears free in the second loop body.
+    free_vars = [var for var in _get_free_vars_from_block(body_block)
+                 if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float))
+                 and outputs[var]]
+
+    prev_outputs = {}
+    for name in free_vars:
+        prev_output = outputs[name]
+        new_loop_var = get_var(name, prev_output)
+        prev_outputs[name] = prev_output
+        outputs[name] = new_loop_var
+        loop_vars.append(new_loop_var)
+        init_vals.append(prev_output)
 
     def cond(*current_vals):
         i = current_vals[0]
@@ -1568,11 +1829,16 @@ def convert_loop(loop_node, outputs):
 
     def body(*current_vals):
         # Update loop variables using the prev iteration outputs
-        assert len(current_vals) == len(block_input_names)
-        for (i, iname) in enumerate(block_input_names):
-            outputs[iname] = current_vals[i]
+        assert len(current_vals) == num_block_inputs + len(free_vars)
 
-        block_outputs = convert_block(body_block, outputs)
+        for (i, val) in enumerate(current_vals):
+            if i < num_block_inputs:
+                outputs[block_input_names[i]] = val
+            else:
+                outputs[free_vars[i-num_block_inputs]] = val
+
+        block_outputs = convert_block(body_block, outputs, convert_map, prelude)
+        block_outputs += [outputs[name] for name in free_vars]
 
         if not is_while_loop:
             # iter var increment implicit in torch, so do it manually
@@ -1583,38 +1849,17 @@ def convert_loop(loop_node, outputs):
 
         return block_outputs
 
-    def get_var(name, val):
-        if isinstance(val, _expr.Constant):
-            return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
-        return _expr.var(name)
-
-    if is_while_loop:
-        loop_iter_dtype = "bool"
-        # while loop with non input dependent condition such as while i < 10:
-        # init_cond is int, need to cast to bool to type check
-        if isinstance(init_cond, _expr.Constant):
-            init_cond = _op.cast(init_cond, "bool")
-        init_loop_iter_val = init_cond
-    else:
-        loop_iter_dtype = "int32"
-        # always count from 0
-        init_loop_iter_val = _expr.const(0, dtype="int32")
-
-    name_val_pairs = list(zip(block_input_names,
-                              [init_loop_iter_val] + init_vals))
-    outputs.update(name_val_pairs)
-
-    loop_iter_var = _expr.var(block_input_names[0], shape=(),
-                              dtype=loop_iter_dtype)
-    loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
     loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
     loop_val = loop(init_loop_iter_val, *init_vals)
 
+    # restore original output values for free vars
+    outputs.update(prev_outputs)
+
     # The first element is a loop counter or boolean condition, ignore it
     return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
 
 
-def convert_operators(operators, outputs, ret_names):
+def convert_operators(operators, outputs, ret_names, convert_map, prelude):
     """ Convert each Torch IR operators to Relay equivalent """
     for node_name, op_node in operators:
         operator = op_node.kind()
@@ -1622,24 +1867,33 @@ def convert_operators(operators, outputs, ret_names):
 
         if operator == "prim::Constant":
             outputs[node_name] = _get_constant(op_node)
-        elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
+        elif operator == "prim::ListConstruct" and _is_int_seq(inputs):
             outputs[node_name] = _expr.var(node_name, shape=inputs)
-        elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
+        elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node):
+            outputs[node_name] = _convert_to_list_adt(inputs, prelude)
+        elif operator == "prim::ListConstruct":
+            # This assumes that no more elements will be appended to this list
+            # In this case, we keep the Python list
             outputs[node_name] = inputs
-        elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
+        elif operator == "prim::TupleConstruct":
+            outputs[node_name] = _expr.Tuple(inputs)
+        elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]:
             assert len(inputs) == 1
-            unpacked_names = _get_output_names(op_node)
-            outputs.update(zip(unpacked_names, inputs[0]))
+            if isinstance(inputs[0], (list, _expr.TupleWrapper)):
+                unpacked = inputs[0]
+            else:
+                unpacked = _unpack_tuple(inputs[0])
+            outputs.update(zip(_get_output_names(op_node), unpacked))
         elif operator == "prim::If":
-            if_out = convert_if(op_node, outputs)
+            if_out = convert_if(op_node, outputs, convert_map, prelude)
             outputs[node_name] = if_out
         elif operator == "prim::Loop":
-            loop_out = convert_loop(op_node, outputs)
+            loop_out = convert_loop(op_node, outputs, convert_map, prelude)
             unpacked_names = _get_output_names(op_node)
             assert len(loop_out) == len(unpacked_names)
             outputs.update(zip(unpacked_names, loop_out))
         else:
-            relay_op = _convert_map[operator]
+            relay_op = convert_map[operator]
             relay_out = relay_op(inputs, _get_input_types(op_node))
 
             if isinstance(relay_out, tuple):
@@ -1666,14 +1920,6 @@ def get_all_op_names(graph):
     return set(node.kind() for node in nodes)
 
 
-def _get_graph_input_names(graph):
-    """ Get the graph input names (use after graph copy and run jit passes) """
-    # Variable names could change the first time a copy is made and after
-    # _run_jit_passes is called, expected that those functions already invoked
-    ir_inputs = _get_input_names(graph)
-    return ir_inputs[1:]  # remove self at the 0th arg
-
-
 def from_pytorch(script_module, input_shapes, custom_convert_map=None):
     """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
     The companion parameters will be handled automatically.
@@ -1700,18 +1946,23 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
     params : dict of str to tvm.runtime.NDArray
         Dict of converted parameters stored in tvm.runtime.ndarray format
     """
+    mod = tvm.IRModule()
+    prelude = Prelude(mod)
+
+    convert_map = _get_convert_map(prelude)
+
     graph = script_module.graph.copy()
     _run_jit_passes(graph)
 
     if custom_convert_map:
-        _convert_map.update(custom_convert_map)
+        convert_map.update(custom_convert_map)
 
     op_names = get_all_op_names(graph)
-    _report_missing_conversion(op_names)
+    _report_missing_conversion(op_names, convert_map)
     _check_inputs(graph, input_shapes)
 
     params = script_module.state_dict()
-    outputs = _get_relay_input_vars(graph, input_shapes)
+    outputs = _get_relay_input_vars(graph, input_shapes, prelude)
     param_vars, tensors, packed_param_map = convert_params(graph, params)
     tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
@@ -1726,14 +1977,11 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
                                               packed_param_map,
                                               weight_quant_params)
         qnn_torch.add_quant_params(tvm_params, weight_quant_params)
-        _convert_map.update(qnn_torch.convert_map)
+        convert_map.update(qnn_torch.convert_map)
 
     ret = convert_operators(_get_operator_nodes(graph.nodes()),
-                            outputs, ret_name)
-
-    if isinstance(ret[0], list):
-        ret[0] = _expr.Tuple(ret[0])
+                            outputs, ret_name, convert_map, prelude)
 
-    func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
+    mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
 
-    return _module.IRModule.from_expr(func), tvm_params
+    return mod, tvm_params
diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/lstm_test.py
new file mode 100644
index 0000000..4616698
--- /dev/null
+++ b/tests/python/frontend/pytorch/lstm_test.py
@@ -0,0 +1,335 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Tests on torch lstm model conversion """
+# originally from https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
+# described in https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Parameter
+import torch.jit as jit
+from typing import List, Tuple
+from torch import Tensor
+
+import tvm
+from tvm import relay
+from tvm.relay.frontend.pytorch import from_pytorch
+from tvm.relay.prelude import Prelude
+from tvm.runtime.container import ADT, tuple_object
+
+
+class LayerNormLSTMCell(jit.ScriptModule):
+    def __init__(self, input_size, hidden_size):
+        super().__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
+        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
+
+        ln = nn.LayerNorm
+
+        self.layernorm_i = ln(4 * hidden_size)
+        self.layernorm_h = ln(4 * hidden_size)
+        self.layernorm_c = ln(hidden_size)
+
+    @jit.script_method
+    def forward(self, input, state):
+        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        hx, cx = state
+        igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
+        hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
+        gates = igates + hgates
+        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+        ingate = torch.sigmoid(ingate)
+        forgetgate = torch.sigmoid(forgetgate)
+        cellgate = torch.tanh(cellgate)
+        outgate = torch.sigmoid(outgate)
+
+        cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
+        hy = outgate * torch.tanh(cy)
+
+        return hy, (hy, cy)
+
+
+class LSTMLayer(jit.ScriptModule):
+    def __init__(self, cell, *cell_args):
+        super().__init__()
+        self.cell = cell(*cell_args)
+
+    @jit.script_method
+    def forward(self, input, state):
+        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        outputs = []
+        for i in range(input.size(0)):
+            out, state = self.cell(input[i], state)
+            outputs += [out]
+        return torch.stack(outputs), state
+
+
+class ReverseLSTMLayer(jit.ScriptModule):
+    def __init__(self, cell, *cell_args):
+        super(ReverseLSTMLayer, self).__init__()
+        self.cell = cell(*cell_args)
+
+    @jit.script_method
+    def forward(self, inputs, state):
+        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        outputs = jit.annotate(List[Tensor], [])
+        seq_len = inputs.size(0)
+        for i in range(seq_len):
+            out, state = self.cell(inputs[seq_len - i - 1], state)
+            # workaround for the lack of list rev support
+            outputs = [out] + outputs
+        return torch.stack(outputs), state
+
+
+class BidirLSTMLayer(jit.ScriptModule):
+    __constants__ = ['directions']
+
+    def __init__(self, cell, *cell_args):
+        super(BidirLSTMLayer, self).__init__()
+        self.directions = nn.ModuleList([
+            LSTMLayer(cell, *cell_args),
+            ReverseLSTMLayer(cell, *cell_args),
+        ])
+
+    @jit.script_method
+    def forward(self, input, states):
+        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
+        # List[LSTMState]: [forward LSTMState, backward LSTMState]
+        outputs = jit.annotate(List[Tensor], [])
+        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
+        for (i, direction) in enumerate(self.directions):
+            state = states[i]
+            out, out_state = direction(input, state)
+            outputs += [out]
+            output_states += [out_state]
+        # tensor array concat assumes axis == 0 for now
+        # return torch.cat(outputs, -1), output_states
+        return torch.cat(outputs, 0), output_states
+
+
+def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
+    layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
+                                           for _ in range(num_layers - 1)]
+    return nn.ModuleList(layers)
+
+
+class StackedLSTM(jit.ScriptModule):
+    __constants__ = ['layers']  # Necessary for iterating through self.layers
+
+    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
+        super().__init__()
+        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
+                                        other_layer_args)
+
+    @jit.script_method
+    def forward(self, input, states):
+        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
+        # List[LSTMState]: One state per layer
+        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
+        output = input
+        for (i, rnn_layer) in enumerate(self.layers):
+            state = states[i]
+            output, out_state = rnn_layer(output, state)
+            output_states += [out_state]
+        return output, output_states
+
+
+class StackedBidirLSTM(jit.ScriptModule):
+    __constants__ = ['layers']  # Necessary for iterating through self.layers
+
+    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
+        super(StackedBidirLSTM, self).__init__()
+        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
+                                        other_layer_args)
+
+    @jit.script_method
+    def forward(self, input, states):
+        # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
+        # List[List[LSTMState]]: The outer list is for layers,
+        #                        inner list is for directions.
+        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
+        output = input
+        for (i, rnn_layer) in enumerate(self.layers):
+            state = states[i]
+            output, out_state = rnn_layer(output, state)
+            output_states += [out_state]
+        return output, output_states
+
+
+def lstm(input_size, hidden_size):
+    return LSTMLayer(LayerNormLSTMCell, input_size, hidden_size)
+
+
+def stacked_lstm(input_size, hidden_size, num_layers):
+    return StackedLSTM(num_layers, LSTMLayer,
+                       first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
+                       other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
+
+
+def bidir_lstm(input_size, hidden_size):
+    return BidirLSTMLayer(LayerNormLSTMCell, input_size, hidden_size)
+
+
+def stacked_bidir_lstm(input_size, hidden_size, num_layers):
+    return StackedBidirLSTM(num_layers, BidirLSTMLayer,
+                            first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
+                            other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
+
+
+def vmobj_to_list(o, dtype="float32"):
+    if isinstance(o, tvm.nd.NDArray):
+        return [o]
+    elif isinstance(o, tvm.runtime.container.ADT):
+        result = []
+        for f in o:
+            result.extend(vmobj_to_list(f, dtype))
+        return result
+    else:
+        raise RuntimeError("Unknown object type: %s" % type(o))
+
+
+def assert_equal(tvm_result, torch_result):
+    if isinstance(torch_result, (tuple, list)):
+        assert isinstance(tvm_result, list)
+        for tvm_res, pt_res in zip(tvm_result, torch_result):
+            assert_equal(tvm_res, pt_res)
+    elif isinstance(torch_result, torch.Tensor):
+        tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(),
+                                    rtol=1e-4, atol=1e-4)
+
+
+def run_and_compare(mod, params, pt_result):
+    executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
+    evaluator = executor.evaluate()
+    exec_res = evaluator(**params)
+
+    def flatten(nested):
+        res = []
+        for r in nested:
+            if isinstance(r, torch.Tensor):
+                res.append(r)
+            else:
+                res.extend(flatten(r))
+        return res
+
+    if isinstance(exec_res, tvm.runtime.container.ADT):
+        assert not isinstance(pt_result, torch.Tensor)
+        tvm_res = vmobj_to_list(exec_res)
+        torch_res = flatten(pt_result)
+    else:
+        tvm_res = exec_res
+        torch_res = pt_result
+
+    assert_equal(tvm_res, torch_res)
+
+
+def convert_list_to_vmobj(py_lst):
+    def wrap_nd_array(arr):
+        return tvm.nd.array(arr, ctx=tvm.cpu(0))
+
+    mod = tvm.IRModule()
+    prelude = Prelude(mod)
+    adt_lst = ADT(prelude.nil.tag, [])
+    for elem in reversed(py_lst):
+        if isinstance(elem, np.ndarray):
+            vmobj = wrap_nd_array(elem)
+        elif isinstance(elem, tuple):
+            vmobj = tuple_object([wrap_nd_array(e) for e in elem])
+        elif isinstance(elem, list):
+            vmobj = convert_list_to_vmobj(elem)
+        adt_lst = ADT(prelude.cons.tag, [vmobj, adt_lst])
+    return adt_lst
+
+
+def custom_lstm_test():
+    input_name = "input"
+    states_name = "states"
+    seq_len = 5
+    batch = 2
+    input_size = 3
+    hidden_size = 4
+    num_layers = 3
+    state_tensor_shape = (batch, hidden_size)
+
+    inp = torch.randn(seq_len, batch, input_size)
+
+    input_shapes = [(input_name, (seq_len, batch, input_size)),
+                    (states_name, (state_tensor_shape, state_tensor_shape))]
+
+    input_shapes_stacked = [(input_name, (seq_len, batch, input_size)),
+                            (states_name, [(state_tensor_shape, state_tensor_shape),
+                                           (state_tensor_shape, state_tensor_shape)])]
+
+    input_shapes_stacked_bidir = [(input_name, (seq_len, batch, input_size)),
+                                  (states_name, [[(state_tensor_shape,
+                                                   state_tensor_shape)
+                                                  for _ in range(2)]
+                                                 for _ in range(num_layers)])]
+
+    states = [(torch.randn(state_tensor_shape),
+               torch.randn(state_tensor_shape))
+              for _ in range(num_layers)]
+
+    bidir_states = [(torch.randn(state_tensor_shape),
+                     torch.randn(state_tensor_shape))
+                    for _ in range(2)]
+
+    stacked_bidir_states = [[(torch.randn(state_tensor_shape),
+                              torch.randn(state_tensor_shape))
+                             for _ in range(2)]
+                            for _ in range(num_layers)]
+
+    models = [
+      (lstm(input_size, hidden_size).eval(), states[0], input_shapes),
+      (stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked),
+      (bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked),
+      (stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(),
+       stacked_bidir_states, input_shapes_stacked_bidir)
+    ]
+
+    for (raw_model, states, input_shapes) in models:
+        script_module = torch.jit.script(raw_model)
+        mod, params = from_pytorch(script_module, input_shapes)
+
+        with torch.no_grad():
+            pt_result = raw_model(inp.clone(), states)
+
+        params[input_name] = inp.numpy()
+
+        if isinstance(states, tuple):
+            states_np = tuple(st.numpy() for st in states)
+        elif isinstance(states, list) and isinstance(states[0], torch.Tensor):
+            states_np = [st.numpy() for st in states]
+        elif isinstance(states, list) and isinstance(states[0], tuple):
+            states_np = [tuple(st.numpy() for st in states[i])
+                         for i in range(len(states))]
+        elif isinstance(states, list) and isinstance(states[0], list):
+            states_np = [[tuple(st.numpy() for st in states)
+                         for states in states[layer]]
+                         for layer in range(num_layers)]
+        else:
+            assert False
+
+        if isinstance(states_np, list):
+            params[states_name] = convert_list_to_vmobj(states_np)
+        else:
+            params[states_name] = states_np
+
+        run_and_compare(mod, params, pt_result)
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index d60ab9e..8e99285 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -526,13 +526,13 @@ def test_forward_maxpool2d():
     input_data = torch.rand(input_shape).float()
 
     verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(),
-                input_data)
+                 input_data)
     verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(),
-                input_data)
+                 input_data)
     verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4],
                                     padding=2,
                                     stride=2).eval(),
-                input_data)
+                 input_data)
 
 def test_forward_maxpool1d():
     torch.set_grad_enabled(False)
@@ -540,13 +540,13 @@ def test_forward_maxpool1d():
     input_data = torch.rand(input_shape).float()
 
     verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(),
-                input_data)
+                 input_data)
     verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(),
-                input_data)
-    verify_model( torch.nn.MaxPool1d(kernel_size=4,
+                 input_data)
+    verify_model(torch.nn.MaxPool1d(kernel_size=4,
                                     padding=2,
                                     stride=2).eval(),
-                input_data)
+                 input_data)
 
 def test_forward_maxpool3d():
     torch.set_grad_enabled(False)
@@ -554,13 +554,13 @@ def test_forward_maxpool3d():
     input_data = torch.rand(input_shape).float()
 
     verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(),
-                input_data)
+                 input_data)
     verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(),
-                input_data)
+                 input_data)
     verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4],
                                     padding=2,
                                     stride=2).eval(),
-                input_data)
+                 input_data)
 
 def test_forward_split():
     torch.set_grad_enabled(False)
@@ -577,13 +577,13 @@ def test_forward_split():
 
     input_data = torch.rand(input_shape).float()
     verify_model(Split(2, 0).float().eval(),
-                input_data=input_data)
+                 input_data=input_data)
     verify_model(Split(3, 1).float().eval(),
-                input_data=input_data)
+                 input_data=input_data)
     verify_model(Split(4, 1).float().eval(),
-                input_data=input_data)
+                 input_data=input_data)
     verify_model(Split([2, 3, 5], 1).float().eval(),
-                input_data=input_data)
+                 input_data=input_data)
 
 def test_forward_avgpool():
     torch.set_grad_enabled(False)
@@ -1363,3 +1363,8 @@ if __name__ == "__main__":
     # Test simple conditionals and loop
     test_control_flow()
     test_simple_rnn()
+
+    # More complex recurrent models
+    from lstm_test import custom_lstm_test
+
+    custom_lstm_test()