You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/16 21:08:58 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6699: [Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP

comaniac commented on a change in pull request #6699:
URL: https://github.com/apache/incubator-tvm/pull/6699#discussion_r506718116



##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -58,6 +58,11 @@
 _activation_map = {"sigmoid": _op.sigmoid, "tanh": _op.tanh, "relu": _op.nn.relu}
 
 
+def get_tuple_shape(shape_expr):

Review comment:
       Can we directly use `topi.util.get_const_tuple`?

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -2312,23 +2345,76 @@ def _mx_npx_reshape(inputs, attrs):
     reverse = attrs.get_bool("reverse", False)
     shape_list = list(shape)
     new_shape_list = []
-    for num in shape_list:
-        if num > 0 or num == -1:
-            new_shape_list.append(num)
-        elif num == -2:
-            new_shape_list.append(0)
-        elif num == -4:
-            new_shape_list.append(-2)
-        elif num == -5:
-            new_shape_list.append(-3)
-        elif num == -6:
-            new_shape_list.append(-4)
-        else:
-            raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
-    shape = tuple(new_shape_list)
-    if reverse:
-        return _op.reverse_reshape(inputs[0], newshape=shape)
-    return _op.reshape(inputs[0], newshape=shape)
+    if -3 not in shape_list:
+        for num in shape_list:
+            if num > 0 or num == -1:
+                new_shape_list.append(num)
+            elif num == -2:
+                new_shape_list.append(0)
+            elif num == -4:
+                new_shape_list.append(-2)
+            elif num == -5:
+                new_shape_list.append(-3)
+            elif num == -6:
+                new_shape_list.append(-4)

Review comment:
       ```suggestion
               elif num in [-2, -4, -5, -6]:
                   new_shape_list.append(num + 2)
   ```

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -2312,23 +2345,76 @@ def _mx_npx_reshape(inputs, attrs):
     reverse = attrs.get_bool("reverse", False)
     shape_list = list(shape)
     new_shape_list = []
-    for num in shape_list:
-        if num > 0 or num == -1:
-            new_shape_list.append(num)
-        elif num == -2:
-            new_shape_list.append(0)
-        elif num == -4:
-            new_shape_list.append(-2)
-        elif num == -5:
-            new_shape_list.append(-3)
-        elif num == -6:
-            new_shape_list.append(-4)
-        else:
-            raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
-    shape = tuple(new_shape_list)
-    if reverse:
-        return _op.reverse_reshape(inputs[0], newshape=shape)
-    return _op.reshape(inputs[0], newshape=shape)
+    if -3 not in shape_list:
+        for num in shape_list:
+            if num > 0 or num == -1:
+                new_shape_list.append(num)
+            elif num == -2:
+                new_shape_list.append(0)
+            elif num == -4:
+                new_shape_list.append(-2)
+            elif num == -5:
+                new_shape_list.append(-3)
+            elif num == -6:
+                new_shape_list.append(-4)
+            else:
+                raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
+        shape = tuple(new_shape_list)
+        if reverse:
+            return _op.reverse_reshape(inputs[0], newshape=shape)
+        return _op.reshape(inputs[0], newshape=shape)
+    else:
+        old_shape = get_tuple_shape(_infer_type(inputs[0]).checked_type.shape)
+        new_shape = []
+        if reverse:
+            old_shape = old_shape[::-1]
+            shape_list = shape_list[::-1]
+        ptr = 0
+        unknown_axis = None
+        src_ptr = 0
+        while src_ptr < len(shape_list):
+            ele = shape_list[src_ptr]
+            src_ptr += 1
+            if ele > 0:
+                new_shape.append(ele)
+                ptr += 1
+            elif ele == -1:
+                new_shape.append(-1)
+                assert unknown_axis is None, "Can only have one unknown axis."
+                unknown_axis = len(new_shape)
+                ptr += 1
+            elif ele == -2:
+                new_shape.append(old_shape[ptr])
+                ptr += 1
+            elif ele == -3:
+                assert old_shape[ptr] == 1

Review comment:
       Better to have an error message. Ditto to other asserts.

##########
File path: python/tvm/topi/x86/batch_matmul.py
##########
@@ -157,6 +163,10 @@ def batch_matmul_cblas(cfg, x, y):
     YB, N, YK = get_const_tuple(y.shape)
     assert XB == YB, "batch dimension doesn't match"
     assert XK == YK, "shapes of x and y is inconsistant"
+    if out_shape is not None:
+        assert out_shape[0] == XB, "got invalid output shape"
+        assert out_shape[1] == M, "got invalid output shape"
+        assert out_shape[2] == N, "got invalid output shape"

Review comment:
       Why we need an additional output shape argument if we can figure it out in this function?

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -2312,23 +2345,76 @@ def _mx_npx_reshape(inputs, attrs):
     reverse = attrs.get_bool("reverse", False)
     shape_list = list(shape)
     new_shape_list = []
-    for num in shape_list:
-        if num > 0 or num == -1:
-            new_shape_list.append(num)
-        elif num == -2:
-            new_shape_list.append(0)
-        elif num == -4:
-            new_shape_list.append(-2)
-        elif num == -5:
-            new_shape_list.append(-3)
-        elif num == -6:
-            new_shape_list.append(-4)
-        else:
-            raise tvm.error.OpAttributeInvalid("Shape dimension %d is not supported" % num)
-    shape = tuple(new_shape_list)
-    if reverse:
-        return _op.reverse_reshape(inputs[0], newshape=shape)
-    return _op.reshape(inputs[0], newshape=shape)
+    if -3 not in shape_list:

Review comment:
       Better to comment why `-3` needs a special process.

##########
File path: python/tvm/relay/frontend/mxnet.py
##########
@@ -627,6 +632,21 @@ def _mx_expand_dims(inputs, attrs):
     return _op.expand_dims(inputs[0], axis=axis)
 
 
+def _mx_where(inputs, attrs):

Review comment:
       Should we have a unit test for this new added op in MXNet frontend?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org