You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/15 06:08:21 UTC
[tvm] branch main updated: [Frontend] [ONNX] Support sequence_lens of GRU (#13587)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 06be0b3b26 [Frontend] [ONNX] Support sequence_lens of GRU (#13587)
06be0b3b26 is described below
commit 06be0b3b2681b86e0dfe7458f16de424df71230e
Author: Jianjian Guan <ja...@me.com>
AuthorDate: Thu Dec 15 14:08:15 2022 +0800
[Frontend] [ONNX] Support sequence_lens of GRU (#13587)
[Frontend] [ONNX] Support sequence_lens of GRU.
Support convert sequence_lens input of GRU.
---
python/tvm/relay/frontend/common.py | 57 ++++++++++++++++++++++++++++--
python/tvm/relay/frontend/onnx.py | 18 ++++++----
tests/python/frontend/onnx/test_forward.py | 40 +++++++++++++++++++--
3 files changed, 104 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 5f961f1ae0..660426fb4a 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -737,6 +737,7 @@ def gru_cell(
n_act=_op.tanh,
backwards=False,
linear_before_reset=True,
+ sequence_lens=None,
):
"""
Common implementation of GRU cell for all frontends of TVM
@@ -765,7 +766,12 @@ def gru_cell(
activation function for new gate. it is tanh by default
backwards : bool
Flag for reverse pass of GRU
-
+ linear_before_reset : bool
+ Flag for applying the linear transformation before multiplying by the output of the reset
+ gate.
+ sequence_lens : relay.op
+ Tensor specifying lengths of the sequences in a batch.
+ Shape = (batch_size)
Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
@@ -773,7 +779,40 @@ def gru_cell(
"""
outputs_list = []
- for x_t in input_seqs if not backwards else reversed(input_seqs):
+
+ seq_len = len(input_seqs)
+ input_dtype = infer_type(input_seqs[0]).checked_type.dtype
+
+ if sequence_lens is not None:
+ shape = infer_shape(sequence_lens)
+ dtype = infer_type(sequence_lens).checked_type.dtype
+
+ arange = _op.arange(_op.const(0), _op.const(seq_len), dtype=dtype)
+ arange = _op.expand_dims(arange, 1)
+ sequence_lens = _op.broadcast_to(sequence_lens, [seq_len, shape[0]])
+
+ # cast to data dtype
+ mask = _op.less(arange, sequence_lens)
+ mask = _op.cast(mask, dtype=input_dtype)
+ mask = _op.expand_dims(mask, 2)
+ mask_seqs = unbind(mask)
+
+ res_mask = _op.greater_equal(arange, sequence_lens)
+ res_mask = _op.cast(res_mask, dtype=input_dtype)
+ res_mask = _op.expand_dims(res_mask, 2)
+ res_mask_seqs = unbind(res_mask)
+
+ if backwards:
+ # need a mask to keep intial_h_B correct
+ initial_h = hidden_state
+ initial_h_mask = _op.equal(arange, sequence_lens)
+ initial_h_mask = _op.cast(initial_h_mask, dtype=input_dtype)
+ initial_h_mask = _op.expand_dims(initial_h_mask, 2)
+ initial_h_mask_seqs = unbind(initial_h_mask)
+
+ output = _op.zeros(infer_shape(hidden_state), input_dtype)
+ for i in range(seq_len) if not backwards else reversed(range(seq_len)):
+ x_t = input_seqs[i]
xwt = _op.nn.dense(x_t, w_inp)
if linear_before_reset:
hwt = _op.nn.dense(hidden_state, w_hid)
@@ -806,9 +845,21 @@ def gru_cell(
hidden_state = (hidden_state - n_gate) * z_gate + n_gate
+ if sequence_lens is not None:
+ hidden_state = hidden_state * mask_seqs[i]
+
outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]
- return outputs_list, hidden_state
+ if sequence_lens is not None:
+ output = output * res_mask_seqs[i] + hidden_state
+ else:
+ output = hidden_state
+
+ # make sure initial_h_B correct
+ if backwards and sequence_lens is not None:
+ hidden_state = hidden_state + initial_h * initial_h_mask_seqs[i]
+
+ return outputs_list, output
def lstm_cell(
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 3470099100..a8ab626025 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -3126,8 +3126,7 @@ class RNN(OnnxOpConverter):
Wp = inputs[1]
Rp = inputs[2]
Bp = inputs[3]
- # Sequence length currently unused as it can be inferred from shapes.
- # sequence_lens = inputs['sequence_lens']
+ sequence_lens = inputs[4]
Hp_0 = inputs[5]
num_directions = infer_shape(Wp)[0]
@@ -3158,11 +3157,11 @@ class RNN(OnnxOpConverter):
Bs = None
if Bp is not None:
Bs = _op.split(Bp, num_directions)
- return X_steps, H_ts, Ws, Rs, Bs, num_directions
+ return X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens
@classmethod
def _impl_common(cls, inputs, attr, layout):
- X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
+ X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
acts = cls._get_activations(attr, 1, num_directions, "RNN")
weights_dicts = []
@@ -3261,7 +3260,7 @@ class LSTM(RNN):
@classmethod
def _impl_common(cls, inputs, attr, layout):
- X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
+ X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
acts = cls._get_activations(attr, 3, num_directions, "LSTM")
# cell state
@@ -3346,6 +3345,7 @@ class GRU(RNN):
input_seqs,
weight_dicts,
acts,
+ sequence_lens=None,
):
"""
Bidirectional GRU cell
@@ -3356,6 +3356,7 @@ class GRU(RNN):
**weight_dicts[0],
rz_act=acts[0],
n_act=acts[1],
+ sequence_lens=sequence_lens,
)
reverse_outputs, rev_H_t = gru_cell(
@@ -3364,6 +3365,7 @@ class GRU(RNN):
rz_act=acts[2],
n_act=acts[3],
backwards=True,
+ sequence_lens=sequence_lens,
)
final_outputs = []
@@ -3383,7 +3385,9 @@ class GRU(RNN):
@classmethod
def _impl_common(cls, inputs, attr, layout):
- X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
+ X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens = cls._inputs_helper(
+ inputs, layout
+ )
acts = cls._get_activations(attr, 2, num_directions, "GRU")
linear_before_reset = attr.get("linear_before_reset", 0)
@@ -3412,6 +3416,7 @@ class GRU(RNN):
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
+ sequence_lens=sequence_lens,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
@@ -3420,6 +3425,7 @@ class GRU(RNN):
**weights_dicts[0],
rz_act=acts[0],
n_act=acts[1],
+ sequence_lens=sequence_lens,
)
# output shape = (seqs_num, num_directions, batch_size, hidden_size)
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index dcd4f2defb..92a87ff6a7 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -3897,6 +3897,7 @@ def verify_rnn(
atol=1e-5,
target=None,
dev=None,
+ use_sequence_lens=False,
):
"""verify_rnn"""
if rnn_type == "RNN":
@@ -3954,10 +3955,16 @@ def verify_rnn(
)
register(b_np, "B")
+ if use_sequence_lens:
+ sequence_np = np.random.uniform(0, seq_length, size=(batch_size)).astype("int32")
+ register(sequence_np, "sequence_lens")
+
if use_initial_state:
assert use_bias is True, "Initial states must have bias specified."
- sequence_np = np.repeat(seq_length, batch_size).astype("int32")
- register(sequence_np, "sequence_lens")
+
+ if not use_sequence_lens:
+ sequence_np = np.repeat(seq_length, batch_size).astype("int32")
+ register(sequence_np, "sequence_lens")
if layout == 1:
initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype(
@@ -4211,6 +4218,35 @@ def verify_rnn_helper(target, dev, rnn_type):
# dev=dev,
# )
+ # Testing with initial state
+ if rnn_type == "GRU":
+ verify_rnn(
+ seq_length=2,
+ batch_size=1,
+ input_size=16,
+ hidden_size=32,
+ use_bias=True,
+ use_initial_state=True,
+ rnn_type=rnn_type,
+ directions=directions,
+ target=target,
+ dev=dev,
+ use_sequence_lens=True,
+ )
+ verify_rnn(
+ seq_length=8,
+ batch_size=8,
+ input_size=16,
+ hidden_size=32,
+ use_bias=True,
+ use_initial_state=True,
+ rnn_type=rnn_type,
+ directions=directions,
+ target=target,
+ dev=dev,
+ use_sequence_lens=True,
+ )
+
# Testing with peepholes
if rnn_type == "LSTM":
verify_rnn(