You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/26 15:08:12 UTC

[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6739: [Relay] Fix dynamic case for Squeeze and Split

giuseros commented on a change in pull request #6739:
URL: https://github.com/apache/incubator-tvm/pull/6739#discussion_r512026785



##########
File path: python/tvm/relay/op/_transform.py
##########
@@ -575,11 +575,14 @@ def transpose_shape_func(attrs, inputs, _):
 
 
 @script
-def _squeeze_shape_func(data_shape, keep_axes):
+def _squeeze_shape_func(data_shape, keep_axes, remove_axes):
     out = output_tensor((len(keep_axes),), "int64")
     for i in const_range(len(keep_axes)):
         out[i] = data_shape[keep_axes[i]]
 
+    for i in const_range(len(remove_axes)):
+        assert data_shape[remove_axes[i]] == 1, "Removed dimension must have size 1"
+

Review comment:
       Are you using `remove_axes` only for correctness? If yes, is it worth passing them?

##########
File path: python/tvm/relay/op/_transform.py
##########
@@ -590,18 +593,21 @@ def squeeze_shape_func(attrs, inputs, _):
     """
     axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
     keep_axes = []
+    remove_axes = []
     if axis is not None:
         for i in range(inputs[0].shape[0].value):
             if i not in axis:
                 keep_axes.append(i)
+            else:
+                remove_axes.append(i)

Review comment:
       This could be written in a more concise form like `keep_axes.append(i) if i not in axis else remove_axes.append(i)`.

##########
File path: python/tvm/relay/op/_transform.py
##########
@@ -590,18 +593,21 @@ def squeeze_shape_func(attrs, inputs, _):
     """
     axis = attrs.axis if attrs.axis is None else get_const_tuple(attrs.axis)
     keep_axes = []
+    remove_axes = []

Review comment:
       Isn't `remove_axes` the same of `axis`?

##########
File path: python/tvm/relay/op/_transform.py
##########
@@ -705,6 +711,9 @@ def split_shape_func(attrs, inputs, _):
 
     axis = get_const_int(attrs.axis)
 
+    if axis < 0:
+        axis += inputs[0].shape[0]

Review comment:
       Don't you need to `get_const_int(inputs[0].shape[0])`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org