You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/16 20:32:18 UTC

[GitHub] [incubator-tvm] sxjscience opened a new pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

sxjscience opened a new pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699


   Fix the MXNet 2.0 integration in relay. Tested the BERT and ALBERT model in the new [GluonNLP 1.0](https://github.com/dmlc/gluon-nlp/tree/master) and has passed the test. I will later add unittests in GluonNLP side to ensure that the backbones can be run with the graph runtime.
   
   ```python
   import mxnet as mx
   import numpy as np
   import gluonnlp
   from gluonnlp.models import get_backbone
   import numpy.testing as npt
   
   mx.npx.set_np()
   
   model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone('google_albert_base_v2')
   
   model = model_cls.from_cfg(cfg)
   model.load_parameters(backbone_param_path)
   model.hybridize()
   
   
   batch_size = 1
   seq_length = 128
   token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32)
   token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32)
   valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32)
   mx_out = model(token_ids, token_types, valid_length)
   
   import tvm
   from tvm import relay
   import tvm.contrib.graph_runtime as runtime
   
   shape_dict = {
       'data0': (batch_size, seq_length),
       'data1': (batch_size, seq_length),
       'data2': (batch_size,)
   }
   
   dtype_dict = {
       'data0': 'int32',
       'data1': 'int32',
       'data2': 'int32'
   }
   
   sym = model._cached_graph[1]
   
   params = {}
   for k, v in model.collect_params().items():
       params[v._var_name] = tvm.nd.array(v.data().asnumpy())
   mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
   print(mod)
   # G4
   target = "cuda -model=t4"
   
   with relay.build_config(opt_level=3, required_pass=["FastMath"]):
       graph, lib, cparams = relay.build(mod, target, params=params)
   
   ctx = tvm.gpu()
   rt = runtime.create(graph, lib, ctx)
   rt.set_input(**cparams)
   rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
   rt.run()
   for i in range(rt.get_num_outputs()):
       out = rt.get_output(i)
       print(out.asnumpy())# verify the correctness
       npt.assert_allclose(out.asnumpy(), mx_out[i].asnumpy(), rtol=1e-3, atol=1e-2)
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jroesch commented on pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#issuecomment-710677184


   As we add more tests can we measure what kind of time increase this will induce in CI? integration tests are becoming increasingly slow and expensive to run. cc @areusch and @tkonolige 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#issuecomment-711433430


   Thanks @sxjscience @yzhliu. The test simplification could be in the follow up PRs.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sxjscience commented on a change in pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#discussion_r506775666



##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -627,6 +632,21 @@ def _mx_expand_dims(inputs, attrs):
     return _op.expand_dims(inputs[0], axis=axis)
 
 
+def _mx_where(inputs, attrs):

Review comment:
       I removed the `_mx_where` and used the old implementation.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac merged pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
comaniac merged pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sxjscience commented on a change in pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#discussion_r506766247



##########
File path: python/tvm/topi/x86/batch_matmul.py
##########
@@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
     YB, N, YK = get_const_tuple(y.shape)
     assert XB == YB, "batch dimension doesn't match"
     assert XK == YK, "shapes of x and y is inconsistant"
+    if out_shape is not None:
+        assert out_shape[0] == XB, "got invalid output shape"
+        assert out_shape[1] == M, "got invalid output shape"
+        assert out_shape[2] == N, "got invalid output shape"

Review comment:
       The reason is that if we do not add this, running the end-to-end script with `target = "llvm -mcpu=skylake-avx512 -libs=cblas"` will trigger the following error:
   ```python
   TypeError: Traceback (most recent call last):
     [bt] (8) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::GraphRuntimeCodegen::VisitExpr_(tvm::relay::CallNode const*)+0xf12) [0x7f8f383774b2]
     [bt] (7) /home/ubuntu/tvm/build/libtvm.so(+0xf87235) [0x7f8f3834b235]
     [bt] (6) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)+0x8a1) [0x7f8f38355f81]
     [bt] (5) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::ScheduleGetter::Create(tvm::relay::Function const&)+0x25b) [0x7f8f3835265b]
     [bt] (4) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)+0xa9) [0x7f8f38358b89]
     [bt] (3) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x82) [0x7f8f38358952]
     [bt] (2) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)#6}::_FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)+0x27) [0x7f8f3834b717]
     [bt] (1) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr_(tvm::relay::CallNode const*)+0x68c) [0x7f8f3835175c]
     [bt] (0) /home/ubuntu/tvm/build/libtvm.so(+0x112beab) [0x7f8f384efeab]
     File "tvm/_ffi/_cython/./packed_func.pxi", line 55, in tvm._ffi._cy3.core.tvm_callback
     File "/home/ubuntu/tvm/python/tvm/relay/backend/compile_engine.py", line 284, in lower_call
       best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
     File "/home/ubuntu/tvm/python/tvm/relay/backend/compile_engine.py", line 206, in select_implementation
       outs = impl.compute(attrs, inputs, out_type)
     File "/home/ubuntu/tvm/python/tvm/relay/op/op.py", line 91, in compute
       return _OpImplementationCompute(self, attrs, inputs, out_type)
     File "tvm/_ffi/_cython/./packed_func.pxi", line 321, in tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 266, in tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
     [bt] (3) /home/ubuntu/tvm/build/libtvm.so(TVMFuncCall+0x65) [0x7f8f384f3205]
     [bt] (2) /home/ubuntu/tvm/build/libtvm.so(+0x104b8c8) [0x7f8f3840f8c8]
     [bt] (1) /home/ubuntu/tvm/build/libtvm.so(tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)+0xb1) [0x7f8f3840f691]
     [bt] (0) /home/ubuntu/tvm/build/libtvm.so(+0x112beab) [0x7f8f384efeab]
     File "tvm/_ffi/_cython/./packed_func.pxi", line 55, in tvm._ffi._cy3.core.tvm_callback
     File "/home/ubuntu/tvm/python/tvm/relay/op/strategy/generic.py", line 686, in _compute_batch_matmul
       return [topi_compute(inputs[0], inputs[1], out_type.shape)]
     File "/home/ubuntu/tvm/python/tvm/autotvm/task/topi_integration.py", line 162, in wrapper
       node = topi_compute(cfg, *args)
   TypeError: batch_matmul_cblas() takes 3 positional arguments but 4 were given
   ```
   
   The root cause is that the logic here requires the batch_matmul to take the output_shape: 
   https://github.com/apache/incubator-tvm/blob/461e75bd5ffaf45a0f270998514d444463d11261/python/tvm/relay/op/strategy/generic.py#L685-L686




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#discussion_r506769623



##########
File path: python/tvm/topi/x86/batch_matmul.py
##########
@@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
     YB, N, YK = get_const_tuple(y.shape)
     assert XB == YB, "batch dimension doesn't match"
     assert XK == YK, "shapes of x and y is inconsistant"
+    if out_shape is not None:
+        assert out_shape[0] == XB, "got invalid output shape"
+        assert out_shape[1] == M, "got invalid output shape"
+        assert out_shape[2] == N, "got invalid output shape"

Review comment:
       This is interesting...I searched for all batch_matmul computes and found that this is the only one that misses one argument. It means this compute is never used before.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sxjscience commented on pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#issuecomment-711111989


   I've verified the TVM integration with 5 NLP backbones in GluonNLP: BERT, ALBERT, ELECTRA, RoBERTA, and BART
   
   ```python
   import mxnet as mx
   import numpy as np
   import gluonnlp
   from gluonnlp.models import get_backbone
   import numpy.testing as npt
   import tvm
   from tvm import relay
   import tvm.contrib.graph_runtime as runtime
   
   
   mx.npx.set_np()
   
   instance_info = {
       'g4': {'target': "cuda -model=t4", 'use_gpu': True},
       'c4': {'target': 'llvm -mcpu=core-avx2 -libs=cblas', 'use_gpu': False},
       'c5': {'target': 'llvm -mcpu=skylake-avx512 -libs=cblas', 'use_gpu': False},
       'p3': {'target': 'cuda -model=v100', 'use_gpu': True}
   }
   
   
   def test_backbone(model_name, batch_size=2, seq_length=128, instance='g4',
                     required_pass=None, opt_level=3):
       if required_pass is None:
           required_pass = ["FastMath"]
       model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
       model = model_cls.from_cfg(cfg)
       model.load_parameters(backbone_param_path)
       model.hybridize()
       token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32)
       token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32)
       valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32)
       if 'bart' in model_name:
           mx_out = model(token_ids, valid_length, token_ids, valid_length)
           shape_dict = {
               'data0': token_ids.shape,
               'data1': valid_length.shape,
               'data2': token_ids.shape,
               'data3': valid_length.shape,
           }
           dtype_dict = {
               'data0': token_ids.dtype.name,
               'data1': valid_length.dtype.name,
               'data2': token_ids.dtype.name,
               'data3': valid_length.dtype.name,
           }
       elif 'roberta' in model_name or 'xlmr' in model_name:
           mx_out = model(token_ids, valid_length)
           shape_dict = {
               'data0': token_ids.shape,
               'data1': valid_length.shape,
           }
           dtype_dict = {
               'data0': token_ids.dtype.name,
               'data1': valid_length.dtype.name,
           }
       else:
           mx_out = model(token_ids, token_types, valid_length)
           shape_dict = {
               'data0': token_ids.shape,
               'data1': token_types.shape,
               'data2': valid_length.shape
           }
           dtype_dict = {
               'data0': token_ids.dtype.name,
               'data1': token_types.dtype.name,
               'data2': valid_length.dtype.name
           }
       sym = model._cached_graph[1]
       params = {}
       for k, v in model.collect_params().items():
           params[v._var_name] = tvm.nd.array(v.data().asnumpy())
       mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
       target = instance_info[instance]['target']
       use_gpu = instance_info[instance]['use_gpu']
       with relay.build_config(opt_level=opt_level, required_pass=required_pass):
           graph, lib, cparams = relay.build(mod, target, params=params)
       if use_gpu:
           ctx = tvm.gpu()
       else:
           ctx = tvm.cpu()
       rt = runtime.create(graph, lib, ctx)
       rt.set_input(**cparams)
       if 'bart' in model_name:
           rt.set_input(data0=token_ids, data1=valid_length, data2=token_ids, data3=valid_length)
       elif 'roberta' in model_name:
           rt.set_input(data0=token_ids, data1=valid_length)
       else:
           rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
       rt.run()
       for i in range(rt.get_num_outputs()):
           out = rt.get_output(i)
           if rt.get_num_outputs() == 1:
               mx_out_gt = mx_out.asnumpy()
           else:
               mx_out_gt = mx_out[i].asnumpy()
           if 'mobilebert' in model_name and len(out.shape) == 3:
               npt.assert_allclose(out.asnumpy()[:, 1:, :], mx_out[i].asnumpy()[:, 1:, :],
                                   rtol=6e-2, atol=6e-2)
           else:
               npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=6e-2, atol=6e-2)
   # test_backbone('google_en_cased_bert_base', instance='g4')
   test_model_names = ['google_albert_base_v2',
                       'google_en_cased_bert_base',
                       'google_electra_small',
                       'google_uncased_mobilebert',
                       'fairseq_roberta_base',
                       'fairseq_bart_base']
   for model_name in test_model_names:
       test_backbone(model_name, instance='g4')
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#discussion_r506718116



