You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/07 06:37:54 UTC
[tvm] branch main updated: [Pytorch] add aten::rnn_tanh, aten::rnn_relu (#12017)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 40d242a3c8 [Pytorch] add aten::rnn_tanh, aten::rnn_relu (#12017)
40d242a3c8 is described below
commit 40d242a3c8f9630223e5775c1f1bf23362c8850e
Author: yuanfz <42...@users.noreply.github.com>
AuthorDate: Thu Jul 7 08:37:48 2022 +0200
[Pytorch] add aten::rnn_tanh, aten::rnn_relu (#12017)
* emptycommit 2nd try
* dev
* comments
* format
* format
Co-authored-by: yuanfz <42...@users.noreply.github.com>
---
python/tvm/relay/frontend/common.py | 40 ++++++
python/tvm/relay/frontend/pytorch.py | 189 ++++++++++++++++++++++++++++-
tests/python/frontend/pytorch/test_rnns.py | 79 ++++++++++++
3 files changed, 307 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 7a1e984029..5f961f1ae0 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -686,6 +686,46 @@ def unbind(data, axis=0):
return _expr.TupleWrapper(_expr.Tuple(ret), selections)
+def rnn_cell(
+ input_seqs, hidden_state, w_inp, w_hid, b_inp=None, b_hid=None, backwards=False, act=_op.tanh
+):
+ """
+ Common implementation of RNN cell for all frontends of TVM
+
+ Parameters
+ ----------
+ input_seqs : List[relay.Expr]
+ The sequence of input tensors
+ Input tensor should be 2d while issue #8412 is not resolved
+ Shape = (batch, feature_size)
+ hidden_state : relay.Expr
+ Hidden state. shape = (batch_size, hidden_size)
+ w_inp, w_hid: relay.Expr
+ weight matrices. shape = (hidden_size, feature_size), (hidden_size, feature_size)
+ b_inp, b_hid : relay.Expr
+ bias matrices. The same order of internal parts as for weights. shape = (1 * hidden_size)
+ backwards : bool
+ Flag for reverse pass of RNN
+ act : relay.op
+ activation function. It is tanh by default.
+
+ Returns
+ -------
+ result : List[relay.Expr], relay.Expr, relay.Expr
+ The sequence of computed result, final hidden and cell state
+ """
+ outputs_list = []
+ for x_t in input_seqs if not backwards else reversed(input_seqs):
+ xwt = _op.nn.dense(x_t, w_inp)
+ hwt = _op.nn.dense(hidden_state, w_hid)
+ if b_inp is not None and b_hid is not None:
+ xwt += b_inp
+ hwt += b_hid
+ hidden_state = act(xwt + hwt)
+ outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]
+ return outputs_list, hidden_state
+
+
def gru_cell(
input_seqs,
hidden_state,
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index b1a7608860..d7e1a5dd1d 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -40,7 +40,7 @@ from ..loops import while_loop
from ..prelude import Prelude, StaticTensorArrayOps
from ..ty import Any, TensorType, TupleType
from . import qnn_torch
-from .common import AttrCvt, get_relay_op, gru_cell, logger
+from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
@@ -2630,6 +2630,191 @@ class PyTorchOpConverter:
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])
+ def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
+ """
+ Bidirectional RNN cell
+ """
+ seq_len = len(input_seqs)
+ forward_outputs, fw_H_t = rnn_cell(input_seqs, **weights_dicts[0], backwards=False, act=act)
+
+ reverse_outputs, rev_H_t = rnn_cell(input_seqs, **weights_dicts[1], backwards=True, act=act)
+
+ final_outputs = []
+ for i in range(seq_len):
+ final_outputs.append(
+ _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1)
+ )
+
+ return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)
+
+ def rnn_layers(self, input_data, layer_weights_dicts, bidirectional, act, dropout_p=0.0):
+ """
+ Methods iterates layers for Stacked RNN
+ """
+ layers_num = len(layer_weights_dicts)
+ # split input sequence to samples set
+ input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)]
+ output_hiddens = []
+ for i in range(layers_num):
+ weights_dicts = layer_weights_dicts[i]
+ # input_seqs shape = [seq_num, (batch, feature_size)] or
+ # [seq_num, (batch, 2*feature_size)] for bidirectional
+ if bidirectional:
+ input_seqs, H_t = self.bidir_rnn_cell(input_seqs, weights_dicts, act=act)
+ else:
+ input_seqs, H_t = rnn_cell(input_seqs, **weights_dicts[0], act=act)
+
+ output_hiddens.append(H_t)
+
+ # TODO (yuanfz98): in pytorch implementation train is also checked
+ # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
+ # /aten/src/ATen/native/RNN.cpp#L1054
+ if dropout_p != 0 and i < layers_num - 1:
+ # for input in input_seqs:
+ # input = _op.dropout(input, dropout_p)
+ raise NotImplementedError("Dropout for GRU has not been supported yet!")
+ output_hiddens = (
+ _op.concatenate(output_hiddens, 0) if bidirectional else _op.stack(output_hiddens, 0)
+ )
+ return _op.stack(input_seqs, 0), output_hiddens
+
+ def rnn(self, inputs, input_types, nonlinearity):
+ """
+ Description of RNN in pytorch:
+ https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN
+ Description of inputs:
+ https://github.com/pytorch/pytorch/blob/736fb7d22cc948b739db2c35aeb5ad4d19aea4f4/torch/overrides.py#L937
+ """
+ # TODO (yuanfz98): support dropout
+ assert len(inputs) == 9, "Input of size 9 is expected"
+ # Unpack inputs, note that if optional and not provided then value will be None.
+ _X = inputs[0]
+ # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)
+
+ hidden_state = inputs[1]
+ # Hidden state shape (hidden_layers_num, batch, hidden_size)
+
+ _weights = inputs[2]
+ # Wi layer[0] shape (hidden_size, feature_size)
+ # Wh layer[0] shape (hidden_size, hidden_size)
+ # Bi layer[0] shape (hidden_size)
+ # Bh layer[0] shape (hidden_size)
+
+ # Wi layer[>0] shape (hidden_size, hidden_size * num_directions)
+ # Wh layer[>0] shape (hidden_size, hidden_size)
+ # Bi layer[>0] shape (hidden_size)
+ # Bh layer[>0] shape (hidden_size)
+
+ # Scalar inputs
+ has_biases = inputs[3]
+ num_layers = inputs[4]
+ dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
+ # train = inputs[6]
+ bidirectional = inputs[7]
+ batch_first = inputs[8]
+
+ num_directions = 1
+ if bidirectional:
+ num_directions = 2
+
+ rsd = len(_weights) % num_layers
+ assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
+ rsd = (len(_weights) / num_layers) % num_directions
+ assert (
+ rsd == 0
+ ), "The number of weights in layer must be a multiple of the number of directions!"
+
+ weights_num = int(len(_weights) / num_layers / num_directions)
+ if has_biases:
+ assert weights_num == 4, "The weights number in layer is expected equal to 4"
+ else:
+ assert weights_num == 2, "The weights number in layer is expected equal to 2"
+ if nonlinearity == "tanh":
+ act = _op.tanh
+ elif nonlinearity == "relu":
+ act = _op.nn.relu
+ assert act, "The nonlinearity is unknown"
+ X = (
+ _op.transpose(_X, (1, 0, 2)) if batch_first else _X
+ ) # always (seq_num, batch, feature_size)
+ # TODO (yuanfz98): Which data type should be used? from input or weights?
+ # Instead of it _infer_type(X).checked_type.dtype can be used
+ X_dtype = input_types[0]
+ X_shape = _infer_shape(X) # (seq_num, batch, feature_size)
+
+ hidden_size = int(_infer_shape(_weights[0])[0])
+ batch_size = X_shape[1]
+
+ # Initialize hidden states if not provided.
+ layers_h = []
+ hidden_layers_num = num_directions * num_layers
+ if hidden_state is None:
+ h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
+ for i in range(hidden_layers_num):
+ layers_h.append(h_0)
+ else:
+ layers_h = unbind(hidden_state, 0)
+
+ layer_weights_dicts = []
+ k = 0 # layer counter
+ if has_biases:
+ names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
+ if bidirectional:
+ rsd = len(_weights) % (2 * weights_num)
+ assert rsd == 0, "got an incorrect number of RNN weights"
+ for i in range(0, len(_weights), 2 * weights_num):
+ fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
+ fw_weights_dict = dict(zip(names, fw_tensors))
+ j = i + weights_num
+ rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]]
+ rev_weights_dict = dict(zip(names, rev_tensors))
+ layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
+ k += 1
+ else:
+ assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
+ for i in range(0, len(_weights), weights_num):
+ fw_tensors = [layers_h[k], *_weights[i : i + 4]]
+ fw_weights_dict = dict(zip(names, fw_tensors))
+ layer_weights_dicts.append([fw_weights_dict])
+ k += 1
+ else:
+ names = ["hidden_state", "w_inp", "w_hid"]
+ if bidirectional:
+ rsd = len(_weights) % (2 * weights_num)
+ assert rsd == 0, "got an incorrect number of RNN weights"
+ for i in range(0, len(_weights), 2 * weights_num):
+ fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
+ fw_weights_dict = dict(zip(names, fw_tensors))
+ j = i + weights_num
+ rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]]
+ rev_weights_dict = dict(zip(names, rev_tensors))
+ layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
+ k += 1
+ else:
+ assert len(_weights) % weights_num == 0, "got an incorrect number of RNN weights"
+ for i in range(0, len(_weights), weights_num):
+ fw_tensors = [layers_h[k], *_weights[i : i + 2]]
+ fw_weights_dict = dict(zip(names, fw_tensors))
+ layer_weights_dicts.append([fw_weights_dict])
+ k += 1
+ assert (
+ len(layer_weights_dicts) == num_layers and k == num_layers
+ ), "For stacked RNN number of weights sets should be the same as number of layers!"
+ output, out_hidden_state = self.rnn_layers(
+ X,
+ layer_weights_dicts,
+ bidirectional,
+ act,
+ dropout_p=dropout_p,
+ )
+
+ # output shape = (seq_num, batch, hidden_size) or
+ # (seq_num, batch, 2*feature_size) for bidirectional
+ if batch_first:
+ output = _op.transpose(output, (1, 0, 2))
+
+ return (output, out_hidden_state)
+
def bidir_gru_cell(
self,
input_seqs,
@@ -3442,6 +3627,8 @@ class PyTorchOpConverter:
"aten::l1_loss": self.l1_loss,
"aten::mse_loss": self.mse_loss,
"aten::flip": self.flip,
+ "aten::rnn_tanh": functools.partial(self.rnn, nonlinearity="tanh"),
+ "aten::rnn_relu": functools.partial(self.rnn, nonlinearity="relu"),
"aten::gru": self.gru,
"aten::lstm": self.lstm,
"aten::all": functools.partial(self.all_any_common, _op.all),
diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py
index b0180a7a99..fba55b9c4c 100644
--- a/tests/python/frontend/pytorch/test_rnns.py
+++ b/tests/python/frontend/pytorch/test_rnns.py
@@ -40,6 +40,10 @@ num_layers = 2
seqs_length = 2
batch_size = 2
+##RNN parameters
+rnn_feature_size = 8
+rnn_hidden_size = 16
+
class RNN_Model(nn.Module):
"""
@@ -93,6 +97,72 @@ class RNN_Model(nn.Module):
raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!")
+class RNN_Model_Impl(RNN_Model):
+ def __init__(
+ self,
+ seq_len=seqs_length,
+ batch_size=batch_size,
+ feature_size=rnn_feature_size,
+ hidden_size=rnn_hidden_size,
+ batch_first=False,
+ layer_num=1,
+ bidirectional=False,
+ use_bias=True,
+ rnd_weights_init=False,
+ nonlinearity="tanh",
+ dropout=0.0,
+ ):
+ super().__init__()
+ # Shapes
+ self.shape = [seq_len, batch_size, feature_size]
+ if batch_first:
+ self.shape = [batch_size, seq_len, feature_size]
+ layers_num = 2 * layer_num if bidirectional else layer_num
+ self.h0_shape = [layers_num, batch_size, hidden_size]
+ # Dummy inputs
+ self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape))
+
+ self.model = nn.RNN(
+ input_size=feature_size,
+ hidden_size=hidden_size,
+ num_layers=layer_num,
+ nonlinearity=nonlinearity,
+ bias=use_bias,
+ batch_first=batch_first,
+ dropout=dropout,
+ bidirectional=bidirectional,
+ )
+
+ if rnd_weights_init:
+ self.gen_rnd_weights()
+
+ def gen_rnd_weights(self):
+ super().gen_rnd_weights()
+
+ def get_dummy_inputs(self):
+ return self.dummy_inputs
+
+ def get_input_names(self):
+ return ["input", "h0"]
+
+ def get_shape_desc(self, frontend_type):
+ shape_desc = None
+ if frontend_type == "pt": # PyTorch
+ shape_desc = [("input", self.shape)]
+ elif frontend_type == "onnx": # ONNX
+ shape_desc = {
+ "input": self.shape,
+ "h0": self.h0_shape,
+ }
+ return shape_desc
+
+ def get_tvm_inputs(self, dtype):
+ return {
+ "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)),
+ "h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)),
+ }
+
+
class GRU_Model(RNN_Model):
def __init__(
self,
@@ -331,6 +401,10 @@ def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm -mcpu=core-avx2"
args["bidirectional"] = True
if "s" in rnn_mod:
args["layer_num"] = num_layers
+ if "tanh" in rnn_mod:
+ args["nonlinearity"] = "tanh"
+ if "relu" in rnn_mod:
+ args["nonlinearity"] = "relu"
if rnn_type == "GRU":
RNN_Model_selector = GRU_Model
@@ -338,6 +412,8 @@ def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm -mcpu=core-avx2"
RNN_Model_selector = LSTM_Model
if "p" in rnn_mod:
args["proj_size"] = lstm_projection_size
+ elif rnn_type == "RNN":
+ RNN_Model_selector = RNN_Model_Impl
return RNN_Model_selector(**args)
@@ -425,6 +501,9 @@ def test_rnns():
for mod_type in ["uni", "s", "b", "sb"]:
check_rnn("LSTM", mod_type, target, dev)
+ for mod_type in ["uni", "s", "b", "sb", "tanh", "relu"]:
+ check_rnn("RNN", mod_type, target, dev)
+
if __name__ == "__main__":
test_rnns()