You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/05/18 16:40:35 UTC

[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11242: [frontend][onnx]fix matmul broadcast

AndrewZhaoLuo commented on code in PR #11242:
URL: https://github.com/apache/tvm/pull/11242#discussion_r876111689


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -249,23 +249,39 @@ def flatten_to_nd(x, x_shape, nd=3):
             return out
 
         # Determine the output batch dimension.
+        new_a_shape = a_shape

Review Comment:
   I would prefer a less imperative style here; e.g. make it so each if/else branch assigns `new_a_shape` and `new_b_shape` so the branches aren't mutating state up here



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -249,23 +249,39 @@ def flatten_to_nd(x, x_shape, nd=3):
             return out
 
         # Determine the output batch dimension.
+        new_a_shape = a_shape
+        new_b_shape = b_shape
         if a_rank > b_rank:
-            out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
+            rank_diff = a_rank - b_rank
+            new_b_shape = _op.concatenate(
+                [
+                    _expr.const([1] * rank_diff, dtype=infer_type(b_shape).checked_type.dtype),
+                    b_shape,
+                ],
+                0,
+            )
         elif a_rank < b_rank:
-            out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-        # If its unclear how broadcasting should be applied, the output
-        # shape is determined by choosing the maximum value from each input.
-        else:
-            out_batch = _op.concatenate(
+            rank_diff = b_rank - a_rank
+            new_a_shape = _op.concatenate(
                 [
-                    _op.maximum(
-                        _op.strided_slice(a_shape, [i], [i + 1]),
-                        _op.strided_slice(b_shape, [i], [i + 1]),
-                    )
-                    for i in range(a_rank - 2)
+                    _expr.const([1] * rank_diff, dtype=infer_type(a_shape).checked_type.dtype),
+                    a_shape,
                 ],
                 0,
             )
+        else:

Review Comment:
   you can just remove else branch i believe



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org