##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -58,6 +58,11 @@
 _activation_map = {"sigmoid": _op.sigmoid, "tanh": _op.tanh, "relu": _op.nn.relu}
 
 
+def get_tuple_shape(shape_expr):

Review comment:
       Can we directly use `topi.util.get_const_tuple`?

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -2312,23 +2345,76 @@ def _mx_npx_reshape(inputs, attrs):
     reverse = attrs.get_bool("reverse", False)
     shape_list = list(shape)
     new_shape_list = []
-    for num in shape_list:
-        if num > 0 or num == -1:
-            new_shape_list.append(num)
-        elif num == -2:
-            new_shape_list.append(0)
-        elif num == -4:
-            new_shape_list.append(-2)
-        elif num == -5:
-            new_shape_list.append(-3)
-        elif num == -6:
-            new_shape_list.append(-4)
-        else:
-            raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
-    shape = tuple(new_shape_list)
-    if reverse:
-        return _op.reverse_reshape(inputs[0], newshape=shape)
-    return _op.reshape(inputs[0], newshape=shape)
+    if -3 not in shape_list:
+        for num in shape_list:
+            if num > 0 or num == -1:
+                new_shape_list.append(num)
+            elif num == -2:
+                new_shape_list.append(0)
+            elif num == -4:
+                new_shape_list.append(-2)
+            elif num == -5:
+                new_shape_list.append(-3)
+            elif num == -6:
+                new_shape_list.append(-4)

