You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/05/09 09:27:47 UTC

[GitHub] [tvm] ah-cheng opened a new pull request, #11242: [frontend][onnx]fix matmul broadcast

ah-cheng opened a new pull request, #11242:
URL: https://github.com/apache/tvm/pull/11242

   When I import onnx model. I ran into a bug with matmul.
   The matmul input shape is: a:[1,1,31,64,16], b:[3,31,64,16].
   Thus the broadcast shape should be [1,3,31,64,16].
   the both inputs should broadcast.
   So far our code just handld the case of one input broadcast.
   So I fix it. Thanks to help reviews.
   CC: @AndrewZhaoLuo 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ah-cheng commented on pull request #11242: [frontend][onnx]fix matmul broadcast

Posted by GitBox <gi...@apache.org>.
ah-cheng commented on PR #11242:
URL: https://github.com/apache/tvm/pull/11242#issuecomment-1128471202

   Can you have a look at this PR? thanks.
   @AndrewZhaoLuo 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo merged pull request #11242: [frontend][onnx]fix matmul broadcast

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo merged PR #11242:
URL: https://github.com/apache/tvm/pull/11242


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11242: [frontend][onnx]fix matmul broadcast

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11242:
URL: https://github.com/apache/tvm/pull/11242#discussion_r876141924


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -249,23 +249,39 @@ def flatten_to_nd(x, x_shape, nd=3):
             return out
 
         # Determine the output batch dimension.
+        new_a_shape = a_shape
+        new_b_shape = b_shape
         if a_rank > b_rank:
-            out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
+            rank_diff = a_rank - b_rank
+            new_b_shape = _op.concatenate(
+                [
+                    _expr.const([1] * rank_diff, dtype=infer_type(b_shape).checked_type.dtype),
+                    b_shape,
+                ],
+                0,
+            )
         elif a_rank < b_rank:
-            out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-        # If its unclear how broadcasting should be applied, the output
-        # shape is determined by choosing the maximum value from each input.
-        else:
-            out_batch = _op.concatenate(
+            rank_diff = b_rank - a_rank
+            new_a_shape = _op.concatenate(
                 [
-                    _op.maximum(
-                        _op.strided_slice(a_shape, [i], [i + 1]),
-                        _op.strided_slice(b_shape, [i], [i + 1]),
-                    )
-                    for i in range(a_rank - 2)
+                    _expr.const([1] * rank_diff, dtype=infer_type(a_shape).checked_type.dtype),
+                    a_shape,
                 ],
                 0,
             )
+        else:

Review Comment:
   Oh nevermind, this is probably because of the linter check: https://github.com/apache/tvm/pull/11327



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11242: [frontend][onnx]fix matmul broadcast

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11242:
URL: https://github.com/apache/tvm/pull/11242#discussion_r876111689


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -249,23 +249,39 @@ def flatten_to_nd(x, x_shape, nd=3):
             return out
 
         # Determine the output batch dimension.
+        new_a_shape = a_shape

Review Comment:
   I would prefer a less imperative style here; e.g. make it so each if/else branch assigns `new_a_shape` and `new_b_shape` so the branches aren't mutating state up here



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -249,23 +249,39 @@ def flatten_to_nd(x, x_shape, nd=3):
             return out
 
         # Determine the output batch dimension.
+        new_a_shape = a_shape
+        new_b_shape = b_shape
         if a_rank > b_rank:
-            out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
+            rank_diff = a_rank - b_rank
+            new_b_shape = _op.concatenate(
+                [
+                    _expr.const([1] * rank_diff, dtype=infer_type(b_shape).checked_type.dtype),
+                    b_shape,
+                ],
+                0,
+            )
         elif a_rank < b_rank:
-            out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2])
-        # If its unclear how broadcasting should be applied, the output
-        # shape is determined by choosing the maximum value from each input.
-        else:
-            out_batch = _op.concatenate(
+            rank_diff = b_rank - a_rank
+            new_a_shape = _op.concatenate(
                 [
-                    _op.maximum(
-                        _op.strided_slice(a_shape, [i], [i + 1]),
-                        _op.strided_slice(b_shape, [i], [i + 1]),
-                    )
-                    for i in range(a_rank - 2)
+                    _expr.const([1] * rank_diff, dtype=infer_type(a_shape).checked_type.dtype),
+                    a_shape,
                 ],
                 0,
             )
+        else:

Review Comment:
   you can just remove else branch i believe



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org