You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/20 19:38:27 UTC

[GitHub] [tvm] mbrookhart commented on a change in pull request #9039: [ONNX][Relay] Add dynamic unsqueeze / expand_dims op

mbrookhart commented on a change in pull request #9039:
URL: https://github.com/apache/tvm/pull/9039#discussion_r712457695



##########
File path: tests/python/frontend/onnx/test_forward.py
##########
@@ -5015,16 +5015,13 @@ def verify_eyelike(indata):
     "test_training_dropout_mask",
     "test_training_dropout_zero_ratio",
     "test_training_dropout_zero_ratio_mask",
-    "test_unique_sorted_with_axis",
-    "test_unique_sorted_with_axis_3d",
-    "test_unique_sorted_with_negative_axis",
-    "test_unsqueeze_axis_0",
-    "test_unsqueeze_axis_1",
-    "test_unsqueeze_axis_2",
-    "test_unsqueeze_negative_axes",
+    # These unsqueeze tests work, but take 2+ hrs to run
     "test_unsqueeze_three_axes",
     "test_unsqueeze_two_axes",
     "test_unsqueeze_unsorted_axes",

Review comment:
       Can you add these to device-specific skips below?

##########
File path: python/tvm/relay/op/transform.py
##########
@@ -110,7 +110,17 @@ def expand_dims(data, axis, num_newaxis=1):
     result : relay.Expr
         The reshaped result.
     """
-    return _make.expand_dims(data, axis, num_newaxis)
+    if isinstance(axis, int):
+        return _make.expand_dims(data, axis, num_newaxis)
+    if isinstance(axis, Expr):
+        # TODO (AndrewZhaoLuo): investigate performance issues with consecutive
+        # dynamic expand_dims on non-llvm targets.
+        for _ in range(num_newaxis):
+            # Dynamic rank is not well supported so we can only increase rank
+            # by a static amount (e.g. 1) so we have to do this
+            data = _dyn_make.expand_dims(data, axis)
+        return data
+    raise ValueError(f"Unknown type for axis: {type(axis)}")

Review comment:
       Why can't we do it all at once if we know that num_newaxis is static?

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -1462,6 +1462,26 @@ def _impl_v1(cls, inputs, attr, params):
             inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1)
         return inputs[0]
 
+    @classmethod
+    def _impl_v12(cls, inputs, attr, params):
+        rank_input = len(infer_type(inputs[0]).checked_type.shape)
+        num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0])
+        axes = relay.split(inputs[1], num_new_axis).astuple()
+
+        result = inputs[0]
+
+        # TODO (AndrewZhaoLuo): investigate performance issues with consecutive
+        # dynamic expand_dims on non-llvm targets.
+        for i in range(num_new_axis):
+            axis = relay.TupleGetItem(axes, i)
+            # Unpack scalar
+            axis = relay.reshape(axis, [])
+            axis = relay.If(
+                axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64")
+            )
+            result = _op.expand_dims(result, axis)
+        return result
+

Review comment:
       Again, I think this should be doable as one call?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org