Review comment:
       ```suggestion
               elif num in [-2, -4, -5, -6]:
                   new_shape_list.append(num + 2)
   ```

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -2312,23 +2345,76 @@ def _mx_npx_reshape(inputs, attrs):
     reverse = attrs.get_bool("reverse", False)
     shape_list = list(shape)
     new_shape_list = []
-    for num in shape_list:
-        if num > 0 or num == -1:
-            new_shape_list.append(num)
-        elif num == -2:
-            new_shape_list.append(0)
-        elif num == -4:
-            new_shape_list.append(-2)
-        elif num == -5:
-            new_shape_list.append(-3)
-        elif num == -6:
-            new_shape_list.append(-4)
-        else:
-            raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
-    shape = tuple(new_shape_list)
-    if reverse:
-        return _op.reverse_reshape(inputs[0], newshape=shape)
-    return _op.reshape(inputs[0], newshape=shape)
+    if -3 not in shape_list:
+        for num in shape_list:
+            if num > 0 or num == -1:
+                new_shape_list.append(num)
+            elif num == -2:
+                new_shape_list.append(0)
+            elif num == -4:
+                new_shape_list.append(-2)
+            elif num == -5:
+                new_shape_list.append(-3)
+            elif num == -6:
+                new_shape_list.append(-4)
+            else:
+                raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
+        shape = tuple(new_shape_list)
+        if reverse:
+            return _op.reverse_reshape(inputs[0], newshape=shape)
+        return _op.reshape(inputs[0], newshape=shape)
+    else:
+        old_shape = get_tuple_shape(_infer_type(inputs[0]).checked_type.shape)
+        new_shape = []
+        if reverse:
+            old_shape = old_shape[::-1]
+            shape_list = shape_list[::-1]
+        ptr = 0
+        unknown_axis = None
+        src_ptr = 0
+        while src_ptr < len(shape_list):
+            ele = shape_list[src_ptr]
+            src_ptr += 1
+            if ele > 0:
+                new_shape.append(ele)
+                ptr += 1
+            elif ele == -1:
+                new_shape.append(-1)
+                assert unknown_axis is None, "Can only have one unknown axis."
+                unknown_axis = len(new_shape)
+                ptr += 1
+            elif ele == -2:
+                new_shape.append(old_shape[ptr])
+                ptr += 1
+            elif ele == -3:
+                assert old_shape[ptr] == 1

