You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/12/13 00:38:40 UTC

[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #13602: [Relay][Frontend][Onnx] SequenceAt and SplitToSequence Operators

AndrewZhaoLuo commented on code in PR #13602:
URL: https://github.com/apache/tvm/pull/13602#discussion_r1046548968


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -5565,6 +5581,66 @@ def _impl_v11(cls, inputs, attr, params):
         return _op.concatenate(inputs[0], axis=axis)
 
 
+class SplitToSequence(OnnxOpConverter):
+    """Operator converter for split to sequence op."""
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        keepdims = attr.get("keepdims", 1)
+
+        input_tensor = inputs[0]
+        input_shape = infer_shape(input_tensor)
+        split = inputs[1]
+
+        # If split is not provided, we split all values along axis.
+        if split is None:
+            output = _op.split(input_tensor, input_shape[axis], axis=axis)
+            # If keepdims is 0, then we need to squeeze off the axis.
+            if keepdims == 0:
+                output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output]
+            return _expr.Tuple(list(output))
+
+        # Otherwise, split based on provided split value.
+        else:
+            # For now we only support constant valued split.
+            assert isinstance(
+                split, _expr.Constant
+            ), "Only constant split supported for SplitToSequence"
+            split = split.data.numpy()
+            if len(split.shape) == 1 and split.shape[0] > 1:
+                # If split is a 1D tensor, it must be converted to indices for relay compatibility.
+                split = np.cumsum(split)
+                # Remove final invalid index.
+                split = split[:-1]
+            else:
+                # Otherwise get split as an integer.
+                split = int(split)
+
+            output = _op.split(input_tensor, split, axis=axis)
+
+            # If keepdims is set to 0 remove split axis. Note that this is
+            # an inconsistency with the onnx spec but is needed for pytorch compatibility.
+            if keepdims == 0:

Review Comment:
   nit `not keepdims` throughout



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -4008,6 +4008,22 @@ def _impl_v1(cls, inputs, attr, params):
         for var in else_free_vars:
             graph_scope._nodes.update({var.name_hint: var})
 
+        # Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks.
+        # Often these dont contribute anything. If we see a dynamic rank output, try to unify
+        # them so we can continue without breaking.
+        then_shape = infer_shape(then_expr)

Review Comment:
   Very sus, but I guess this is the best we can do in this situation. The only reason I will allow this is it's in converting `If` which is inherently going to be a bit more sus due to relay limitations.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org