You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/06/16 21:05:50 UTC
[tvm] branch main updated: [Frontend,
Tensorflow] Support for broadcasting in batch_matmul when shapes
differ (#8251)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ec6a817 [Frontend, Tensorflow] Support for broadcasting in batch_matmul when shapes differ (#8251)
ec6a817 is described below
commit ec6a817eaed246ffcf925f295b587cfc0af15035
Author: Rohan Mukherjee <mu...@amazon.com>
AuthorDate: Wed Jun 16 14:05:33 2021 -0700
[Frontend, Tensorflow] Support for broadcasting in batch_matmul when shapes differ (#8251)
* Support for broadcasting in batch_matmul when shapes differ
* refactor
* refactor logic for reshape in conditional
* refactor
---
python/tvm/relay/frontend/tensorflow_ops.py | 16 +++++++++-------
tests/python/frontend/tensorflow/test_forward.py | 17 +++++++++++++++++
2 files changed, 26 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py
index c738556..3c4a9b6 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -1132,22 +1132,23 @@ def _batch_matmul():
orig_shape_x = _infer_shape(input_x, mod)
orig_shape_y = _infer_shape(input_y, mod)
ndim = len(orig_shape_x)
+ ndim_y = len(orig_shape_y)
is_static = not check_symbolic_shape(orig_shape_x)
- if ndim > 3 and not is_static:
- shape_of_x = list_shape_of(inputs[0], ndim)
- shape_of_y = list_shape_of(inputs[1], ndim)
-
# reshape n-dimensional batch matmul into 3d
if ndim > 3:
outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)]
if is_static:
num_outer_elts = np.prod(outer_dims)
new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1])
- new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
+ if ndim_y > 2:
+ new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
+ elif ndim_y == 2:
+ new_shape_y = (1, orig_shape_y[-2], orig_shape_y[-1])
else: # handle dynamic shape (dyn.reshape op)
- # new shape = [prod(shape[:-2]), -2, -1]
+ shape_of_x = list_shape_of(inputs[0], ndim)
+ shape_of_y = list_shape_of(inputs[1], ndim)
new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]]
new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]]
for i in range(ndim - 2):
@@ -1158,7 +1159,8 @@ def _batch_matmul():
input_x = _op.reshape(input_x, newshape=new_shape_x)
input_y = _op.reshape(input_y, newshape=new_shape_y)
-
+ elif ndim_y == 2:
+ input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1]))
adj_x = attr["adj_x"]
adj_y = attr["adj_y"]
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 3315533..57497d0 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1843,6 +1843,9 @@ def test_forward_batch_matmul():
_test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True)
_test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False)
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)
+ _test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False)
+ _test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False)
+ _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False)
@tvm.testing.requires_cuda
@@ -1870,6 +1873,20 @@ def test_forward_batch_matmul_dynamic():
(2, 3, 4, 6, 5),
"float32",
)
+ _test_batch_matmul_dynamic(
+ (None, None, None, 5, 6),
+ (6, None),
+ (2, 3, 4, 5, 6),
+ (6, 1),
+ "float32",
+ )
+ _test_batch_matmul_dynamic(
+ (None, 5, 6),
+ (6, None),
+ (24, 5, 6),
+ (6, 1),
+ "float32",
+ )
#######################################################################