Review comment:
       Better to have an error message. Ditto to other asserts.

##########
File path: python/tvm/topi/x86/batch_matmul.py
##########
@@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
     YB, N, YK = get_const_tuple(y.shape)
     assert XB == YB, "batch dimension doesn't match"
     assert XK == YK, "shapes of x and y is inconsistant"
+    if out_shape is not None:
+        assert out_shape[0] == XB, "got invalid output shape"
+        assert out_shape[1] == M, "got invalid output shape"
+        assert out_shape[2] == N, "got invalid output shape"

Review comment:
       Why we need an additional output shape argument if we can figure it out in this function?

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -2312,23 +2345,76 @@ def _mx_npx_reshape(inputs, attrs):
     reverse = attrs.get_bool("reverse", False)
     shape_list = list(shape)
     new_shape_list = []
-    for num in shape_list:
-        if num > 0 or num == -1:
-            new_shape_list.append(num)
-        elif num == -2:
-            new_shape_list.append(0)
-        elif num == -4:
-            new_shape_list.append(-2)
-        elif num == -5:
-            new_shape_list.append(-3)
-        elif num == -6:
-            new_shape_list.append(-4)
-        else:
-            raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
-    shape = tuple(new_shape_list)
-    if reverse:
-        return _op.reverse_reshape(inputs[0], newshape=shape)
-    return _op.reshape(inputs[0], newshape=shape)
+    if -3 not in shape_list:

Review comment:
       Better to comment why `-3` needs a special process.

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -627,6 +632,21 @@ def _mx_expand_dims(inputs, attrs):
     return _op.expand_dims(inputs[0], axis=axis)
 
 
+def _mx_where(inputs, attrs):

Review comment:
       Should we have a unit test for this new added op in MXNet frontend?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sxjscience edited a comment on pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
