You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/03/23 01:36:03 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] Onnx Support for
Transformer (#20048)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 833cb89 [v1.x] Onnx Support for Transformer (#20048)
833cb89 is described below
commit 833cb89e3e7a5262151a3b512d18a82d6de917be
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Mon Mar 22 18:34:10 2021 -0700
[v1.x] Onnx Support for Transformer (#20048)
* add ops
* add transformer test
* fix test
* add unit test
* fix sanity
* add to ci
---
ci/docker/runtime_functions.sh | 1 +
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 105 ++++++++++++-
tests/python-pytest/onnx/test_onnxruntime.py | 169 ++++++++++++++++++++-
tests/python-pytest/onnx/test_operators.py | 26 ++++
4 files changed, 295 insertions(+), 6 deletions(-)
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index a3484fc..9bfc841 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -1274,6 +1274,7 @@ integrationtest_ubuntu_cpu_onnx() {
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_action_recognition_model_inference_onnxruntime[inceptionv3_kinetics400]
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_bert_inference_onnxruntime
pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_dynamic_shape_cv_inference_onnxruntime
+ pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py::test_transformer_pretrained_inference_onnxruntime
}
integrationtest_ubuntu_gpu_python() {
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 462564a..0903376 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -2064,14 +2064,78 @@ def convert_broadcast_lesser(node, **kwargs):
"""Map MXNet's broadcast_lesser operator attributes to onnx's Less operator
and return the created node.
"""
- return create_basic_op_node('Less', node, kwargs)
+ from onnx.helper import make_node
+ name, input_nodes, _ = get_inputs(node, kwargs)
+ input_dtypes = get_input_dtypes(node, kwargs)
+
+ dtype = input_dtypes[0]
+ dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+ nodes = [
+ make_node('Less', [input_nodes[0], input_nodes[1]], [name+'_lt']),
+ make_node('Cast', [name+'_lt'], [name], to=dtype_t)
+ ]
+
+ return nodes
+
+
+@mx_op.register("broadcast_lesser_equal")
+def convert_broadcast_lesser_equal(node, **kwargs):
+ """Map MXNet's broadcast_lesser_equal operator
+ """
+ from onnx.helper import make_node
+ name, input_nodes, _ = get_inputs(node, kwargs)
+ input_dtypes = get_input_dtypes(node, kwargs)
+
+ dtype = input_dtypes[0]
+ dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+ nodes = [
+ make_node('LessOrEqual', [input_nodes[0], input_nodes[1]], [name+'_lt']),
+ make_node('Cast', [name+'_lt'], [name], to=dtype_t)
+ ]
+
+ return nodes
+
+
+@mx_op.register("broadcast_greater_equal")
+def convert_broadcast_greater_equal(node, **kwargs):
+ """Map MXNet's broadcast_greater_equal operator
+ """
+ from onnx.helper import make_node
+ name, input_nodes, _ = get_inputs(node, kwargs)
+ input_dtypes = get_input_dtypes(node, kwargs)
+
+ dtype = input_dtypes[0]
+ dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+ nodes = [
+ make_node('GreaterOrEqual', [input_nodes[0], input_nodes[1]], [name+'_gt']),
+ make_node('Cast', [name+'_gt'], [name], to=dtype_t)
+ ]
+
+ return nodes
+
@mx_op.register("broadcast_greater")
def convert_broadcast_greater(node, **kwargs):
"""Map MXNet's broadcast_greater operator attributes to onnx's Greater operator
and return the created node.
"""
- return create_basic_op_node('Greater', node, kwargs)
+ from onnx.helper import make_node
+ name, input_nodes, _ = get_inputs(node, kwargs)
+ input_dtypes = get_input_dtypes(node, kwargs)
+
+ dtype = input_dtypes[0]
+ dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+ nodes = [
+ make_node('Greater', [input_nodes[0], input_nodes[1]], [name+'_gt']),
+ make_node('Cast', [name+'_gt'], [name], to=dtype_t)
+ ]
+
+ return nodes
+
@mx_op.register("broadcast_equal")
def convert_broadcast_equal(node, **kwargs):
@@ -2498,7 +2562,6 @@ def convert_layer_norm(node, **kwargs):
axes = int(attrs.get('axis', -1))
eps = attrs.get('eps', 9.99999975e-06)
-
create_tensor([axes], name+"_axes", kwargs["initializer"])
create_tensor([axes+1], name+"_axes+1", kwargs["initializer"])
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
@@ -2519,7 +2582,11 @@ def convert_layer_norm(node, **kwargs):
if axes == -1:
nodes += [
make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]),
- make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name=name)
+ # make_node("Add", [name+"_mul0_out", input_nodes[2]], [name])
+ # the Add operator triggers a weird NaN issue in onnxruntime
+ # a workaround is to use Neg + Sub
+ make_node('Neg', [input_nodes[2]], [name+'_neg']),
+ make_node("Sub", [name+"_mul0_out", name+'_neg'], [name])
]
else:
nodes += [
@@ -4301,7 +4368,7 @@ def convert_RNN(node, **kwargs):
@mx_op.register('_rnn_param_concat')
def convert_rnn_param_concat(node, **kwargs):
- """Map MXNet’s _rnn_param_concat operator
+ """Map MXNet's _rnn_param_concat operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)
@@ -4313,3 +4380,31 @@ def convert_rnn_param_concat(node, **kwargs):
]
return nodes
+
+
+@mx_op.register('_contrib_div_sqrt_dim')
+def convert_contrib_div_sqrt_dim(node, **kwargs):
+ """Map MXNet's _contrib_div_sqrt_dim operator
+ """
+ from onnx.helper import make_node
+ name, input_nodes, _ = get_inputs(node, kwargs)
+ input_dtypes = get_input_dtypes(node, kwargs)
+
+ dtype = input_dtypes[0]
+ dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+ create_tensor([0], name+'_0', kwargs['initializer'])
+ create_tensor([1], name+'_1', kwargs['initializer'])
+ create_tensor([1], name+'_1_f', kwargs['initializer'], dtype=dtype)
+ nodes = [
+ make_node('Shape', [input_nodes[0]], [name+'_shape']),
+ make_node('Shape', [name+'_shape'], [name+'_dim']),
+ make_node('Sub', [name+'_dim', name+'_1'], [name+'_dim_m1']),
+ make_node('Slice', [name+'_shape', name+'_dim_m1', name+'_dim', name+'_0'], [name+'_c_']),
+ make_node('Cast', [name+'_c_'], [name+'_c'], to=dtype_t),
+ make_node('Sqrt', [name+'_c'], [name+'_c_sqrt']),
+ make_node('Div', [name+'_1_f', name+'_c_sqrt'], [name+'_1_over_c_sqrt']),
+ make_node('Mul', [input_nodes[0], name+'_1_over_c_sqrt'], [name])
+ ]
+
+ return nodes
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py
index e2a8329..ca4c0fe 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -723,7 +723,7 @@ def test_distilbert_inference_onnxruntime(tmp_path, model_name):
@with_seed()
@pytest.mark.parametrize('model_name', [('standard_lstm_lm_200', 200), ('standard_lstm_lm_650', 650),
('standard_lstm_lm_1500', 1500)])
-@pytest.mark.parametrize('seq_length', [16, 32])
+@pytest.mark.parametrize('seq_length', [64, 128])
def test_standard_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length):
try:
import gluonnlp as nlp
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name):
finally:
shutil.rmtree(tmp_path)
+
+
+@with_seed()
+@pytest.mark.parametrize('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+ tmp_path = str(tmp_path)
+ try:
+ import gluonnlp as nlp
+ dataset = 'WMT2014'
+ ctx = mx.cpu(0)
+ model, _, _ = nlp.model.get_model(
+ name=model_name,
+ ctx=ctx,
+ pretrained=True,
+ dataset_name=dataset)
+
+ model.hybridize(static_alloc=False)
+
+ batch = 7
+ seq_length = 16
+ C_in = 512
+ C_out = 512
+ src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+ step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32')
+ src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+ encoder_outputs, encoder_additional_outputs = model.encode(src,
+ valid_length=src_valid_length)
+
+ decoder_states = model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
+
+ step_output, states, additional_outputs = model.decode_step(step_input, decoder_states)
+
+ # skip export of 'decoder' as it's used for training only
+ for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 'tgt_embed',
+ 'tgt_proj']:
+
+ prefix = "%s/%s" %(tmp_path, component)
+ component = getattr(model, component)
+ component.export(prefix)
+ sym_file = "%s-symbol.json" % prefix
+ params_file = "%s-0000.params" % prefix
+ onnx_file = "%s.onnx" % prefix
+
+ def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
+ sym_file = "%s-symbol.json" % prefix
+ params_file = "%s-0000.params" % prefix
+ onnx_file = "%s.onnx" % prefix
+ return mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types,
+ onnx_file, **kwargs)
+
+ def onnx_runtime_predict(onnx_file, onnx_inputs):
+ ses_opt = onnxruntime.SessionOptions()
+ ses_opt.log_severity_level = 3
+ session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+ input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy())
+ for i in range(len(onnx_inputs)))
+ return session.run(None, input_dict)
+
+ def verify_encoder():
+ inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, C_in), dtype='float32')
+ valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+ pred = model.encoder(inputs, valid_length=valid_length)
+
+ prefix = "%s/encoder" %tmp_path
+ input_shapes = [(batch, seq_length, C_in), (batch,)]
+ input_types = [np.float32, np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [inputs, valid_length]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred[0], pred_onx[0])
+
+ def verify_src_embed():
+ src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+ pred = model.src_embed(src)
+
+ prefix = "%s/src_embed" %tmp_path
+ input_shapes = [(batch, seq_length)]
+ input_types = [np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [src]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0])
+
+ def verify_tgt_embed():
+ tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32')
+ pred = model.tgt_embed(tgt)
+
+ prefix = "%s/tgt_embed" %tmp_path
+ input_shapes = [(batch, seq_length)]
+ input_types = [np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [tgt]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0])
+
+ def verify_tgt_proj():
+ decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, seq_length, C_out),
+ dtype='float32')
+ pred = model.tgt_proj(decoder_out)
+
+ prefix = "%s/tgt_proj" %tmp_path
+ input_shapes = [(batch, seq_length, C_out)]
+ input_types = [np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [decoder_out]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
+
+ def verify_one_step_ahead_decoder():
+ prefix = "%s/one_step_ahead_decoder" %tmp_path
+
+ # the input data order
+ perm = [2, 0, 1]
+ input_shapes = [(batch, seq_length, C_in), (batch, seq_length, C_out),
+ (batch, seq_length)]
+ input_shapes = [input_shapes[i] for i in perm]
+ dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch, 'seq_length', C_out),
+ (batch, 'seq_length')]
+ dynamic_input_shapes = [dynamic_input_shapes[i] for i in perm]
+ input_types = [np.float32, np.float32, np.float32]
+ # do a dynamic export
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True,
+ dynamic_input_shapes=dynamic_input_shapes)
+
+ # step 0
+ step_input = mx.nd.random.uniform(-1, 1, shape=(batch, C_in), dtype='float32')
+ # mxnet
+ pred, step_states, _ = model.one_step_ahead_decoder(step_input, decoder_states)
+ # onnx
+ # note that we need to expand the sequence axis just like in here:
+ # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L831
+ input_onx = mx.nd.expand_dims(step_input, axis=1)
+ onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
+ onnx_inputs = [onnx_inputs[i] for i in perm]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0])
+
+ # step >= 1
+ for i in range(20):
+ step_input = mx.nd.random.uniform(-10*i, 10*i, shape=(batch, C_in), dtype='float32')
+ # mxnet
+ pred, step_states, _ = model.one_step_ahead_decoder(step_input, step_states)
+ # onnx
+ # note that we need to concat the step_input with the previous inpus
+ # just like in here:
+ # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/src/gluonnlp/model/transformer.py#L828
+ input_onx = mx.nd.concat(input_onx, mx.nd.expand_dims(step_input, axis=1), dim=1)
+ onnx_inputs = [input_onx, decoder_states[0], decoder_states[1]]
+ onnx_inputs = [onnx_inputs[i] for i in perm]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0])
+
+ verify_encoder()
+ verify_src_embed()
+ verify_tgt_embed()
+ verify_tgt_proj()
+ verify_one_step_ahead_decoder()
+
+ finally:
+ shutil.rmtree(tmp_path)
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 36b687a..220f259 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1234,3 +1234,29 @@ def test_onnx_export_RNN(tmp_path, dtype, state_size, input_size, num_layers, ba
state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype)
op_export_test('rnn', M, [x, param, state, cell], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
+@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
+def test_onnx_export_broadcast_lesser_equal(tmp_path, dtype, shapes):
+ A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype)
+ B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype)
+ M = def_model('broadcast_lesser_equal')
+ op_export_test('broadcast_lesser_equal', M, [A, B], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
+@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
+def test_onnx_export_broadcast_greater_equal(tmp_path, dtype, shapes):
+ A = mx.nd.random.uniform(0, 5, shapes[0]).astype('int32').astype(dtype)
+ B = mx.nd.random.uniform(0, 5, shapes[1]).astype('int32').astype(dtype)
+ M = def_model('broadcast_greater_equal')
+ op_export_test('broadcast_greater_equal', M, [A, B], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
+@pytest.mark.parametrize('shape', [(3, 4, 5), (6, 7), (8,)])
+def test_onnx_export_contrib_div_sqrt_dim(tmp_path, dtype, shape):
+ A = mx.nd.random.uniform(-100, 100, shape).astype(dtype)
+ M = def_model('contrib.div_sqrt_dim')
+ op_export_test('contrib_div_sqrt_dim', M, [A], tmp_path)