You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/01/15 21:13:35 UTC
[tvm] branch main updated: [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] (#9911)
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6eb4ed8 [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] (#9911)
6eb4ed8 is described below
commit 6eb4ed813ebcdcd9558f0906a1870db8302ff1e0
Author: Will Zhang <wi...@163.com>
AuthorDate: Sun Jan 16 05:13:15 2022 +0800
[Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K] (#9911)
* [Fix] relay onnx frontend bug when [A, B, M, N] * [1, B, N, K]
* fix line
Co-authored-by: tomoyazhang <to...@tencent.com>
---
python/tvm/relay/frontend/onnx.py | 54 ++++++++++++++++++++----------
tests/python/frontend/onnx/test_forward.py | 1 +
2 files changed, 37 insertions(+), 18 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 60319d6..234beec 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -238,24 +238,6 @@ def matmul_out_dtype(inputs, out_dtype):
out = _op.reshape(x, fold_constant(newshape))
return out
- b_type = infer_type(inputs[1])
- # Convert to dense if the second matrix is 2d and non-dynamic
- if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
- a = flatten_to_nd(inputs[0], a_shape, 2)
- b = _op.transpose(inputs[1])
- output = _op.nn.dense(a, b, out_dtype=out_dtype)
- else:
- # Convert a and b into 3 dimensional tensors.
- a = flatten_to_nd(inputs[0], a_shape, 3)
- b = flatten_to_nd(inputs[1], b_shape, 3)
- if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
- # Transpose matrix dimensions of b.
- bt = _op.transpose(b, [0, 2, 1])
- # Perform a NT batch matmul.
- output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype)
- else:
- # Perform a NN batch matmul.
- output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
@@ -274,6 +256,42 @@ def matmul_out_dtype(inputs, out_dtype):
],
0,
)
+
+ b_type = infer_type(inputs[1])
+ # Convert to dense if the second matrix is 2d and non-dynamic
+ if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
+ a = flatten_to_nd(inputs[0], a_shape, 2)
+ b = _op.transpose(inputs[1])
+ output = _op.nn.dense(a, b, out_dtype=out_dtype)
+ else:
+ # broadcast a and b
+ a_broadcasted_shape = _op.concatenate(
+ [
+ out_batch,
+ _op.strided_slice(a_shape, [a_rank - 2], [a_rank]),
+ ],
+ 0,
+ )
+ b_broadcasted_shape = _op.concatenate(
+ [
+ out_batch,
+ _op.strided_slice(b_shape, [b_rank - 2], [b_rank]),
+ ],
+ 0,
+ )
+ a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape))
+ b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape))
+ # Convert a and b into 3 dimensional tensors.
+ a = flatten_to_nd(a, shape_of(a), 3)
+ b = flatten_to_nd(b, shape_of(b), 3)
+ if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
+ # Transpose matrix dimensions of b.
+ bt = _op.transpose(b, [0, 2, 1])
+ # Perform a NT batch matmul.
+ output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype)
+ else:
+ # Perform a NN batch matmul.
+ output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Reshape output to original dimensions.
final_shape = _op.concatenate(
[
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 2bdc5f7..2e0d927 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1273,6 +1273,7 @@ def test_batch_matmul(target, dev):
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))
verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32))
verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16))
+ verify_batch_matmul((4, 32, 16, 32), (1, 32, 32, 16), (4, 32, 16, 16))
# Test transb=False
verify_batch_matmul(
(2, 3, 4, 3),