sxjscience edited a comment on pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#issuecomment-711111989


   I've verified the TVM integration with 5 NLP backbones in GluonNLP: BERT, ALBERT, ELECTRA, RoBERTA, and BART
   
   ```python
   import mxnet as mx
   import numpy as np
   import gluonnlp
   from gluonnlp.models import get_backbone
   import numpy.testing as npt
   import tvm
   from tvm import relay
   import tvm.contrib.graph_runtime as runtime
   
   
   mx.npx.set_np()
   
   instance_info = {
       'g4': {'target': "cuda -model=t4", 'use_gpu': True},
       'c4': {'target': 'llvm -mcpu=core-avx2 -libs=cblas', 'use_gpu': False},
       'c5': {'target': 'llvm -mcpu=skylake-avx512 -libs=cblas', 'use_gpu': False},
       'p3': {'target': 'cuda -model=v100', 'use_gpu': True}
   }
   
   
   def test_backbone(model_name, batch_size=2, seq_length=128, instance='g4',
                     required_pass=None, opt_level=3):
       if required_pass is None:
           required_pass = ["FastMath"]
       model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
       model = model_cls.from_cfg(cfg)
       model.load_parameters(backbone_param_path)
       model.hybridize()
       token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32)
       token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32)
       valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32)
       if 'bart' in model_name:
           mx_out = model(token_ids, valid_length, token_ids, valid_length)
           shape_dict = {
               'data0': token_ids.shape,
               'data1': valid_length.shape,
               'data2': token_ids.shape,
               'data3': valid_length.shape,
           }
           dtype_dict = {
               'data0': token_ids.dtype.name,
               'data1': valid_length.dtype.name,
               'data2': token_ids.dtype.name,
               'data3': valid_length.dtype.name,
           }
       elif 'roberta' in model_name or 'xlmr' in model_name:
           mx_out = model(token_ids, valid_length)
           shape_dict = {
               'data0': token_ids.shape,
               'data1': valid_length.shape,
           }
           dtype_dict = {
               'data0': token_ids.dtype.name,
               'data1': valid_length.dtype.name,
           }
       else:
           mx_out = model(token_ids, token_types, valid_length)
           shape_dict = {
               'data0': token_ids.shape,
               'data1': token_types.shape,
               'data2': valid_length.shape
           }
           dtype_dict = {
               'data0': token_ids.dtype.name,
               'data1': token_types.dtype.name,
               'data2': valid_length.dtype.name
           }
       sym = model._cached_graph[1]
       params = {}
       for k, v in model.collect_params().items():
           params[v._var_name] = tvm.nd.array(v.data().asnumpy())
       mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
       target = instance_info[instance]['target']
       use_gpu = instance_info[instance]['use_gpu']
       with relay.build_config(opt_level=opt_level, required_pass=required_pass):
           graph, lib, cparams = relay.build(mod, target, params=params)
       if use_gpu:
           ctx = tvm.gpu()
       else:
           ctx = tvm.cpu()
       rt = runtime.create(graph, lib, ctx)
       rt.set_input(**cparams)
       if 'bart' in model_name:
           rt.set_input(data0=token_ids, data1=valid_length, data2=token_ids, data3=valid_length)
       elif 'roberta' in model_name:
           rt.set_input(data0=token_ids, data1=valid_length)
       else:
           rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
       rt.run()
       for i in range(rt.get_num_outputs()):
           out = rt.get_output(i)
           if rt.get_num_outputs() == 1:
               mx_out_gt = mx_out.asnumpy()
           else:
               mx_out_gt = mx_out[i].asnumpy()
           if 'mobilebert' in model_name and len(out.shape) == 3:
               npt.assert_allclose(out.asnumpy()[:, 1:, :], mx_out[i].asnumpy()[:, 1:, :],
                                   rtol=6e-2, atol=6e-2)
           else:
               npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=6e-2, atol=6e-2)
   # test_backbone('google_en_cased_bert_base', instance='g4')
   test_model_names = ['google_albert_base_v2',
                       'google_en_cased_bert_base',
                       'google_electra_small',
                       'fairseq_roberta_base',
                       'fairseq_bart_base']
   for model_name in test_model_names:
       test_backbone(model_name, instance='g4')
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sxjscience commented on pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#issuecomment-710579653


   @yzhliu @comaniac @icemelon9 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sxjscience commented on pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#issuecomment-710693179


   The integration tests take a very long time because there are two many combinations. For example: https://github.com/apache/incubator-tvm/blob/461e75bd5ffaf45a0f270998514d444463d11261/tests/python/frontend/mxnet/test_forward.py#L2119-L2125
   
   We may try to simplify the tests by not using a full cartesian product


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sxjscience commented on a change in pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#discussion_r506769946



##########
File path: python/tvm/topi/x86/batch_matmul.py
##########
@@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
     YB, N, YK = get_const_tuple(y.shape)
     assert XB == YB, "batch dimension doesn't match"
     assert XK == YK, "shapes of x and y is inconsistant"
+    if out_shape is not None:
+        assert out_shape[0] == XB, "got invalid output shape"
+        assert out_shape[1] == M, "got invalid output shape"
+        assert out_shape[2] == N, "got invalid output shape"

Review comment:
       Yes, I triggered this when I'm following the blog https://medium.com/apache-mxnet/speed-up-your-bert-inference-by-3x-on-cpus-using-apache-tvm-9cf7776cd7f8 .




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org