You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/01/15 21:13:35 UTC

[tvm] branch main updated: [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] (#9911)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6eb4ed8  [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] (#9911)
6eb4ed8 is described below

commit 6eb4ed813ebcdcd9558f0906a1870db8302ff1e0
Author: Will Zhang <wi...@163.com>
AuthorDate: Sun Jan 16 05:13:15 2022 +0800

    [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] (#9911)
    
    * [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K]
    
    * fix line
    
    Co-authored-by: tomoyazhang <to...@tencent.com>
---
 python/tvm/relay/frontend/onnx.py          | 54 ++++++++++++++++++++----------
 tests/python/frontend/onnx/test_forward.py |  1 +
 2 files changed, 37 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 60319d6..234beec 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -238,24 +238,6 @@ def matmul_out_dtype(inputs, out_dtype):
             out = _op.reshape(x, fold_constant(newshape))
             return out
 
-        b_type = infer_type(inputs[1])
-        # Convert to dense if the second matrix is 2d and non-dynamic
-        if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
-            a = flatten_to_nd(inputs[0], a_shape, 2)
-            b = _op.transpose(inputs[1])
-            output = _op.nn.dense(a, b, out_dtype=out_dtype)
-        else:
-            # Convert a and b into 3 dimensional tensors.
-            a = flatten_to_nd(inputs[0], a_shape, 3)
-            b = flatten_to_nd(inputs[1], b_shape, 3)
-            if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
-                # Transpose matrix dimensions of b.
-                bt = _op.transpose(b, [0, 2, 1])
-                # Perform a NT batch matmul.
-                output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype)
-            else:
-                # Perform a NN batch matmul.
-                output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
         # Determine the output batch dimension.
         if a_rank > b_rank:
             out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
@@ -274,6 +256,42 @@ def matmul_out_dtype(inputs, out_dtype):
                 ],
                 0,
             )
+
+        b_type = infer_type(inputs[1])
+        # Convert to dense if the second matrix is 2d and non-dynamic
+        if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
+            a = flatten_to_nd(inputs[0], a_shape, 2)
+            b = _op.transpose(inputs[1])
+            output = _op.nn.dense(a, b, out_dtype=out_dtype)
+        else:
+            # broadcast a and b
+            a_broadcasted_shape = _op.concatenate(
+                [
+                    out_batch,
+                    _op.strided_slice(a_shape, [a_rank - 2], [a_rank]),
+                ],
+                0,
+            )
+            b_broadcasted_shape = _op.concatenate(
+                [
+                    out_batch,
+                    _op.strided_slice(b_shape, [b_rank - 2], [b_rank]),
+                ],
+                0,
+            )
+            a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape))
+            b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape))
+            # Convert a and b into 3 dimensional tensors.
+            a = flatten_to_nd(a, shape_of(a), 3)
+            b = flatten_to_nd(b, shape_of(b), 3)
+            if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
+                # Transpose matrix dimensions of b.
+                bt = _op.transpose(b, [0, 2, 1])
+                # Perform a NT batch matmul.
+                output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype)
+            else:
+                # Perform a NN batch matmul.
+                output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
         # Reshape output to original dimensions.
         final_shape = _op.concatenate(
             [
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 2bdc5f7..2e0d927 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1273,6 +1273,7 @@ def test_batch_matmul(target, dev):
     verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))
     verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32))
     verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16))
+    verify_batch_matmul((4, 32, 16, 32), (1, 32, 32, 16), (4, 32, 16, 16))
     # Test transb=False
     verify_batch_matmul(
         (2, 3, 4, 3),