You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ke...@apache.org on 2020/04/13 06:12:07 UTC
[incubator-tvm] branch master updated: [Torch] Support Python list,
more realistic recurrent networks (#5306)
This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 0145cd5 [Torch] Support Python list, more realistic recurrent networks (#5306)
0145cd5 is described below
commit 0145cd504585e25b776bef83688d10ff0ca44082
Author: masahi <ma...@gmail.com>
AuthorDate: Mon Apr 13 15:11:57 2020 +0900
[Torch] Support Python list, more realistic recurrent networks (#5306)
* use funcs from prelude, pass around convert_map
* get relay input type from user ishape
* handle tuple unpack
* experimenting with static tensor array
* use prelude concat instead of cons + rev
* minor clean up
* fix layer norm conversion bug, unwrap tensor array
* add infer shape on tensor array
* pass around prelude for now
* compile worked but runtime error
* fix tensor array wrapping
* begin list dynamic test
* is_list_dynamic first version
* finish dynamic list test
* a few fix
* use shape_of function if Any is found
* improve size conversion
* working on adding free vars to loop block
* fixed inlined inner loop issue
* clean up free var handling
* add support for tensor array concat
* adding ta concat on last axis
* fix concat, but got runtime error
* disable concat on axis -1 for now
* add lstm tests
* revert unrelated change
* fix stacked bidir test
* minor fix to test
* relax tol a bit, revert dnnl change to avoid conflict
* simplify infer type, use input tensor shape rather than concat shape
* more shape fix
---
python/tvm/relay/frontend/pytorch.py | 618 ++++++++++++++++++--------
tests/python/frontend/pytorch/lstm_test.py | 335 ++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 33 +-
3 files changed, 787 insertions(+), 199 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index a542ccc..506f6ba 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -25,20 +25,95 @@ import sys
import numpy as np
import tvm
-from tvm.ir import module as _module
from .. import analysis as _analysis
from .. import expr as _expr
from .. import op as _op
+from ..ty import TupleType, TensorType, Any
from ..loops import while_loop
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
+from .common import infer_type as _infer_type
+from ..prelude import Prelude, StaticTensorArrayOps
from . import qnn_torch
__all__ = ["from_pytorch"]
+
+# List ADT utilities
+def _infer_type_with_prelude(val, prelude):
+ body = _infer_type(val, prelude.mod)
+ return body.checked_type
+
+
+def _convert_to_list_adt(py_lst, prelude):
+ elem_tys = [_infer_type_with_prelude(elem, prelude) for elem in py_lst]
+ msg = "List elements should have identical types"
+ assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
+
+ adt_lst = prelude.nil()
+ for elem in reversed(py_lst):
+ adt_lst = prelude.cons(elem, adt_lst)
+ return adt_lst
+
+
+def _map_tensor_array_constructor(adt_lst, prelude, shape):
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
+ static_tensor_array_ops.register()
+ tensor_create = prelude.get_var_static('tensor_constructor', "float32", shape)
+ return prelude.map(tensor_create, adt_lst)
+
+
+def _convert_to_tensor_array(adt_lst, prelude):
+ if prelude.length(adt_lst) == 0:
+ return prelude.nil()
+
+ checked_type = _infer_type_with_prelude(prelude.hd(adt_lst), prelude)
+ shape = checked_type.shape
+ tensor_array = _map_tensor_array_constructor(adt_lst, prelude, shape)
+ return tensor_array, tuple(shape)
+
+
+def _should_construct_dynamic_list(list_construct_node):
+ # if this list is element-accessed or modified at runtime, generate List ADT
+ def is_used_by_list_add(uses):
+ for use in uses:
+ op_name = use.user.kind()
+ output_type = _get_node_type(use.user)
+ if op_name in ["aten::add", "aten::add_"] and output_type == "ListType":
+ return True
+ return False
+
+ def inplace_add_to_add(op_name):
+ if op_name == "aten::add_":
+ return "aten::add"
+ else:
+ return op_name
+
+ uses = _get_uses(list_construct_node)
+
+ for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses):
+ block_input_index = loop_use.offset - 1
+ block = list(loop_use.user.blocks())[0]
+ list_loop_var = list(block.inputs())[block_input_index]
+ uses += _get_uses(list_loop_var.node())
+
+ op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses))
+
+ list_ops = set(["aten::add", "aten::__getitem__", "aten::stack"])
+ intersect = list_ops.intersection(op_names)
+
+ if len(intersect) > 0 and intersect != set(["aten::add"]):
+ return True
+
+ if is_used_by_list_add(filter(lambda use: use.user.kind() != "prim::Loop", uses)):
+ return True
+
+ return False
+
+
# operator implementation
def _elemwise(name):
def _impl(inputs, input_types):
@@ -103,11 +178,27 @@ def _unsqueeze():
return _op.transform.expand_dims(data, int(axis), 1)
return _impl
-def _concatenate():
+
+def _concatenate(prelude):
+ def tensor_array_concat(lst, axis):
+ assert axis == 0, "Tensor array concat supported only for axis 0"
+ tensor_array, shape = _convert_to_tensor_array(lst, prelude)
+ concat_shape = (Any(),) + shape[1:]
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
+ static_tensor_array_ops.define_tensor_get_data(concat_shape)
+
+ concat = prelude.get_var_static('tensor_array_concat', "float32", shape)
+ concatenated = concat(tensor_array)
+ get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
+ return get_tensor(concatenated)
+
def _impl(inputs, input_types):
data = inputs[0]
axis = inputs[1]
+ if not isinstance(data, list):
+ return tensor_array_concat(data, axis)
+
if isinstance(data, _expr.Expr):
data = [data]
@@ -130,7 +221,7 @@ def _slice():
else:
end = data.shape
- begin = [0]*len(end)
+ begin = [0] * len(end)
dim = int(inputs[1])
begin[dim] = int(inputs[2])
@@ -371,7 +462,7 @@ def _maxpool_2d():
ceil_mode = int(inputs[5])
if dilation != (1, 1):
- msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation), )
+ msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation))
raise NotImplementedError(msg)
return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode)
@@ -388,7 +479,7 @@ def _maxpool_1d():
ceil_mode = int(inputs[5])
if dilation != (1,):
- msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation), )
+ msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation))
raise NotImplementedError(msg)
return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
@@ -404,7 +495,7 @@ def _maxpool_3d():
dilation = _infer_shape(inputs[4])
ceil_mode = int(inputs[5])
if dilation != (1, 1, 1):
- msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), )
+ msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation))
raise NotImplementedError(msg)
return _op.nn.max_pool3d(data,
@@ -618,13 +709,13 @@ def _layer_norm():
scale=True)
return _impl
-def _transpose():
+def _transpose(prelude):
def _impl(inputs, input_types):
data = inputs[0]
import torch
if isinstance(data, _expr.Expr):
- ndims = len(_infer_shape(data))
+ ndims = len(_infer_shape(data, prelude.mod))
elif isinstance(data, list):
ndims = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
@@ -693,15 +784,30 @@ def _dense():
return dense_out
return _impl
-def _size():
+
+def _size(prelude):
+ def _impl_dynamic(inp, axis):
+ shape_dynamic = _op.shape_of(inp)
+ if axis is not None:
+ return _op.take(shape_dynamic, _expr.const(axis), 0)
+ return shape_dynamic
+
def _impl(inputs, input_types):
- shape = _infer_shape(inputs[0])
+ shape = _infer_shape(inputs[0], prelude.mod)
+ axis = None
if len(inputs) > 1:
axis = int(inputs[1])
+
+ if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)):
+ if axis is None or isinstance(shape[axis], tvm.tir.expr.Any):
+ return _impl_dynamic(inputs[0], axis)
+
+ if axis is not None:
return shape[axis]
return shape
return _impl
+
def _numtotensor():
def _impl(inputs, input_types):
val = inputs[0]
@@ -862,7 +968,7 @@ def _mean():
return _impl
-def _chunk():
+def _chunk(prelude):
def _impl(inputs, input_types):
data = inputs[0]
@@ -870,7 +976,7 @@ def _chunk():
axis = int(inputs[2])
if isinstance(data, _expr.Expr):
- inferred_shape = _infer_shape(data)
+ inferred_shape = _infer_shape(data, prelude.mod)
shape = []
for infer in inferred_shape:
@@ -894,7 +1000,6 @@ def _chunk():
chunk_out = _op.transform.strided_slice(data, begin, end, stride)
chunks.append(chunk_out)
-
if dim % num_chunks:
begin = [0] * len(shape)
end = shape[:]
@@ -1077,6 +1182,49 @@ def _Float():
return _op.cast(inputs[0], "float32")
return _impl
+
+def _mm():
+ def _impl(inputs, input_types):
+ return _op.nn.dense(inputs[0], inputs[1])
+ return _impl
+
+
+def _list_getitem(prelude):
+ def _impl(inputs, input_types):
+ return prelude.nth(inputs[0], _wrap_const(inputs[1]))
+ return _impl
+
+
+def _list_len(prelude):
+ def _impl(inputs, input_types):
+ return prelude.length(inputs[0])
+ return _impl
+
+
+def _add(prelude):
+ # add_ is overloaded for tensor add and list concat
+ def _impl(inputs, input_types):
+ if input_types[0] == "ListType":
+ return prelude.concat(inputs[0], inputs[1])
+ return _elemwise("add")(inputs, input_types)
+ return _impl
+
+
+def _tensor_array_stack(prelude):
+ def _impl(inputs, input_types):
+ tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude)
+ stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
+ stacked = stack(tensor_array)
+
+ stacked_shape = (Any(),) + shape
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape)
+ static_tensor_array_ops.define_tensor_get_data(stacked_shape)
+ # passing stacked_shape below gives "'Prelude' object has no attribute" error
+ get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape)
+ return get_tensor(stacked)
+ return _impl
+
+
# Helper functions for operator implementation
def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64",
@@ -1148,112 +1296,117 @@ def _convert_elemwise_input(data, input_type):
return data
def _wrap_const(c):
- if not isinstance(c, _expr.Expr) and not isinstance(c, list):
+ if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)):
return _expr.const(c)
return c
# Operator mappings
-
-_convert_map = {
- "aten::device" : _none(),
- "aten::add" : _elemwise("add"),
- "aten::add_" : _elemwise("add"),
- "aten::sub" : _elemwise("subtract"),
- "aten::sub_" : _elemwise("subtract"),
- "aten::max" : _elemwise("maximum"),
- "aten::min" : _elemwise("minimum"),
- "aten::mul" : _elemwise("multiply"),
- "aten::mul_" : _elemwise("multiply"),
- "aten::pow" : _elemwise("power"),
- "aten::div" : _elemwise("divide"),
- "aten::div_" : _elemwise("divide"),
- "aten::abs" : _abs(),
- "aten::arange" : _arange(),
- "aten::ones" : _ones(),
- "aten::zeros" : _zeros(),
- "aten::reciprocal" : _reciprocal(),
- "aten::repeat" : _repeat(),
- "aten::repeat_interleave" : _repeat_interleave(),
- "aten::to" : _to(),
- "aten::squeeze" : _squeeze(),
- "aten::unsqueeze" : _unsqueeze(),
- "aten::cat" : _concatenate(),
- "aten::slice" : _slice(),
- "aten::split" : _split(),
- "aten::split_with_sizes" : _split_with_sizes(),
- "aten::select" : _select(),
- "aten::relu" : _relu(),
- "aten::relu_" : _relu(),
- "aten::prelu" : _prelu(),
- "aten::leaky_relu" : _leaky_relu(),
- "aten::elu" : _elu(),
- "aten::celu" : _celu(),
- "aten::gelu" : _gelu(),
- "aten::selu" : _selu(),
- "aten::log_sigmoid" : _log_sigmoid(),
- "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
- "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
- "aten::max_pool2d" : _maxpool_2d(),
- "aten::max_pool2d_with_indices" : _maxpool_2d(),
- "aten::max_pool1d" : _maxpool_1d(),
- "aten::max_pool3d" : _maxpool_3d(),
- "aten::hardtanh" : _hardtanh(),
- "aten::hardtanh_" : _hardtanh(),
- "aten::_convolution" : _convolution(),
- "aten::softmax" : _softmax(),
- "aten::threshold" : _threshold(),
- "aten::threshold_" : _threshold(),
- "aten::contiguous" : _contiguous(),
- "aten::batch_norm" : _batch_norm(),
- "aten::instance_norm" : _instance_norm(),
- "aten::layer_norm" : _layer_norm(),
- "aten::transpose" : _transpose(),
- "aten::transpose_" : _transpose(),
- "aten::t" : _transpose(),
- "aten::flatten" : _flatten(),
- "aten::addmm" : _dense(),
- "aten::size" : _size(),
- "aten::view" : _view(),
- "aten::reshape" : _reshape(),
- "aten::clone" : _clone(),
- "aten::log_softmax" : _log_softmax(),
- "aten::sigmoid" : _sigmoid(),
- "aten::softplus" : _softplus(),
- "aten::avg_pool2d" : _avg_pool2d(),
- "aten::avg_pool3d" : _avg_pool3d(),
- "aten::dropout" : _dropout(),
- "aten::dropout_" : _dropout(),
- "aten::feature_dropout" : _dropout(),
- "aten::alpha_dropout" : _dropout(),
- "aten::mean" : _mean(),
- "aten::chunk" : _chunk(),
- "aten::matmul" : _matmul(),
- "aten::expand" : _expand(),
- "aten::Int" : _int(),
- "prim::NumToTensor" : _numtotensor(),
- "prim::ListUnpack" : _identity(),
- "aten::constant_pad_nd" : _pad(),
- "aten::permute" : _transpose(),
- "aten::sum" : _reduce("sum"),
- "aten::prod" : _reduce("prod"),
- "aten::sqrt" : _sqrt(),
- 'aten::floor' : _floor(),
- "aten::detach" : _identity(),
- "aten::upsample_bilinear2d" : _upsample("bilinear"),
- "aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
- "aten::expand_as" : _expand_as(),
- "aten::lt" : _elemwise("less"),
- "aten::gt" : _elemwise("greater"),
- "aten::le" : _elemwise("less_equal"),
- "aten::ge" : _elemwise("greater_equal"),
- "aten::ne" : _elemwise("not_equal"),
- "aten::Bool" : _Bool(),
- "aten::Float" : _Float(),
- "aten::neg" : _neg(),
- "aten::tanh" : _tanh(),
- "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
- "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d()
-}
+def _get_convert_map(prelude):
+ convert_map = {
+ "aten::device" : _none(),
+ "aten::sub" : _elemwise("subtract"),
+ "aten::sub_" : _elemwise("subtract"),
+ "aten::max" : _elemwise("maximum"),
+ "aten::min" : _elemwise("minimum"),
+ "aten::mul" : _elemwise("multiply"),
+ "aten::mul_" : _elemwise("multiply"),
+ "aten::pow" : _elemwise("power"),
+ "aten::abs" : _abs(),
+ "aten::arange" : _arange(),
+ "aten::div" : _elemwise("divide"),
+ "aten::div_" : _elemwise("divide"),
+ "aten::ones" : _ones(),
+ "aten::zeros" : _zeros(),
+ "aten::reciprocal" : _reciprocal(),
+ "aten::repeat" : _repeat(),
+ "aten::repeat_interleave" : _repeat_interleave(),
+ "aten::to" : _to(),
+ "aten::squeeze" : _squeeze(),
+ "aten::unsqueeze" : _unsqueeze(),
+ "aten::cat" : _concatenate(prelude),
+ "aten::slice" : _slice(),
+ "aten::split" : _split(),
+ "aten::split_with_sizes" : _split_with_sizes(),
+ "aten::select" : _select(),
+ "aten::relu" : _relu(),
+ "aten::relu_" : _relu(),
+ "aten::prelu" : _prelu(),
+ "aten::leaky_relu" : _leaky_relu(),
+ "aten::elu" : _elu(),
+ "aten::celu" : _celu(),
+ "aten::gelu" : _gelu(),
+ "aten::selu" : _selu(),
+ "aten::log_sigmoid" : _log_sigmoid(),
+ "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
+ "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
+ "aten::max_pool2d" : _maxpool_2d(),
+ "aten::max_pool2d_with_indices" : _maxpool_2d(),
+ "aten::max_pool1d" : _maxpool_1d(),
+ "aten::max_pool3d" : _maxpool_3d(),
+ "aten::hardtanh" : _hardtanh(),
+ "aten::hardtanh_" : _hardtanh(),
+ "aten::_convolution" : _convolution(),
+ "aten::softmax" : _softmax(),
+ "aten::threshold" : _threshold(),
+ "aten::threshold_" : _threshold(),
+ "aten::contiguous" : _contiguous(),
+ "aten::batch_norm" : _batch_norm(),
+ "aten::instance_norm" : _instance_norm(),
+ "aten::layer_norm" : _layer_norm(),
+ "aten::transpose" : _transpose(prelude),
+ "aten::transpose_" : _transpose(prelude),
+ "aten::t" : _transpose(prelude),
+ "aten::flatten" : _flatten(),
+ "aten::addmm" : _dense(),
+ "aten::size" : _size(prelude),
+ "aten::view" : _view(),
+ "aten::reshape" : _reshape(),
+ "aten::clone" : _clone(),
+ "aten::log_softmax" : _log_softmax(),
+ "aten::sigmoid" : _sigmoid(),
+ "aten::softplus" : _softplus(),
+ "aten::avg_pool2d" : _avg_pool2d(),
+ "aten::avg_pool3d" : _avg_pool3d(),
+ "aten::dropout" : _dropout(),
+ "aten::dropout_" : _dropout(),
+ "aten::feature_dropout" : _dropout(),
+ "aten::alpha_dropout" : _dropout(),
+ "aten::mean" : _mean(),
+ "aten::chunk" : _chunk(prelude),
+ "aten::matmul" : _matmul(),
+ "aten::expand" : _expand(),
+ "aten::Int" : _int(),
+ "prim::NumToTensor" : _numtotensor(),
+ "aten::constant_pad_nd" : _pad(),
+ "aten::permute" : _transpose(prelude),
+ "aten::sum" : _reduce("sum"),
+ "aten::prod" : _reduce("prod"),
+ "aten::sqrt" : _sqrt(),
+ 'aten::floor' : _floor(),
+ "aten::detach" : _identity(),
+ "aten::upsample_bilinear2d" : _upsample("bilinear"),
+ "aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
+ "aten::expand_as" : _expand_as(),
+ "aten::lt" : _elemwise("less"),
+ "aten::gt" : _elemwise("greater"),
+ "aten::le" : _elemwise("less_equal"),
+ "aten::ge" : _elemwise("greater_equal"),
+ "aten::ne" : _elemwise("not_equal"),
+ "aten::Bool" : _Bool(),
+ "aten::Float" : _Float(),
+ "aten::neg" : _neg(),
+ "aten::tanh" : _tanh(),
+ "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
+ "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
+ "aten::mm" : _matmul(),
+ "relay::tensor_array_stack" : _tensor_array_stack(prelude),
+ "aten::add" : _add(prelude),
+ "aten::add_" : _add(prelude),
+ "aten::stack" : _tensor_array_stack(prelude),
+ "aten::__getitem__" : _list_getitem(prelude),
+ "aten::len" : _list_len(prelude),
+ }
+ return convert_map
def _run_jit_passes(graph):
@@ -1289,13 +1442,29 @@ def _get_op_inputs(op_node, outputs):
return [outputs[name] for name in _get_input_names(op_node)]
-def _report_missing_conversion(op_names):
+def _get_node_type(node):
+ assert node.outputsSize() == 1
+ return node.output().type().kind()
+
+
+def _get_uses(node):
+ uses = []
+ for output in node.outputs():
+ uses += output.uses()
+ return uses
+
+
+def _get_users(node):
+ return [use.user for use in _get_uses(node)]
+
+
+def _report_missing_conversion(op_names, convert_map):
""" Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack",
"prim::If", "prim::Loop"]
- known_ops += list(_convert_map.keys())
+ known_ops += list(convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())
missing = [op_name for op_name in op_names
@@ -1361,7 +1530,7 @@ def _get_input_types(op_node):
input_list_types.append(in_ty.scalarType().lower())
elif input_node_kind == 'ListType':
- input_list_types.append(str(in_ty.getElementType()).lower())
+ input_list_types.append("ListType")
elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
'StringType', 'OptionalType']:
input_list_types.append(str(in_ty).lower())
@@ -1422,21 +1591,69 @@ def _get_operator_nodes(nodes):
return ops
-def _get_relay_input_vars(graph, input_shapes):
+def _get_graph_input_names(graph):
+ """ Get the graph input names (use after graph copy and run jit passes) """
+ # Variable names could change the first time a copy is made and after
+ # _run_jit_passes is called, expected that those functions already invoked
+ ir_inputs = _get_input_names(graph)
+ return ir_inputs[1:] # remove self at the 0th arg
+
+
+def _get_relay_input_vars(graph, input_shapes, prelude):
"""
Return Relay vars from input shapes and create entries based on
expected graph inputs - to allow translation
"""
+ def get_relay_ty(ishape):
+ if _is_int_seq(ishape) or len(ishape) == 0:
+ return TensorType(ishape)
+ elif isinstance(ishape, tuple):
+ return TupleType([get_relay_ty(elem) for elem in ishape])
+ elif isinstance(ishape, list):
+ assert len(ishape) > 0
+ elem_tys = [get_relay_ty(s) for s in ishape]
+ msg = "List elements should have identical types"
+ assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
+ return prelude.l(elem_tys[0])
+ raise NotImplementedError("unsupported input type")
+
+ input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes]
input_vars = {}
ir_inputs = _get_graph_input_names(graph)
- for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
- inp = _expr.var(name, shape=shape)
+ for ir_input, (name, itype) in zip(ir_inputs, input_types):
+ inp = _expr.var(name, type_annotation=itype)
# Translate from graph input to user input name
input_vars[ir_input] = inp
return input_vars
+def _unpack_tuple(tup):
+ def unpack(tup, num_fields):
+ return [_expr.TupleGetItem(tup, i) for i in range(num_fields)]
+
+ if isinstance(tup, _expr.Tuple):
+ return unpack(tup, len(tup.fields))
+ elif isinstance(tup.type_annotation, TupleType):
+ return unpack(tup, len(tup.type_annotation.fields))
+ # shouldn't happen
+ assert False
+
+
+def _get_free_vars_from_block(block):
+ block_inp_names = _get_input_names(block)
+ bound_names = block_inp_names
+ free_vars = set()
+
+ for node in block.nodes():
+ inp_names = _get_input_names(node)
+ list_diff = [name for name in inp_names if name not in bound_names]
+ free_vars.update(list_diff)
+ bound_names += _get_output_names(node)
+
+ return free_vars
+
+
def get_use_chains(root_node, terminate=lambda _: False):
"""
Track a chain of users of this node forward, returning a list of chains
@@ -1446,9 +1663,7 @@ def get_use_chains(root_node, terminate=lambda _: False):
return itertools.chain.from_iterable(lists)
def inner(current, accum):
- users = []
- for output in current.outputs():
- users += [use.user for use in output.uses()]
+ users = _get_users(current)
if not users or terminate(users):
return [accum]
@@ -1512,24 +1727,24 @@ def convert_params(graph, state_dict):
return params, param_tensors, packed_param_map
-def convert_block(block, outputs):
+def convert_block(block, outputs, convert_map, prelude):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
- return convert_operators(ops, outputs, ret_names)
+ return convert_operators(ops, outputs, ret_names, convert_map, prelude)
-def convert_if(if_node, outputs):
+def convert_if(if_node, outputs, convert_map, prelude):
""" Translate Torch prim::If to Relay If """
cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
- true_branch = convert_block(blocks[0], outputs)
- false_branch = convert_block(blocks[1], outputs)
+ true_branch = convert_block(blocks[0], outputs, convert_map, prelude)
+ false_branch = convert_block(blocks[1], outputs, convert_map, prelude)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])
-def convert_loop(loop_node, outputs):
+def convert_loop(loop_node, outputs, convert_map, prelude):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
@@ -1555,8 +1770,54 @@ def convert_loop(loop_node, outputs):
is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
_get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)
+ if is_while_loop:
+ loop_iter_dtype = "bool"
+ # while loop with non input dependent condition such as while i < 10:
+ # init_cond is int, need to cast to bool to type check
+ if isinstance(init_cond, _expr.Constant):
+ init_cond = _op.cast(init_cond, "bool")
+ init_loop_iter_val = init_cond
+ else:
+ loop_iter_dtype = "int32"
+ # always count from 0
+ init_loop_iter_val = _expr.const(0, dtype="int32")
+
body_block = list(loop_node.blocks())[0]
block_input_names = _get_input_names(body_block)
+ num_block_inputs = len(block_input_names)
+ name_val_pairs = list(zip(block_input_names,
+ [init_loop_iter_val] + init_vals))
+ outputs.update(name_val_pairs)
+
+ def get_var(name, val):
+ if val:
+ checked_type = _infer_type_with_prelude(val, prelude)
+ return _expr.var(name, type_annotation=checked_type)
+ return _expr.var(name)
+
+ loop_iter_var = _expr.var(block_input_names[0], shape=(),
+ dtype=loop_iter_dtype)
+ loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
+
+ # Add non constant free variables to loop variables to prevent code blow up
+ # Without this, if there are two for loops in a row, which often happens
+ # if the outer loop is unrolled, the computation corresponding to the first for loop
+ # is inlined inside loop body, turning O(N) + O(N) computation into O(N^2).
+ # This issue was found when converting from Stacked LSTM test. Torch does not add the output
+ # of the eariler loop into loop variables of the next loop.
+ # So the variable corresponding to the first loop output appears free in the second loop body.
+ free_vars = [var for var in _get_free_vars_from_block(body_block)
+ if var in outputs and not isinstance(outputs[var], (_expr.Constant, int, float))
+ and outputs[var]]
+
+ prev_outputs = {}
+ for name in free_vars:
+ prev_output = outputs[name]
+ new_loop_var = get_var(name, prev_output)
+ prev_outputs[name] = prev_output
+ outputs[name] = new_loop_var
+ loop_vars.append(new_loop_var)
+ init_vals.append(prev_output)
def cond(*current_vals):
i = current_vals[0]
@@ -1568,11 +1829,16 @@ def convert_loop(loop_node, outputs):
def body(*current_vals):
# Update loop variables using the prev iteration outputs
- assert len(current_vals) == len(block_input_names)
- for (i, iname) in enumerate(block_input_names):
- outputs[iname] = current_vals[i]
+ assert len(current_vals) == num_block_inputs + len(free_vars)
- block_outputs = convert_block(body_block, outputs)
+ for (i, val) in enumerate(current_vals):
+ if i < num_block_inputs:
+ outputs[block_input_names[i]] = val
+ else:
+ outputs[free_vars[i-num_block_inputs]] = val
+
+ block_outputs = convert_block(body_block, outputs, convert_map, prelude)
+ block_outputs += [outputs[name] for name in free_vars]
if not is_while_loop:
# iter var increment implicit in torch, so do it manually
@@ -1583,38 +1849,17 @@ def convert_loop(loop_node, outputs):
return block_outputs
- def get_var(name, val):
- if isinstance(val, _expr.Constant):
- return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
- return _expr.var(name)
-
- if is_while_loop:
- loop_iter_dtype = "bool"
- # while loop with non input dependent condition such as while i < 10:
- # init_cond is int, need to cast to bool to type check
- if isinstance(init_cond, _expr.Constant):
- init_cond = _op.cast(init_cond, "bool")
- init_loop_iter_val = init_cond
- else:
- loop_iter_dtype = "int32"
- # always count from 0
- init_loop_iter_val = _expr.const(0, dtype="int32")
-
- name_val_pairs = list(zip(block_input_names,
- [init_loop_iter_val] + init_vals))
- outputs.update(name_val_pairs)
-
- loop_iter_var = _expr.var(block_input_names[0], shape=(),
- dtype=loop_iter_dtype)
- loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
loop_val = loop(init_loop_iter_val, *init_vals)
+ # restore original output values for free vars
+ outputs.update(prev_outputs)
+
# The first element is a loop counter or boolean condition, ignore it
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
-def convert_operators(operators, outputs, ret_names):
+def convert_operators(operators, outputs, ret_names, convert_map, prelude):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
@@ -1622,24 +1867,33 @@ def convert_operators(operators, outputs, ret_names):
if operator == "prim::Constant":
outputs[node_name] = _get_constant(op_node)
- elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
+ elif operator == "prim::ListConstruct" and _is_int_seq(inputs):
outputs[node_name] = _expr.var(node_name, shape=inputs)
- elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
+ elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node):
+ outputs[node_name] = _convert_to_list_adt(inputs, prelude)
+ elif operator == "prim::ListConstruct":
+ # This assumes that no more elements will be appended to this list
+ # In this case, we keep the Python list
outputs[node_name] = inputs
- elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
+ elif operator == "prim::TupleConstruct":
+ outputs[node_name] = _expr.Tuple(inputs)
+ elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]:
assert len(inputs) == 1
- unpacked_names = _get_output_names(op_node)
- outputs.update(zip(unpacked_names, inputs[0]))
+ if isinstance(inputs[0], (list, _expr.TupleWrapper)):
+ unpacked = inputs[0]
+ else:
+ unpacked = _unpack_tuple(inputs[0])
+ outputs.update(zip(_get_output_names(op_node), unpacked))
elif operator == "prim::If":
- if_out = convert_if(op_node, outputs)
+ if_out = convert_if(op_node, outputs, convert_map, prelude)
outputs[node_name] = if_out
elif operator == "prim::Loop":
- loop_out = convert_loop(op_node, outputs)
+ loop_out = convert_loop(op_node, outputs, convert_map, prelude)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
outputs.update(zip(unpacked_names, loop_out))
else:
- relay_op = _convert_map[operator]
+ relay_op = convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node))
if isinstance(relay_out, tuple):
@@ -1666,14 +1920,6 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)
-def _get_graph_input_names(graph):
- """ Get the graph input names (use after graph copy and run jit passes) """
- # Variable names could change the first time a copy is made and after
- # _run_jit_passes is called, expected that those functions already invoked
- ir_inputs = _get_input_names(graph)
- return ir_inputs[1:] # remove self at the 0th arg
-
-
def from_pytorch(script_module, input_shapes, custom_convert_map=None):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
@@ -1700,18 +1946,23 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
params : dict of str to tvm.runtime.NDArray
Dict of converted parameters stored in tvm.runtime.ndarray format
"""
+ mod = tvm.IRModule()
+ prelude = Prelude(mod)
+
+ convert_map = _get_convert_map(prelude)
+
graph = script_module.graph.copy()
_run_jit_passes(graph)
if custom_convert_map:
- _convert_map.update(custom_convert_map)
+ convert_map.update(custom_convert_map)
op_names = get_all_op_names(graph)
- _report_missing_conversion(op_names)
+ _report_missing_conversion(op_names, convert_map)
_check_inputs(graph, input_shapes)
params = script_module.state_dict()
- outputs = _get_relay_input_vars(graph, input_shapes)
+ outputs = _get_relay_input_vars(graph, input_shapes, prelude)
param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
@@ -1726,14 +1977,11 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
- _convert_map.update(qnn_torch.convert_map)
+ convert_map.update(qnn_torch.convert_map)
ret = convert_operators(_get_operator_nodes(graph.nodes()),
- outputs, ret_name)
-
- if isinstance(ret[0], list):
- ret[0] = _expr.Tuple(ret[0])
+ outputs, ret_name, convert_map, prelude)
- func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
+ mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
- return _module.IRModule.from_expr(func), tvm_params
+ return mod, tvm_params
diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/lstm_test.py
new file mode 100644
index 0000000..4616698
--- /dev/null
+++ b/tests/python/frontend/pytorch/lstm_test.py
@@ -0,0 +1,335 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Tests on torch lstm model conversion """
+# originally from https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
+# described in https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Parameter
+import torch.jit as jit
+from typing import List, Tuple
+from torch import Tensor
+
+import tvm
+from tvm import relay
+from tvm.relay.frontend.pytorch import from_pytorch
+from tvm.relay.prelude import Prelude
+from tvm.runtime.container import ADT, tuple_object
+
+
+class LayerNormLSTMCell(jit.ScriptModule):
+ def __init__(self, input_size, hidden_size):
+ super().__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
+ self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
+
+ ln = nn.LayerNorm
+
+ self.layernorm_i = ln(4 * hidden_size)
+ self.layernorm_h = ln(4 * hidden_size)
+ self.layernorm_c = ln(hidden_size)
+
+ @jit.script_method
+ def forward(self, input, state):
+ # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+ hx, cx = state
+ igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
+ hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
+ gates = igates + hgates
+ ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+ ingate = torch.sigmoid(ingate)
+ forgetgate = torch.sigmoid(forgetgate)
+ cellgate = torch.tanh(cellgate)
+ outgate = torch.sigmoid(outgate)
+
+ cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
+ hy = outgate * torch.tanh(cy)
+
+ return hy, (hy, cy)
+
+
+class LSTMLayer(jit.ScriptModule):
+ def __init__(self, cell, *cell_args):
+ super().__init__()
+ self.cell = cell(*cell_args)
+
+ @jit.script_method
+ def forward(self, input, state):
+ # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+ outputs = []
+ for i in range(input.size(0)):
+ out, state = self.cell(input[i], state)
+ outputs += [out]
+ return torch.stack(outputs), state
+
+
+class ReverseLSTMLayer(jit.ScriptModule):
+ def __init__(self, cell, *cell_args):
+ super(ReverseLSTMLayer, self).__init__()
+ self.cell = cell(*cell_args)
+
+ @jit.script_method
+ def forward(self, inputs, state):
+ # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+ outputs = jit.annotate(List[Tensor], [])
+ seq_len = inputs.size(0)
+ for i in range(seq_len):
+ out, state = self.cell(inputs[seq_len - i - 1], state)
+ # workaround for the lack of list rev support
+ outputs = [out] + outputs
+ return torch.stack(outputs), state
+
+
+class BidirLSTMLayer(jit.ScriptModule):
+ __constants__ = ['directions']
+
+ def __init__(self, cell, *cell_args):
+ super(BidirLSTMLayer, self).__init__()
+ self.directions = nn.ModuleList([
+ LSTMLayer(cell, *cell_args),
+ ReverseLSTMLayer(cell, *cell_args),
+ ])
+
+ @jit.script_method
+ def forward(self, input, states):
+ # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
+ # List[LSTMState]: [forward LSTMState, backward LSTMState]
+ outputs = jit.annotate(List[Tensor], [])
+ output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
+ for (i, direction) in enumerate(self.directions):
+ state = states[i]
+ out, out_state = direction(input, state)
+ outputs += [out]
+ output_states += [out_state]
+ # tensor array concat assumes axis == 0 for now
+ # return torch.cat(outputs, -1), output_states
+ return torch.cat(outputs, 0), output_states
+
+
+def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
+ layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
+ for _ in range(num_layers - 1)]
+ return nn.ModuleList(layers)
+
+
+class StackedLSTM(jit.ScriptModule):
+ __constants__ = ['layers'] # Necessary for iterating through self.layers
+
+ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
+ super().__init__()
+ self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
+ other_layer_args)
+
+ @jit.script_method
+ def forward(self, input, states):
+ # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
+ # List[LSTMState]: One state per layer
+ output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
+ output = input
+ for (i, rnn_layer) in enumerate(self.layers):
+ state = states[i]
+ output, out_state = rnn_layer(output, state)
+ output_states += [out_state]
+ return output, output_states
+
+
+class StackedBidirLSTM(jit.ScriptModule):
+ __constants__ = ['layers'] # Necessary for iterating through self.layers
+
+ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
+ super(StackedBidirLSTM, self).__init__()
+ self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
+ other_layer_args)
+
+ @jit.script_method
+ def forward(self, input, states):
+ # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
+ # List[List[LSTMState]]: The outer list is for layers,
+ # inner list is for directions.
+ output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
+ output = input
+ for (i, rnn_layer) in enumerate(self.layers):
+ state = states[i]
+ output, out_state = rnn_layer(output, state)
+ output_states += [out_state]
+ return output, output_states
+
+
+def lstm(input_size, hidden_size):
+ return LSTMLayer(LayerNormLSTMCell, input_size, hidden_size)
+
+
+def stacked_lstm(input_size, hidden_size, num_layers):
+ return StackedLSTM(num_layers, LSTMLayer,
+ first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
+ other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
+
+
+def bidir_lstm(input_size, hidden_size):
+ return BidirLSTMLayer(LayerNormLSTMCell, input_size, hidden_size)
+
+
+def stacked_bidir_lstm(input_size, hidden_size, num_layers):
+ return StackedBidirLSTM(num_layers, BidirLSTMLayer,
+ first_layer_args=[LayerNormLSTMCell, input_size, hidden_size],
+ other_layer_args=[LayerNormLSTMCell, hidden_size, hidden_size])
+
+
+def vmobj_to_list(o, dtype="float32"):
+ if isinstance(o, tvm.nd.NDArray):
+ return [o]
+ elif isinstance(o, tvm.runtime.container.ADT):
+ result = []
+ for f in o:
+ result.extend(vmobj_to_list(f, dtype))
+ return result
+ else:
+ raise RuntimeError("Unknown object type: %s" % type(o))
+
+
+def assert_equal(tvm_result, torch_result):
+ if isinstance(torch_result, (tuple, list)):
+ assert isinstance(tvm_result, list)
+ for tvm_res, pt_res in zip(tvm_result, torch_result):
+ assert_equal(tvm_res, pt_res)
+ elif isinstance(torch_result, torch.Tensor):
+ tvm.testing.assert_allclose(tvm_result.asnumpy(), torch_result.numpy(),
+ rtol=1e-4, atol=1e-4)
+
+
+def run_and_compare(mod, params, pt_result):
+ executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
+ evaluator = executor.evaluate()
+ exec_res = evaluator(**params)
+
+ def flatten(nested):
+ res = []
+ for r in nested:
+ if isinstance(r, torch.Tensor):
+ res.append(r)
+ else:
+ res.extend(flatten(r))
+ return res
+
+ if isinstance(exec_res, tvm.runtime.container.ADT):
+ assert not isinstance(pt_result, torch.Tensor)
+ tvm_res = vmobj_to_list(exec_res)
+ torch_res = flatten(pt_result)
+ else:
+ tvm_res = exec_res
+ torch_res = pt_result
+
+ assert_equal(tvm_res, torch_res)
+
+
+def convert_list_to_vmobj(py_lst):
+ def wrap_nd_array(arr):
+ return tvm.nd.array(arr, ctx=tvm.cpu(0))
+
+ mod = tvm.IRModule()
+ prelude = Prelude(mod)
+ adt_lst = ADT(prelude.nil.tag, [])
+ for elem in reversed(py_lst):
+ if isinstance(elem, np.ndarray):
+ vmobj = wrap_nd_array(elem)
+ elif isinstance(elem, tuple):
+ vmobj = tuple_object([wrap_nd_array(e) for e in elem])
+ elif isinstance(elem, list):
+ vmobj = convert_list_to_vmobj(elem)
+ adt_lst = ADT(prelude.cons.tag, [vmobj, adt_lst])
+ return adt_lst
+
+
+def custom_lstm_test():
+ input_name = "input"
+ states_name = "states"
+ seq_len = 5
+ batch = 2
+ input_size = 3
+ hidden_size = 4
+ num_layers = 3
+ state_tensor_shape = (batch, hidden_size)
+
+ inp = torch.randn(seq_len, batch, input_size)
+
+ input_shapes = [(input_name, (seq_len, batch, input_size)),
+ (states_name, (state_tensor_shape, state_tensor_shape))]
+
+ input_shapes_stacked = [(input_name, (seq_len, batch, input_size)),
+ (states_name, [(state_tensor_shape, state_tensor_shape),
+ (state_tensor_shape, state_tensor_shape)])]
+
+ input_shapes_stacked_bidir = [(input_name, (seq_len, batch, input_size)),
+ (states_name, [[(state_tensor_shape,
+ state_tensor_shape)
+ for _ in range(2)]
+ for _ in range(num_layers)])]
+
+ states = [(torch.randn(state_tensor_shape),
+ torch.randn(state_tensor_shape))
+ for _ in range(num_layers)]
+
+ bidir_states = [(torch.randn(state_tensor_shape),
+ torch.randn(state_tensor_shape))
+ for _ in range(2)]
+
+ stacked_bidir_states = [[(torch.randn(state_tensor_shape),
+ torch.randn(state_tensor_shape))
+ for _ in range(2)]
+ for _ in range(num_layers)]
+
+ models = [
+ (lstm(input_size, hidden_size).eval(), states[0], input_shapes),
+ (stacked_lstm(input_size, hidden_size, num_layers).eval(), states, input_shapes_stacked),
+ (bidir_lstm(input_size, hidden_size).eval(), bidir_states, input_shapes_stacked),
+ (stacked_bidir_lstm(input_size, hidden_size, num_layers).eval(),
+ stacked_bidir_states, input_shapes_stacked_bidir)
+ ]
+
+ for (raw_model, states, input_shapes) in models:
+ script_module = torch.jit.script(raw_model)
+ mod, params = from_pytorch(script_module, input_shapes)
+
+ with torch.no_grad():
+ pt_result = raw_model(inp.clone(), states)
+
+ params[input_name] = inp.numpy()
+
+ if isinstance(states, tuple):
+ states_np = tuple(st.numpy() for st in states)
+ elif isinstance(states, list) and isinstance(states[0], torch.Tensor):
+ states_np = [st.numpy() for st in states]
+ elif isinstance(states, list) and isinstance(states[0], tuple):
+ states_np = [tuple(st.numpy() for st in states[i])
+ for i in range(len(states))]
+ elif isinstance(states, list) and isinstance(states[0], list):
+ states_np = [[tuple(st.numpy() for st in states)
+ for states in states[layer]]
+ for layer in range(num_layers)]
+ else:
+ assert False
+
+ if isinstance(states_np, list):
+ params[states_name] = convert_list_to_vmobj(states_np)
+ else:
+ params[states_name] = states_np
+
+ run_and_compare(mod, params, pt_result)
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index d60ab9e..8e99285 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -526,13 +526,13 @@ def test_forward_maxpool2d():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(),
- input_data)
+ input_data)
verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(),
- input_data)
+ input_data)
verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4],
padding=2,
stride=2).eval(),
- input_data)
+ input_data)
def test_forward_maxpool1d():
torch.set_grad_enabled(False)
@@ -540,13 +540,13 @@ def test_forward_maxpool1d():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(),
- input_data)
+ input_data)
verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(),
- input_data)
- verify_model( torch.nn.MaxPool1d(kernel_size=4,
+ input_data)
+ verify_model(torch.nn.MaxPool1d(kernel_size=4,
padding=2,
stride=2).eval(),
- input_data)
+ input_data)
def test_forward_maxpool3d():
torch.set_grad_enabled(False)
@@ -554,13 +554,13 @@ def test_forward_maxpool3d():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(),
- input_data)
+ input_data)
verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(),
- input_data)
+ input_data)
verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4],
padding=2,
stride=2).eval(),
- input_data)
+ input_data)
def test_forward_split():
torch.set_grad_enabled(False)
@@ -577,13 +577,13 @@ def test_forward_split():
input_data = torch.rand(input_shape).float()
verify_model(Split(2, 0).float().eval(),
- input_data=input_data)
+ input_data=input_data)
verify_model(Split(3, 1).float().eval(),
- input_data=input_data)
+ input_data=input_data)
verify_model(Split(4, 1).float().eval(),
- input_data=input_data)
+ input_data=input_data)
verify_model(Split([2, 3, 5], 1).float().eval(),
- input_data=input_data)
+ input_data=input_data)
def test_forward_avgpool():
torch.set_grad_enabled(False)
@@ -1363,3 +1363,8 @@ if __name__ == "__main__":
# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()
+
+ # More complex recurrent models
+ from lstm_test import custom_lstm_test
+
+ custom_lstm_test()