You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/06/16 21:05:50 UTC

[tvm] branch main updated: [Frontend, Tensorflow] Support for broadcasting in batch_matmul when shapes differ (#8251)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ec6a817  [Frontend, Tensorflow] Support for broadcasting in batch_matmul when shapes differ (#8251)
ec6a817 is described below

commit ec6a817eaed246ffcf925f295b587cfc0af15035
Author: Rohan Mukherjee <mu...@amazon.com>
AuthorDate: Wed Jun 16 14:05:33 2021 -0700

    [Frontend, Tensorflow] Support for broadcasting in batch_matmul when shapes differ (#8251)
    
    * Support for broadcasting in batch_matmul when shapes differ
    
    * refactor
    
    * refactor logic for reshape in conditional
    
    * refactor
---
 python/tvm/relay/frontend/tensorflow_ops.py      | 16 +++++++++-------
 tests/python/frontend/tensorflow/test_forward.py | 17 +++++++++++++++++
 2 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py
index c738556..3c4a9b6 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -1132,22 +1132,23 @@ def _batch_matmul():
         orig_shape_x = _infer_shape(input_x, mod)
         orig_shape_y = _infer_shape(input_y, mod)
         ndim = len(orig_shape_x)
+        ndim_y = len(orig_shape_y)
 
         is_static = not check_symbolic_shape(orig_shape_x)
 
-        if ndim > 3 and not is_static:
-            shape_of_x = list_shape_of(inputs[0], ndim)
-            shape_of_y = list_shape_of(inputs[1], ndim)
-
         # reshape n-dimensional batch matmul into 3d
         if ndim > 3:
             outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)]
             if is_static:
                 num_outer_elts = np.prod(outer_dims)
                 new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1])
-                new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
+                if ndim_y > 2:
+                    new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1])
+                elif ndim_y == 2:
+                    new_shape_y = (1, orig_shape_y[-2], orig_shape_y[-1])
             else:  # handle dynamic shape (dyn.reshape op)
-                # new shape = [prod(shape[:-2]), -2, -1]
+                shape_of_x = list_shape_of(inputs[0], ndim)
+                shape_of_y = list_shape_of(inputs[1], ndim)
                 new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]]
                 new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]]
                 for i in range(ndim - 2):
@@ -1158,7 +1159,8 @@ def _batch_matmul():
 
             input_x = _op.reshape(input_x, newshape=new_shape_x)
             input_y = _op.reshape(input_y, newshape=new_shape_y)
-
+        elif ndim_y == 2:
+            input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1]))
         adj_x = attr["adj_x"]
         adj_y = attr["adj_y"]
         input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 3315533..57497d0 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1843,6 +1843,9 @@ def test_forward_batch_matmul():
     _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True)
     _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False)
     _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)
+    _test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False)
+    _test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False)
+    _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False)
 
 
 @tvm.testing.requires_cuda
@@ -1870,6 +1873,20 @@ def test_forward_batch_matmul_dynamic():
         (2, 3, 4, 6, 5),
         "float32",
     )
+    _test_batch_matmul_dynamic(
+        (None, None, None, 5, 6),
+        (6, None),
+        (2, 3, 4, 5, 6),
+        (6, 1),
+        "float32",
+    )
+    _test_batch_matmul_dynamic(
+        (None, 5, 6),
+        (6, None),
+        (24, 5, 6),
+        (6, 1),
+        "float32",
+    )
 
 
 #######################################################################