You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/04/24 17:11:57 UTC

[tvm] branch main updated: [Frontend][Tensorflow] SelectV2 and BroadcastArgs op support for tf2 models (#7901)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fad10d7  [Frontend][Tensorflow] SelectV2 and BroadcastArgs op support for tf2 models (#7901)
fad10d7 is described below

commit fad10d7914d3807e56501d7cc239d505da175e07
Author: srinidhigoud <sr...@gmail.com>
AuthorDate: Sat Apr 24 10:11:31 2021 -0700

    [Frontend][Tensorflow] SelectV2 and BroadcastArgs op support for tf2 models (#7901)
---
 python/tvm/relay/frontend/tensorflow.py          | 43 ++++++++++++++++++++++--
 tests/python/frontend/tensorflow/test_forward.py | 29 +++++++++++++++-
 2 files changed, 68 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 997f68b..f566a3f 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -19,6 +19,7 @@
 """TF: Tensorflow frontend."""
 import warnings
 from collections import defaultdict
+from collections import deque
 
 # Numpy support
 import numpy as np
@@ -1770,6 +1771,43 @@ def _bias_add():
     return _impl
 
 
+def _broadcast_args():
+    def _impl(inputs, attr, params, mod):
+        if isinstance(inputs[0], _expr.Var):
+            s0 = params[inputs[0].name_hint]
+        else:
+            s0 = _infer_value(inputs[0], params, mod)
+        if isinstance(inputs[1], _expr.Var):
+            s1 = params[inputs[1].name_hint]
+        else:
+            s1 = _infer_value(inputs[1], params, mod)
+        s0 = list(s0.asnumpy().reshape([-1]))
+        s1 = list(s1.asnumpy().reshape([-1]))
+        s0_size, s1_size = len(s0), len(s1)
+
+        out = deque([])
+        for i in range(1, min(s0_size, s1_size) + 1):
+            if s0[s0_size - i] == s1[s1_size - i]:
+                out.appendleft(s0[s0_size - i])
+            elif s0[s0_size - i] == 1:
+                out.appendleft(s1[s1_size - i])
+            else:
+                assert s1[s1_size - i] == 1, "Incompatible broadcast type %s and %s" % (
+                    s0[s0_size - i],
+                    s1[s1_size - i],
+                )
+                out.appendleft(s0[s0_size - i])
+        if s0_size < s1_size:
+            for i in range(s0_size + 1, s1_size + 1):
+                out.appendleft(s1[s1_size - i])
+        if s1_size < s0_size:
+            for i in range(s1_size + 1, s0_size + 1):
+                out.appendleft(s0[s0_size - i])
+        return _expr.const(list(out), attr["T"].name)
+
+    return _impl
+
+
 def _broadcast_to():
     def _impl(inputs, attr, params, mod):
         if isinstance(inputs[1], _expr.Var):
@@ -2745,6 +2783,7 @@ _convert_map = {
     "BatchToSpaceND": _batch_to_space_nd(),
     "BiasAdd": _bias_add(),
     "BroadcastTo": _broadcast_to(),
+    "BroadcastArgs": _broadcast_args(),
     "Cast": _cast(),
     "Ceil": AttrCvt("ceil"),
     "CheckNumerics": _check_numerics(),
@@ -2838,6 +2877,7 @@ _convert_map = {
     "Round": AttrCvt("round"),
     "Rsqrt": _rsqrt(),
     "Select": _where(),
+    "SelectV2": _where(),
     "Selu": _selu(),
     "Shape": _shape(),
     "Sigmoid": AttrCvt("sigmoid"),
@@ -3941,7 +3981,6 @@ class GraphProto(object):
             raise ImportError("Unable to import tensorflow which is required {}".format(e))
 
         input_op_name = node_name.split(":")[0].split("^")[-1]
-
         if input_op_name not in self._nodes:
             node = self._tf_node_map[input_op_name]
             attr = self._parse_attr(node.attr)
@@ -4002,7 +4041,6 @@ class GraphProto(object):
                         inputs[i] = actual_input
 
                 op = self._convert_operator(node.op, node.name, inputs, attr)
-
             if isinstance(op, np.ndarray):
                 self._params[node.name] = tvm.nd.array(op)
                 op = [
@@ -4024,7 +4062,6 @@ class GraphProto(object):
             tn = node_name.split(":")
             tensor_slot = int(tn[1]) if len(tn) > 1 else 0
             return out[tensor_slot]
-
         return out[0]
 
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 80a70e4..f4e7522 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3079,6 +3079,33 @@ def test_forward_resize():
 
 
 #######################################################################
+# BroadcastArgs
+# -----------
+
+
+def _test_broadcast_args(in_shape_1, in_shape_2):
+    """ One iteration of broadcast_args"""
+
+    shape_1 = np.array(in_shape_1).astype("int32")
+    shape_2 = np.array(in_shape_2).astype("int32")
+
+    with tf.Graph().as_default():
+        shape_1 = constant_op.constant(shape_1, shape=shape_1.shape, dtype=shape_1.dtype)
+        shape_2 = constant_op.constant(shape_2, shape=shape_2.shape, dtype=shape_2.dtype)
+        tf.raw_ops.BroadcastArgs(s0=shape_1, s1=shape_2)
+
+        compare_tf_with_tvm(None, "", "BroadcastArgs:0", opt_level=0)
+
+
+def test_forward_broadcast_args():
+    """ Resize Bilinear """
+
+    _test_broadcast_args((4, 1, 32, 32), [4, 8, 32, 32])
+    _test_broadcast_args((6, 32, 32, 1), [6, 32, 32, 16])
+    _test_broadcast_args((32, 32, 16), [6, 32, 32, 16])
+
+
+#######################################################################
 # BroadcastTo
 # -----------
 
@@ -3621,7 +3648,7 @@ def test_forward_logical():
 
 
 #######################################################################
-# Where, Select
+# Where, Select, SelectV2
 # -------------
 def test_forward_where():
     """ Where: return elements depending on conditions"""