You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/07 06:37:54 UTC

[tvm] branch main updated: [Pytorch] add aten::rnn_tanh, aten::rnn_relu (#12017)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 40d242a3c8 [Pytorch] add aten::rnn_tanh, aten::rnn_relu (#12017)
40d242a3c8 is described below

commit 40d242a3c8f9630223e5775c1f1bf23362c8850e
Author: yuanfz <42...@users.noreply.github.com>
AuthorDate: Thu Jul 7 08:37:48 2022 +0200

    [Pytorch] add aten::rnn_tanh, aten::rnn_relu (#12017)
    
    * emptycommit 2nd try
    
    * dev
    
    * comments
    
    * format
    
    * format
    
    Co-authored-by: yuanfz <42...@users.noreply.github.com>
---
 python/tvm/relay/frontend/common.py        |  40 ++++++
 python/tvm/relay/frontend/pytorch.py       | 189 ++++++++++++++++++++++++++++-
 tests/python/frontend/pytorch/test_rnns.py |  79 ++++++++++++
 3 files changed, 307 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 7a1e984029..5f961f1ae0 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -686,6 +686,46 @@ def unbind(data, axis=0):
     return _expr.TupleWrapper(_expr.Tuple(ret), selections)
 
 
+def rnn_cell(
+    input_seqs, hidden_state, w_inp, w_hid, b_inp=None, b_hid=None, backwards=False, act=_op.tanh
+):
+    """
+    Common implementation of RNN cell for all frontends of TVM
+
+    Parameters
+    ----------
+    input_seqs : List[relay.Expr]
+        The sequence of input tensors
+        Input tensor should be 2d while issue #8412 is not resolved
+        Shape = (batch, feature_size)
+    hidden_state : relay.Expr
+        Hidden state. shape = (batch_size, hidden_size)
+    w_inp, w_hid: relay.Expr
+        weight matrices. shape = (hidden_size, feature_size), (hidden_size, feature_size)
+    b_inp, b_hid : relay.Expr
+        bias matrices. The same order of internal parts as for weights. shape = (1 * hidden_size)
+    backwards : bool
+        Flag for reverse pass of RNN
+    act : relay.op
+        activation function. It is tanh by default.
+
+    Returns
+    -------
+    result : List[relay.Expr], relay.Expr, relay.Expr
+        The sequence of computed result, final hidden and cell state
+    """
+    outputs_list = []
+    for x_t in input_seqs if not backwards else reversed(input_seqs):
+        xwt = _op.nn.dense(x_t, w_inp)
+        hwt = _op.nn.dense(hidden_state, w_hid)
+        if b_inp is not None and b_hid is not None:
+            xwt += b_inp
+            hwt += b_hid
+        hidden_state = act(xwt + hwt)
+        outputs_list.append(hidden_state)  # [seq_num, (batch, hidden_size)]
+    return outputs_list, hidden_state
+
+
 def gru_cell(
     input_seqs,
     hidden_state,
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index b1a7608860..d7e1a5dd1d 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -40,7 +40,7 @@ from ..loops import while_loop
 from ..prelude import Prelude, StaticTensorArrayOps
 from ..ty import Any, TensorType, TupleType
 from . import qnn_torch
-from .common import AttrCvt, get_relay_op, gru_cell, logger
+from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
@@ -2630,6 +2630,191 @@ class PyTorchOpConverter:
         axis = inputs[1]
         return _op.transform.reverse(data, axis=axis[0])
 
+    def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
+        """
+        Bidirectional RNN cell
+        """
+        seq_len = len(input_seqs)
+        forward_outputs, fw_H_t = rnn_cell(input_seqs, **weights_dicts[0], backwards=False, act=act)
+
+        reverse_outputs, rev_H_t = rnn_cell(input_seqs, **weights_dicts[1], backwards=True, act=act)
+
+        final_outputs = []
+        for i in range(seq_len):
+            final_outputs.append(
+                _op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1)
+            )
+
+        return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)
+
+    def rnn_layers(self, input_data, layer_weights_dicts, bidirectional, act, dropout_p=0.0):
+        """
+        Methods iterates layers for Stacked RNN
+        """
+        layers_num = len(layer_weights_dicts)
+        # split input sequence to samples set
+        input_seqs = unbind(input_data, 0)  # [seq_num, (batch, feature_size)]
+        output_hiddens = []
+        for i in range(layers_num):
+            weights_dicts = layer_weights_dicts[i]
+            # input_seqs shape = [seq_num, (batch, feature_size)] or
+            # [seq_num, (batch, 2*feature_size)] for bidirectional
+            if bidirectional:
+                input_seqs, H_t = self.bidir_rnn_cell(input_seqs, weights_dicts, act=act)
+            else:
+                input_seqs, H_t = rnn_cell(input_seqs, **weights_dicts[0], act=act)
+
+            output_hiddens.append(H_t)
+
+            # TODO (yuanfz98): in pytorch implementation train is also checked
+            # see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
+            # /aten/src/ATen/native/RNN.cpp#L1054
+            if dropout_p != 0 and i < layers_num - 1:
+                # for input in input_seqs:
+                #     input = _op.dropout(input, dropout_p)
+                raise NotImplementedError("Dropout for GRU has not been supported yet!")
+        output_hiddens = (
+            _op.concatenate(output_hiddens, 0) if bidirectional else _op.stack(output_hiddens, 0)
+        )
+        return _op.stack(input_seqs, 0), output_hiddens
+
+    def rnn(self, inputs, input_types, nonlinearity):
+        """
+        Description of RNN in pytorch:
+        https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN
+        Description of inputs:
+        https://github.com/pytorch/pytorch/blob/736fb7d22cc948b739db2c35aeb5ad4d19aea4f4/torch/overrides.py#L937
+        """
+        # TODO (yuanfz98): support dropout
+        assert len(inputs) == 9, "Input of size 9 is expected"
+        # Unpack inputs, note that if optional and not provided then value will be None.
+        _X = inputs[0]
+        # _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)
+
+        hidden_state = inputs[1]
+        # Hidden state shape (hidden_layers_num, batch, hidden_size)
+
+        _weights = inputs[2]
+        # Wi layer[0] shape (hidden_size, feature_size)
+        # Wh layer[0] shape (hidden_size, hidden_size)
+        # Bi layer[0] shape (hidden_size)
+        # Bh layer[0] shape (hidden_size)
+
+        # Wi layer[>0] shape (hidden_size, hidden_size * num_directions)
+        # Wh layer[>0] shape (hidden_size, hidden_size)
+        # Bi layer[>0] shape (hidden_size)
+        # Bh layer[>0] shape (hidden_size)
+
+        # Scalar inputs
+        has_biases = inputs[3]
+        num_layers = inputs[4]
+        dropout_p = inputs[5]  # dropout probability, if 0.0 it means there is no dropout
+        # train = inputs[6]
+        bidirectional = inputs[7]
+        batch_first = inputs[8]
+
+        num_directions = 1
+        if bidirectional:
+            num_directions = 2
+
+        rsd = len(_weights) % num_layers
+        assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
+        rsd = (len(_weights) / num_layers) % num_directions
+        assert (
+            rsd == 0
+        ), "The number of weights in layer must be a multiple of the number of directions!"
+
+        weights_num = int(len(_weights) / num_layers / num_directions)
+        if has_biases:
+            assert weights_num == 4, "The weights number in layer is expected equal to 4"
+        else:
+            assert weights_num == 2, "The weights number in layer is expected equal to 2"
+        if nonlinearity == "tanh":
+            act = _op.tanh
+        elif nonlinearity == "relu":
+            act = _op.nn.relu
+        assert act, "The nonlinearity is unknown"
+        X = (
+            _op.transpose(_X, (1, 0, 2)) if batch_first else _X
+        )  # always (seq_num, batch, feature_size)
+        # TODO (yuanfz98): Which data type should be used? from input or weights?
+        # Instead of it _infer_type(X).checked_type.dtype can be used
+        X_dtype = input_types[0]
+        X_shape = _infer_shape(X)  # (seq_num, batch, feature_size)
+
+        hidden_size = int(_infer_shape(_weights[0])[0])
+        batch_size = X_shape[1]
+
+        # Initialize hidden states if not provided.
+        layers_h = []
+        hidden_layers_num = num_directions * num_layers
+        if hidden_state is None:
+            h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
+            for i in range(hidden_layers_num):
+                layers_h.append(h_0)
+        else:
+            layers_h = unbind(hidden_state, 0)
+
+        layer_weights_dicts = []
+        k = 0  # layer counter
+        if has_biases:
+            names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
+            if bidirectional:
+                rsd = len(_weights) % (2 * weights_num)
+                assert rsd == 0, "got an incorrect number of RNN weights"
+                for i in range(0, len(_weights), 2 * weights_num):
+                    fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    j = i + weights_num
+                    rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]]
+                    rev_weights_dict = dict(zip(names, rev_tensors))
+                    layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
+                    k += 1
+            else:
+                assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
+                for i in range(0, len(_weights), weights_num):
+                    fw_tensors = [layers_h[k], *_weights[i : i + 4]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    layer_weights_dicts.append([fw_weights_dict])
+                    k += 1
+        else:
+            names = ["hidden_state", "w_inp", "w_hid"]
+            if bidirectional:
+                rsd = len(_weights) % (2 * weights_num)
+                assert rsd == 0, "got an incorrect number of RNN weights"
+                for i in range(0, len(_weights), 2 * weights_num):
+                    fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    j = i + weights_num
+                    rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]]
+                    rev_weights_dict = dict(zip(names, rev_tensors))
+                    layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
+                    k += 1
+            else:
+                assert len(_weights) % weights_num == 0, "got an incorrect number of RNN weights"
+                for i in range(0, len(_weights), weights_num):
+                    fw_tensors = [layers_h[k], *_weights[i : i + 2]]
+                    fw_weights_dict = dict(zip(names, fw_tensors))
+                    layer_weights_dicts.append([fw_weights_dict])
+                    k += 1
+        assert (
+            len(layer_weights_dicts) == num_layers and k == num_layers
+        ), "For stacked RNN number of weights sets should be the same as number of layers!"
+        output, out_hidden_state = self.rnn_layers(
+            X,
+            layer_weights_dicts,
+            bidirectional,
+            act,
+            dropout_p=dropout_p,
+        )
+
+        # output shape = (seq_num, batch, hidden_size) or
+        # (seq_num, batch, 2*feature_size) for bidirectional
+        if batch_first:
+            output = _op.transpose(output, (1, 0, 2))
+
+        return (output, out_hidden_state)
+
     def bidir_gru_cell(
         self,
         input_seqs,
@@ -3442,6 +3627,8 @@ class PyTorchOpConverter:
             "aten::l1_loss": self.l1_loss,
             "aten::mse_loss": self.mse_loss,
             "aten::flip": self.flip,
+            "aten::rnn_tanh": functools.partial(self.rnn, nonlinearity="tanh"),
+            "aten::rnn_relu": functools.partial(self.rnn, nonlinearity="relu"),
             "aten::gru": self.gru,
             "aten::lstm": self.lstm,
             "aten::all": functools.partial(self.all_any_common, _op.all),
diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py
index b0180a7a99..fba55b9c4c 100644
--- a/tests/python/frontend/pytorch/test_rnns.py
+++ b/tests/python/frontend/pytorch/test_rnns.py
@@ -40,6 +40,10 @@ num_layers = 2
 seqs_length = 2
 batch_size = 2
 
+##RNN parameters
+rnn_feature_size = 8
+rnn_hidden_size = 16
+
 
 class RNN_Model(nn.Module):
     """
@@ -93,6 +97,72 @@ class RNN_Model(nn.Module):
         raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!")
 
 
+class RNN_Model_Impl(RNN_Model):
+    def __init__(
+        self,
+        seq_len=seqs_length,
+        batch_size=batch_size,
+        feature_size=rnn_feature_size,
+        hidden_size=rnn_hidden_size,
+        batch_first=False,
+        layer_num=1,
+        bidirectional=False,
+        use_bias=True,
+        rnd_weights_init=False,
+        nonlinearity="tanh",
+        dropout=0.0,
+    ):
+        super().__init__()
+        # Shapes
+        self.shape = [seq_len, batch_size, feature_size]
+        if batch_first:
+            self.shape = [batch_size, seq_len, feature_size]
+        layers_num = 2 * layer_num if bidirectional else layer_num
+        self.h0_shape = [layers_num, batch_size, hidden_size]
+        # Dummy inputs
+        self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape))
+
+        self.model = nn.RNN(
+            input_size=feature_size,
+            hidden_size=hidden_size,
+            num_layers=layer_num,
+            nonlinearity=nonlinearity,
+            bias=use_bias,
+            batch_first=batch_first,
+            dropout=dropout,
+            bidirectional=bidirectional,
+        )
+
+        if rnd_weights_init:
+            self.gen_rnd_weights()
+
+    def gen_rnd_weights(self):
+        super().gen_rnd_weights()
+
+    def get_dummy_inputs(self):
+        return self.dummy_inputs
+
+    def get_input_names(self):
+        return ["input", "h0"]
+
+    def get_shape_desc(self, frontend_type):
+        shape_desc = None
+        if frontend_type == "pt":  # PyTorch
+            shape_desc = [("input", self.shape)]
+        elif frontend_type == "onnx":  # ONNX
+            shape_desc = {
+                "input": self.shape,
+                "h0": self.h0_shape,
+            }
+        return shape_desc
+
+    def get_tvm_inputs(self, dtype):
+        return {
+            "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)),
+            "h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)),
+        }
+
+
 class GRU_Model(RNN_Model):
     def __init__(
         self,
@@ -331,6 +401,10 @@ def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm -mcpu=core-avx2"
             args["bidirectional"] = True
         if "s" in rnn_mod:
             args["layer_num"] = num_layers
+        if "tanh" in rnn_mod:
+            args["nonlinearity"] = "tanh"
+        if "relu" in rnn_mod:
+            args["nonlinearity"] = "relu"
 
         if rnn_type == "GRU":
             RNN_Model_selector = GRU_Model
@@ -338,6 +412,8 @@ def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm -mcpu=core-avx2"
             RNN_Model_selector = LSTM_Model
             if "p" in rnn_mod:
                 args["proj_size"] = lstm_projection_size
+        elif rnn_type == "RNN":
+            RNN_Model_selector = RNN_Model_Impl
 
         return RNN_Model_selector(**args)
 
@@ -425,6 +501,9 @@ def test_rnns():
         for mod_type in ["uni", "s", "b", "sb"]:
             check_rnn("LSTM", mod_type, target, dev)
 
+        for mod_type in ["uni", "s", "b", "sb", "tanh", "relu"]:
+            check_rnn("RNN", mod_type, target, dev)
+
 
 if __name__ == "__main__":
     test_rnns()