You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/15 06:08:21 UTC

[tvm] branch main updated: [Frontend] [ONNX] Support sequence_lens of GRU (#13587)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 06be0b3b26 [Frontend] [ONNX] Support sequence_lens of GRU (#13587)
06be0b3b26 is described below

commit 06be0b3b2681b86e0dfe7458f16de424df71230e
Author: Jianjian Guan <ja...@me.com>
AuthorDate: Thu Dec 15 14:08:15 2022 +0800

    [Frontend] [ONNX] Support sequence_lens of GRU (#13587)
    
    [Frontend] [ONNX] Support sequence_lens of GRU.
    
    Support convert sequence_lens input of GRU.
---
 python/tvm/relay/frontend/common.py        | 57 ++++++++++++++++++++++++++++--
 python/tvm/relay/frontend/onnx.py          | 18 ++++++----
 tests/python/frontend/onnx/test_forward.py | 40 +++++++++++++++++++--
 3 files changed, 104 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 5f961f1ae0..660426fb4a 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -737,6 +737,7 @@ def gru_cell(
     n_act=_op.tanh,
     backwards=False,
     linear_before_reset=True,
+    sequence_lens=None,
 ):
     """
     Common implementation of GRU cell for all frontends of TVM
@@ -765,7 +766,12 @@ def gru_cell(
         activation function for new gate. it is tanh by default
     backwards : bool
         Flag for reverse pass of GRU
-
+    linear_before_reset : bool
+        Flag for applying the linear transformation before multiplying by the output of the reset
+        gate.
+    sequence_lens : relay.op
+        Tensor specifying lengths of the sequences in a batch.
+        Shape = (batch_size)
     Returns
     -------
     result : List[relay.Expr], relay.Expr, relay.Expr
@@ -773,7 +779,40 @@ def gru_cell(
     """
 
     outputs_list = []
-    for x_t in input_seqs if not backwards else reversed(input_seqs):
+
+    seq_len = len(input_seqs)
+    input_dtype = infer_type(input_seqs[0]).checked_type.dtype
+
+    if sequence_lens is not None:
+        shape = infer_shape(sequence_lens)
+        dtype = infer_type(sequence_lens).checked_type.dtype
+
+        arange = _op.arange(_op.const(0), _op.const(seq_len), dtype=dtype)
+        arange = _op.expand_dims(arange, 1)
+        sequence_lens = _op.broadcast_to(sequence_lens, [seq_len, shape[0]])
+
+        # cast to data dtype
+        mask = _op.less(arange, sequence_lens)
+        mask = _op.cast(mask, dtype=input_dtype)
+        mask = _op.expand_dims(mask, 2)
+        mask_seqs = unbind(mask)
+
+        res_mask = _op.greater_equal(arange, sequence_lens)
+        res_mask = _op.cast(res_mask, dtype=input_dtype)
+        res_mask = _op.expand_dims(res_mask, 2)
+        res_mask_seqs = unbind(res_mask)
+
+        if backwards:
+            # need a mask to keep intial_h_B correct
+            initial_h = hidden_state
+            initial_h_mask = _op.equal(arange, sequence_lens)
+            initial_h_mask = _op.cast(initial_h_mask, dtype=input_dtype)
+            initial_h_mask = _op.expand_dims(initial_h_mask, 2)
+            initial_h_mask_seqs = unbind(initial_h_mask)
+
+    output = _op.zeros(infer_shape(hidden_state), input_dtype)
+    for i in range(seq_len) if not backwards else reversed(range(seq_len)):
+        x_t = input_seqs[i]
         xwt = _op.nn.dense(x_t, w_inp)
         if linear_before_reset:
             hwt = _op.nn.dense(hidden_state, w_hid)
@@ -806,9 +845,21 @@ def gru_cell(
 
         hidden_state = (hidden_state - n_gate) * z_gate + n_gate
 
+        if sequence_lens is not None:
+            hidden_state = hidden_state * mask_seqs[i]
+
         outputs_list.append(hidden_state)  # [seq_num, (batch, hidden_size)]
 
-    return outputs_list, hidden_state
+        if sequence_lens is not None:
+            output = output * res_mask_seqs[i] + hidden_state
+        else:
+            output = hidden_state
+
+        # make sure initial_h_B correct
+        if backwards and sequence_lens is not None:
+            hidden_state = hidden_state + initial_h * initial_h_mask_seqs[i]
+
+    return outputs_list, output
 
 
 def lstm_cell(
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 3470099100..a8ab626025 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -3126,8 +3126,7 @@ class RNN(OnnxOpConverter):
         Wp = inputs[1]
         Rp = inputs[2]
         Bp = inputs[3]
-        # Sequence length currently unused as it can be inferred from shapes.
-        # sequence_lens = inputs['sequence_lens']
+        sequence_lens = inputs[4]
         Hp_0 = inputs[5]
 
         num_directions = infer_shape(Wp)[0]
@@ -3158,11 +3157,11 @@ class RNN(OnnxOpConverter):
         Bs = None
         if Bp is not None:
             Bs = _op.split(Bp, num_directions)
-        return X_steps, H_ts, Ws, Rs, Bs, num_directions
+        return X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens
 
     @classmethod
     def _impl_common(cls, inputs, attr, layout):
-        X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
+        X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
         acts = cls._get_activations(attr, 1, num_directions, "RNN")
 
         weights_dicts = []
@@ -3261,7 +3260,7 @@ class LSTM(RNN):
 
     @classmethod
     def _impl_common(cls, inputs, attr, layout):
-        X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
+        X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
         acts = cls._get_activations(attr, 3, num_directions, "LSTM")
 
         # cell state
@@ -3346,6 +3345,7 @@ class GRU(RNN):
         input_seqs,
         weight_dicts,
         acts,
+        sequence_lens=None,
     ):
         """
         Bidirectional GRU cell
@@ -3356,6 +3356,7 @@ class GRU(RNN):
             **weight_dicts[0],
             rz_act=acts[0],
             n_act=acts[1],
+            sequence_lens=sequence_lens,
         )
 
         reverse_outputs, rev_H_t = gru_cell(
@@ -3364,6 +3365,7 @@ class GRU(RNN):
             rz_act=acts[2],
             n_act=acts[3],
             backwards=True,
+            sequence_lens=sequence_lens,
         )
 
         final_outputs = []
@@ -3383,7 +3385,9 @@ class GRU(RNN):
 
     @classmethod
     def _impl_common(cls, inputs, attr, layout):
-        X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
+        X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens = cls._inputs_helper(
+            inputs, layout
+        )
         acts = cls._get_activations(attr, 2, num_directions, "GRU")
         linear_before_reset = attr.get("linear_before_reset", 0)
 
@@ -3412,6 +3416,7 @@ class GRU(RNN):
                 input_seqs=X_steps,
                 weight_dicts=weights_dicts,
                 acts=acts,
+                sequence_lens=sequence_lens,
             )
         else:
             # outputs shape = [seqs_num, (batch_size, hidden_size)]
@@ -3420,6 +3425,7 @@ class GRU(RNN):
                 **weights_dicts[0],
                 rz_act=acts[0],
                 n_act=acts[1],
+                sequence_lens=sequence_lens,
             )
 
             # output shape = (seqs_num, num_directions, batch_size, hidden_size)
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index dcd4f2defb..92a87ff6a7 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -3897,6 +3897,7 @@ def verify_rnn(
     atol=1e-5,
     target=None,
     dev=None,
+    use_sequence_lens=False,
 ):
     """verify_rnn"""
     if rnn_type == "RNN":
@@ -3954,10 +3955,16 @@ def verify_rnn(
             )
             register(b_np, "B")
 
+        if use_sequence_lens:
+            sequence_np = np.random.uniform(0, seq_length, size=(batch_size)).astype("int32")
+            register(sequence_np, "sequence_lens")
+
         if use_initial_state:
             assert use_bias is True, "Initial states must have bias specified."
-            sequence_np = np.repeat(seq_length, batch_size).astype("int32")
-            register(sequence_np, "sequence_lens")
+
+            if not use_sequence_lens:
+                sequence_np = np.repeat(seq_length, batch_size).astype("int32")
+                register(sequence_np, "sequence_lens")
 
             if layout == 1:
                 initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype(
@@ -4211,6 +4218,35 @@ def verify_rnn_helper(target, dev, rnn_type):
         #     dev=dev,
         # )
 
+        # Testing with initial state
+        if rnn_type == "GRU":
+            verify_rnn(
+                seq_length=2,
+                batch_size=1,
+                input_size=16,
+                hidden_size=32,
+                use_bias=True,
+                use_initial_state=True,
+                rnn_type=rnn_type,
+                directions=directions,
+                target=target,
+                dev=dev,
+                use_sequence_lens=True,
+            )
+            verify_rnn(
+                seq_length=8,
+                batch_size=8,
+                input_size=16,
+                hidden_size=32,
+                use_bias=True,
+                use_initial_state=True,
+                rnn_type=rnn_type,
+                directions=directions,
+                target=target,
+                dev=dev,
+                use_sequence_lens=True,
+            )
+
         # Testing with peepholes
         if rnn_type == "LSTM":
             verify_rnn(