You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/30 03:24:31 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export support
for GRU (#20060)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 0a3719a [v1.x] ONNX export support for GRU (#20060)
0a3719a is described below
commit 0a3719a706bdf6a93316ad1e6e7326c86a73138e
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Mon Mar 29 20:22:31 2021 -0700
[v1.x] ONNX export support for GRU (#20060)
* export gru
* fix sanity
* reduce change state_size
Co-authored-by: Wei Chu <we...@amazon.com>
Co-authored-by: Zhaoqi Zhu <zh...@gmail.com>
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 350 ++++++++++++++-------
tests/python-pytest/onnx/test_operators.py | 31 +-
2 files changed, 251 insertions(+), 130 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 22e6282..eb91a3c 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -4240,10 +4240,6 @@ def convert_RNN(node, **kwargs):
name, input_nodes, attrs = get_inputs(node, kwargs)
- mode = str(attrs.get('mode'))
- if mode != 'lstm':
- raise NotImplementedError('Currently RNN onnx export only supports lstm mode')
-
bidirectional = str(attrs.get('bidirectional', 'False'))
if bidirectional != 'False':
raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')
@@ -4270,128 +4266,246 @@ def convert_RNN(node, **kwargs):
data = input_nodes[0]
param = input_nodes[1]
initial_h = input_nodes[2]
- initial_c = input_nodes[3]
nodes = []
- if num_layers == 2:
- create_tensor([0], name+'_0', kwargs['initializer'])
- create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
- create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+ mode = str(attrs.get('mode'))
+ if mode == 'lstm':
+ initial_c = input_nodes[3]
+ if num_layers == 2:
+ create_tensor([0], name+'_0', kwargs['initializer'])
+ create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+ create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+ create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
+ create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+ create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
- create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
- create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+ nodes += [
+ make_node('Shape', [data], [name+'_data_shape']),
+ make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+
+ # Layer 0
+ # get W
+ make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']),
+ make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']),
+ make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0),
+ make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
+ # get R
+ make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']),
+ make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
+ make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']),
+ make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0),
+ make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
+ # get B
+ make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']),
+ make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
+ make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03',
+ name+'_B04', name+'_B05', name+'_B06', name+'_B07']),
+ make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02',
+ name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0),
+ make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
+ # get initial states
+ make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
+ make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0),
+ # get seq_len
+ make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+ make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+ # Layer 0 LSTM
+ make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
+ name+'_initial_h0', name+'_initial_c0'],
+ [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size),
+ make_node('Squeeze', [name+'_lstm0_out_'], [name+'_lstm0_out'], axes=[1]),
+
+ # Layer 1
+ # get W
+ make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']),
+ make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
+ make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']),
+ make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0),
+ make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
+ # get R
+ make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
+ make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']),
+ make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0),
+ make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
+ # get B
+ make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']),
+ make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
+ make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13',
+ name+'_B14', name+'_B15', name+'_B16', name+'_B17']),
+ make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12',
+ name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0),
+ make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
+ # Layer 1 LSTM
+ make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
+ name+'_initial_h1', name+'_initial_c1'],
+ [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size),
+ make_node('Squeeze', [name+'_lstm1_out_'], [name], axes=[1]),
+ make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0),
+ make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
+ ]
+ elif num_layers == 1:
+ create_tensor([0], name+'_0', kwargs['initializer'])
+ create_tensor([1], name+'_1', kwargs['initializer'])
+ create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
+ create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
+ create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
+ create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+ create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
- create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
+ nodes += [
+ make_node('Shape', [data], [name+'_data_shape']),
+ make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+ # get W
+ make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
+ make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+ make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
+ make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
+ make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+ make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+ # get R
+ make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
+ make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+ make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
+ make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
+ make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+ # get B
+ make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
+ make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+ make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
+ name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
+ make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
+ name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
+ make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+ # get seq_len
+ make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+ make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+ # compute LSTM
+ make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
+ [name+'0_', name+'1', name+'2'], hidden_size=state_size),
+ make_node('Squeeze', [name+'0_'], [name], axes=[1]),
+ ]
+ else:
+ raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
- nodes += [
- make_node('Shape', [data], [name+'_data_shape']),
- make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
-
- # Layer 0
- # get W
- make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']),
- make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']),
- make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0),
- make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
- # get R
- make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']),
- make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
- make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']),
- make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0),
- make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
- # get B
- make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']),
- make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
- make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03',
- name+'_B04', name+'_B05', name+'_B06', name+'_B07']),
- make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02',
- name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0),
- make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
- # get initial states
- make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
- make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0),
- # get seq_len
- make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
- make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
- # Layer 0 LSTM
- make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
- name+'_initial_h0', name+'_initial_c0'],
- [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size),
- make_node('Squeeze', [name+'_lstm0_out_'], [name+'_lstm0_out'], axes=[1]),
-
- # Layer 1
- # get W
- make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']),
- make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
- make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']),
- make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0),
- make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
- # get R
- make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
- make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']),
- make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0),
- make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
- # get B
- make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']),
- make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
- make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13',
- name+'_B14', name+'_B15', name+'_B16', name+'_B17']),
- make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12',
- name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0),
- make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
- # Layer 1 LSTM
- make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
- name+'_initial_h1', name+'_initial_c1'],
- [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size),
- make_node('Squeeze', [name+'_lstm1_out_'], [name], axes=[1]),
- make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0),
- make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
- ]
- elif num_layers == 1:
- create_tensor([0], name+'_0', kwargs['initializer'])
- create_tensor([1], name+'_1', kwargs['initializer'])
- create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
- create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
- create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
- create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
- create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
+ elif mode == 'gru':
+ if num_layers == 2:
+ create_tensor([0], name+'_0', kwargs['initializer'])
+ create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
+ create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
+ create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
+ create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer'])
+ create_tensor([4*3*state_size*state_size], name+'_WR_offset', kwargs['initializer'])
- nodes += [
- make_node('Shape', [data], [name+'_data_shape']),
- make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
- # get W
- make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
- make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
- make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
- make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
- make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
- make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
- # get R
- make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
- make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
- make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
- make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
- make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
- # get B
- make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
- make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
- make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
- name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
- make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
- name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
- make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
- # get seq_len
- make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
- make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
- # compute LSTM
- make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
- [name+'0_', name+'1', name+'2'], hidden_size=state_size),
- make_node('Squeeze', [name+'0_'], [name], axes=[1]),
- ]
- else:
- raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')
+ nodes += [
+ make_node('Shape', [data], [name+'_data_shape']),
+ make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+
+ # Layer 0
+ # get W
+ make_node('Slice', [param, name+'_0', name+'_3*state_size^2'], [name+'_W0_1d']),
+ make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02']),
+ make_node('Concat', [name+'_W01', name+'_W00', name+'_W02'], [name+'_W0_'], axis=0),
+ make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']),
+ # get R
+ make_node('Add', [name+'_3*state_size^2', name+'_3*state_size^2'], [name+'_R0_offset']),
+ make_node('Slice', [param, name+'_3*state_size^2', name+'_R0_offset'], [name+'_R0_1d']),
+ make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02']),
+ make_node('Concat', [name+'_R01', name+'_R00', name+'_R02'], [name+'_R0_'], axis=0),
+ make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']),
+ # get B
+ make_node('Add', [name+'_WR_offset', name+'_6*state_size'], [name+'_B0_offset']),
+ make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']),
+ make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02',
+ name+'_B03', name+'_B04', name+'_B05']),
+ make_node('Concat', [name+'_B01', name+'_B00', name+'_B02',
+ name+'_B04', name+'_B03', name+'_B05'], [name+'_B0_'], axis=0),
+ make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']),
+ # get initial states
+ make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0),
+ # get seq_len
+ make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+ make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+ # Layer 0 GRU
+ make_node('GRU', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len',
+ name+'_initial_h0'],
+ [name+'_gru0_out_', name+'_gru0_h'], hidden_size=state_size, linear_before_reset=1),
+ make_node('Squeeze', [name+'_gru0_out_'], [name+'_gru0_out'], axes=[1]),
+
+ # Layer 1
+ # get W
+ make_node('Add', [name+'_R0_offset', name+'_3*state_size^2'], [name+'_W1_offset']),
+ make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']),
+ make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12']),
+ make_node('Concat', [name+'_W11', name+'_W10', name+'_W12'], [name+'_W1_'], axis=0),
+ make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']),
+ # get R
+ make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']),
+ make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12']),
+ make_node('Concat', [name+'_R11', name+'_R10', name+'_R12'], [name+'_R1_'], axis=0),
+ make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']),
+ # get B
+ make_node('Add', [name+'_B0_offset', name+'_6*state_size'], [name+'_B1_offset']),
+ make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']),
+ make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12',
+ name+'_B13', name+'_B14', name+'_B15']),
+ make_node('Concat', [name+'_B11', name+'_B10', name+'_B12',
+ name+'_B14', name+'_B13', name+'_B15'], [name+'_B1_'], axis=0),
+ make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']),
+ # Layer 1 GRU
+ make_node('GRU', [name+'_gru0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len',
+ name+'_initial_h1'],
+ [name+'_gru1_out_', name+'_gru1_h'], hidden_size=state_size, linear_before_reset=1),
+ make_node('Squeeze', [name+'_gru1_out_'], [name], axes=[1]),
+ make_node('Concat', [name+'_gru0_h', name+'_gru1_h'], [name+'1'], axis=0)
+ ]
+ elif num_layers == 1:
+ create_tensor([0], name+'_0', kwargs['initializer'])
+ create_tensor([1], name+'_1', kwargs['initializer'])
+ create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer'])
+ create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer'])
+ create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer'])
+ create_tensor([1, 3*state_size, state_size], name+'_R_shape', kwargs['initializer'])
+ create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer'])
+
+ nodes += [
+ make_node('Shape', [data], [name+'_data_shape']),
+ make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
+ # get W
+ make_node('Mul', [name+'_3*state_size', name+'_input_size'], [name+'_mul0']),
+ make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
+ make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2']),
+ make_node('Concat', [name+'_W1', name+'_W0', name+'_W2'], [name+'_W_'], axis=0),
+ make_node('Concat', [name+'_1', name+'_3*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
+ make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
+ # get R
+ make_node('Add', [name+'_mul0', name+'_3*state_size^2'], [name+'_add0']),
+ make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
+ make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2']),
+ make_node('Concat', [name+'_R1', name+'_R0', name+'_R2'], [name+'_R_'], axis=0),
+ make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
+ # get B
+ make_node('Add', [name+'_add0', name+'_6*state_size'], [name+'_add1']),
+ make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
+ make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2',
+ name+'_B3', name+'_B4', name+'_B5']),
+ make_node('Concat', [name+'_B1', name+'_B0', name+'_B2',
+ name+'_B4', name+'_B3', name+'_B5'], [name+'_B_'], axis=0),
+ make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
+ # get seq_len
+ make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
+ make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
+ # compute LSTM
+ make_node('GRU', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h],
+ [name+'0_', name+'1'], hidden_size=state_size, linear_before_reset=1),
+ make_node('Squeeze', [name+'0_'], [name], axes=[1]),
+ ]
+ else:
+ raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1')
+
+ else:
+ raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
return nodes
@mx_op.register('_rnn_param_concat')
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 520ac40..51170e6 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1221,27 +1221,34 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params):
# onnx LSTM from opset 11 does not support float64
+@pytest.mark.parametrize('mode', ['lstm', 'gru'])
@pytest.mark.parametrize('dtype', ['float32'])
-@pytest.mark.parametrize('state_size', [32, 40])
-@pytest.mark.parametrize('input_size', [32, 40, 64])
+@pytest.mark.parametrize('state_size', [16, 32])
+@pytest.mark.parametrize('input_size', [16, 32, 64])
@pytest.mark.parametrize('num_layers', [1, 2])
-@pytest.mark.parametrize('batch_size', [1, 3, 5])
+@pytest.mark.parametrize('batch_size', [1, 2, 4])
@pytest.mark.parametrize('seq_length', [16, 32])
-def test_onnx_export_RNN(tmp_path, dtype, state_size, input_size, num_layers, batch_size, seq_length):
+def test_onnx_export_RNN(tmp_path, mode, dtype, state_size, input_size, num_layers, batch_size, seq_length):
# TODO: The current implementation fails assertion checks for large parm/state_size.
-
+
# for num_layers >= 2, input_size must equal to state_size
if num_layers >= 2 and input_size != state_size:
return
-
- M = def_model('RNN', mode='lstm', state_size=state_size, state_outputs=True, num_layers=num_layers, p=0)
+ factor = 3
+ if mode == 'lstm':
+ factor = 4
+
+ M = def_model('RNN', mode=mode, state_size=state_size, state_outputs=True, num_layers=num_layers, p=0)
x = mx.nd.random.normal(0, 10, (seq_length, batch_size, input_size), dtype=dtype)
- param = mx.nd.random.normal(0, 1, [num_layers*4*state_size*input_size +
- num_layers*4*state_size*state_size +
- num_layers*8*state_size], dtype=dtype)
+ param = mx.nd.random.normal(0, 1, [num_layers*factor*state_size*input_size +
+ num_layers*factor*state_size*state_size +
+ num_layers*2*factor*state_size], dtype=dtype)
state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
- cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
- op_export_test('rnn', M, [x, param, state, cell], tmp_path)
+ if mode == 'lstm':
+ cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
+ op_export_test('rnn', M, [x, param, state, cell], tmp_path)
+ else:
+ op_export_test('rnn', M, [x, param, state], tmp_path)
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])